diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py index ea184296..7ab241d9 100644 --- a/aiohttp_security/api.py +++ b/aiohttp_security/api.py @@ -64,10 +64,14 @@ async def authorized_userid(request: web.Request) -> Optional[str]: return user_id -async def permits(request: web.Request, permission: Union[str, enum.Enum], - context: Any = None) -> bool: +def _validate_permission(permission: Union[str, enum.Enum]) -> None: if not permission or not isinstance(permission, (str, enum.Enum)): raise ValueError("Permission should be a str or enum value.") + + +async def permits(request: web.Request, permission: Union[str, enum.Enum], + context: Any = None) -> bool: + _validate_permission(permission) identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY) autz_policy: _AAP = request.config_dict.get(AUTZ_KEY) if identity_policy is None or autz_policy is None: @@ -104,13 +108,14 @@ async def check_authorized(request: web.Request) -> str: async def check_permission(request: web.Request, permission: Union[str, enum.Enum], context: Any = None) -> None: - """Checker that passes only to authoraised users with given permission. + """Checker that passes only to authorized users with given permission. If user is not authorized - raises HTTPUnauthorized, if user is authorized and does not have permission - raises HTTPForbidden. """ + _validate_permission(permission) await check_authorized(request) allowed = await permits(request, permission, context) if not allowed: diff --git a/tests/test_no_auth.py b/tests/test_no_auth.py index 9611a080..677623f2 100644 --- a/tests/test_no_auth.py +++ b/tests/test_no_auth.py @@ -1,6 +1,7 @@ +import pytest from aiohttp import web -from aiohttp_security import authorized_userid, permits +from aiohttp_security import authorized_userid, check_permission, permits async def test_authorized_userid(aiohttp_client): @@ -33,3 +34,17 @@ async def check(request): client = await aiohttp_client(app) resp = await client.get('/') assert 200 == resp.status + + +async def test_check_permission_rejects_invalid_value(aiohttp_client): + + async def check(request): + with pytest.raises(ValueError): + await check_permission(request, None) # type: ignore[arg-type] + return web.Response() + + app = web.Application() + app.router.add_route('GET', '/', check) + client = await aiohttp_client(app) + resp = await client.get('/') + assert 200 == resp.status