Source code for zsl.resource.guard

"""
:mod:`zsl.resource.guard`
-------------------------

Guard module defines tools to inject security checks into a resource. With
help of the ``guard`` class decorator and ``ResourcePolicy`` declarative
policy class a complex security resource behaviour can be achieved.
"""
from enum import Enum
from functools import wraps
import http.client
from typing import Any, Callable, Dict, List, Optional

from zsl.interface.resource import ResourceResult
from zsl.service.service import _TX_HOLDER_ATTRIBUTE, SessionFactory, transactional
from zsl.utils.http import get_http_status_code_value

_HTTP_STATUS_FORBIDDEN = get_http_status_code_value(http.client.FORBIDDEN)


[docs] class Access(Enum): ALLOW = 1 DENY = 2 CONTINUE = 3
[docs] class ResourcePolicy: """Declarative policy class. Every CRUD method has is corespondent *can_method__before* and *can_method__after* where *method* can be one of (*create*, *read*, *update*, *delete*). *__before* method will get the CRUD method parameters and *__after* will get the CRUD method result as parameter. On returning ``Access.ALLOW`` access is granted. It should return ``Access.CONTINUE`` when the policy is not met, but is not broken, i. e. it is not its responsibility to decide. On returning ``Access.DENY`` or raising a ``PolicyException`` policy is broken and access is immediately denied. The default implementation of these method lookup for corresponding attribute *can_method*, so ``can_read = Access.ALLOW`` will allow access for reading without the declaration of ``can_read__before`` or ``can_read__after``. *default* attribute is used if *can_method* attribute is not declared. For more complex logic it can be declared as a property, see examples: .. code-block:: python class SimplePolicy(ResourcePolicy): '''Allow read and create''' default = Access.ALLOW can_delete = Access.CONTINUE can_update = Access.CONTINUE class AdminPolicy(ResourcePolicy): '''Only admin has access''' @inject(user_service=UserService) def __init__(self, user_service): self._user_service = user_service @property def default(self): if self._user_service.current_user.is_admin: return Access.ALLOW """ default = Access.CONTINUE # can_create # can_read # can_update # can_delete
[docs] def can_create__before(self, *args, **kwargs): """Check create method before executing.""" return self._check_default('can_create')
[docs] def can_create__after(self, *args, **kwargs): """Check create method after executing.""" return self._check_default('can_create')
[docs] def can_read__before(self, *args, **kwargs): """Check read method before executing.""" return self._check_default('can_read')
[docs] def can_read__after(self, *args, **kwargs): """Check read method after executing.""" return self._check_default('can_read')
[docs] def can_update__before(self, *args, **kwargs): """Check update method before executing.""" return self._check_default('can_update')
[docs] def can_update__after(self, *args, **kwargs): """Check update method after executing.""" return self._check_default('can_update')
[docs] def can_delete__before(self, *args, **kwargs): """Check delete method before executing.""" return self._check_default('can_delete')
[docs] def can_delete__after(self, *args, **kwargs): """Check delete method after executing.""" return self._check_default('can_delete')
def _check_default(self, prop): # type: (str) -> Access return getattr(self, prop, self.default)
[docs] class PolicyViolationError(Exception): """Error raised when policy is violated. It can bear a HTTP status code, 403 is by default. """ def __init__(self, message, code=_HTTP_STATUS_FORBIDDEN): self.code = code super().__init__(message)
[docs] class GuardedMixin: """Add guarded CRUD methods to resource. The ``guard`` replaces the CRUD guarded methods with a wrapper with security checks around these methods. It adds this mixin into the resource automatically, but it can be declared on the resource manually for IDEs to accept calls to the guarded methods. """
[docs] def guarded_create(self, params, args, data): # type: (str, Dict[str, str], Dict[str, Any]) -> Dict[str, Any] pass
[docs] def guarded_read(self, params, args, data): # type: (str, Dict[str, str], Dict[str, Any]) -> Dict[str, Any] pass
[docs] def guarded_update(self, params, args, data): # type: (str, Dict[str, str], Dict[str, Any]) -> Dict[str, Any] pass
[docs] def guarded_delete(self, params, args, data): # type: (str, Dict[str, str], Dict[str, Any]) -> Dict[str, Any] pass
[docs] def default_error_handler(e, *_): # type: (PolicyViolationError, Any) -> ResourceResult """Default policy violation error handler. It will create an empty resource result with an error HTTP code. """ return ResourceResult( status=e.code, body={} )
[docs] class guard: """Guard decorator. This decorator wraps the CRUD methods with security checks before and after CRUD method execution, so that the response can be stopped or manipulated. The original CRUD methods are renamed to *guarded_method*, where *method* can be [*create*, *read*, *update*, *delete*], so by using a `GuardedResource` as a base, you can still redeclare the *guarded_methods* and won't loose the security checks. It takes a list of policies, which will be always checked before and after executing the CRUD method. Policy is met, when it returns ``Access.ALLOW``, on ``Access.CONTINUE`` it will continue to check others and on ``Access.DENY`` or raising a ``PolicyViolationError`` access will be restricted. If there is no policy which grants the access a ``PolicyViolationError`` is raised and access will be restricted. Guard can have a custom exception handlers or method wrappers to _wrap the CRUD method around. .. code-block:: python class Policy(ResourcePolicy): default = Access.DENY can_read = Access.ALLOW # allow only read @guard([Policy()]) class GuardedResource(GuardedMixin): def read(self, param, args, data): return resources[param] class SpecificResource(GuardedResource): # override GuardedResource.read, but with its security checks def guarded_read(self, param, args, data): return specific_resources[param] """ method_wrappers = [] exception_handlers = [default_error_handler] resource_methods = ['create', 'read', 'update', 'delete'] def __init__(self, policies=None, method_wrappers=None, exception_handlers=None): # type: (Optional[List[policies]]) -> None self.policies = list(policies) if policies else [] if method_wrappers: self._method_wrappers = self.method_wrappers + method_wrappers else: self._method_wrappers = list(self.method_wrappers) if exception_handlers: self._exception_handlers = \ self.exception_handlers + exception_handlers else: self._exception_handlers = list(self.exception_handlers) @staticmethod def _check_before_policies(res, name, *args, **kwargs): for policy in res._guard_policies: access = _call_before(policy, name)(*args, **kwargs) if access == Access.ALLOW: return elif access == Access.DENY: raise PolicyViolationError('Access denied for {} {}'.format( name, 'before'), code=_HTTP_STATUS_FORBIDDEN) elif access == Access.CONTINUE: continue else: raise TypeError('Access has no value {}'.format(access)) raise PolicyViolationError( "Access haven't been granted for {} {}".format( name, 'before'), code=_HTTP_STATUS_FORBIDDEN) @staticmethod def _check_after_policies(res, name, result): for policy in res._guard_policies: access = _call_after(policy, name)(result) if access == Access.ALLOW: return elif access == Access.DENY: raise PolicyViolationError('Policy violation for {} {}'.format( name, 'before'), code=_HTTP_STATUS_FORBIDDEN) elif access == Access.CONTINUE: continue else: raise TypeError('Access have no value {}'.format(access)) raise PolicyViolationError( "Access haven't been granted for {} {}".format( name, 'after'), code=_HTTP_STATUS_FORBIDDEN) def _wrap(self, method): # type: (Callable) -> Callable name = method.__name__ @wraps(method) def wrapped(*args, **kwargs): res = args[0] args = args[1:] try: self._check_before_policies(res, name, *args, **kwargs) rv = _guarded_method(res, name)(*args, **kwargs) self._check_after_policies(res, name, rv) except PolicyViolationError as e: rv = self._handle_exception(e, res) return rv for mw in reversed(self._method_wrappers): wrapped = mw(wrapped) return wrapped def _handle_exception(self, error, resource): rv = None for handler in self._exception_handlers: rv = handler(error, rv, resource) return rv def __call__(self, cls): if hasattr(cls, '_guard_policies'): self.policies += getattr(cls, '_guard_policies') setattr(cls, '_guard_policies', list(self.policies)) return cls setattr(cls, '_guard_policies', list(self.policies)) for method_name in self.resource_methods: guarded_name = _guarded_name(method_name) if hasattr(cls, method_name): method = getattr(cls, method_name) setattr(cls, method_name, self._wrap(method)) setattr(cls, guarded_name, method) if issubclass(cls, GuardedMixin): return cls else: return type(cls.__name__, (cls, GuardedMixin), {})
[docs] def transactional_error_handler(e: Any, rv: Any, _: SessionFactory) -> Any: """Re-raise a violation error to be handled in the ``_nested_transactional``. """ raise _TransactionalPolicyViolationError(rv) from e
def _nested_transactional(fn): # type: (Callable) -> Callable """In a transactional method create a nested transaction.""" @wraps(fn) def wrapped(self, *args, **kwargs): # type: (SessionFactory) -> Any try: rv = fn(self, *args, **kwargs) except _TransactionalPolicyViolationError as e: getattr(self, _TX_HOLDER_ATTRIBUTE).rollback() rv = e.result return rv return wrapped
[docs] class transactional_guard(guard): """Security guard for ``ModelResource``. This add a transactional method wrapper and error handler which calls the rollback on ``PolicyViolationError``. """ method_wrappers = [transactional, _nested_transactional] exception_handlers = guard.exception_handlers + [ transactional_error_handler]
class _TransactionalPolicyViolationError(PolicyViolationError): """Exception raised during """ def __init__(self, result): # type: (ResourceResult) -> None self.result = result super().__init__( result.body, result.status ) def _guarded_method(res, method_name): # type: (object, str) -> Callable """Return the guarded method from CRUD name""" return getattr(res, _guarded_name(method_name)) def _guarded_name(method_name): # type: (str) -> str """Return name for guarded CRUD method. >>> _guarded_name('read') 'guarded_read' """ return 'guarded_' + method_name def _before_name(method_name): # type: (str) -> str """Return the name of the before check method. >>> _before_name('read') 'can_read__before' """ return 'can_' + method_name + '__before' def _after_name(method_name): # type: (str) -> str """Return the name of after check method. >>> _after_name('read') 'can_read__after' """ return 'can_' + method_name + '__after' def _call_before(policy, method_name): # type: (ResourcePolicy, str) -> Callable """Return the before check method. >>> p = ResourcePolicy() >>> _call_before(p, 'read') p.can_read__before """ return getattr(policy, _before_name(method_name)) def _call_after(policy, method_name): # type: (ResourcePolicy, str) -> Callable """Return the after check method. >>> p = ResourcePolicy() >>> _call_after(p, 'read') p.can_after__before """ return getattr(policy, _after_name(method_name))