From 0bcca68ae975f7853dc1e5b9e36c2b9feb4e39f9 Mon Sep 17 00:00:00 2001 From: c4ffein <c4ffein@gmail.com> Date: Tue, 20 Aug 2024 01:41:47 +0200 Subject: [PATCH 1/4] PoC - named view functions --- ninja/operation.py | 81 +++++++++++++++++++++----------------- tests/test_api_instance.py | 14 ++++++- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/ninja/operation.py b/ninja/operation.py index 455a16689..74251e5af 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -21,12 +21,17 @@ from ninja.constants import NOT_SET, NOT_SET_TYPE from ninja.errors import AuthenticationError, ConfigError, Throttled, ValidationError from ninja.params.models import TModels -from ninja.schema import Schema, pydantic_version +from ninja.schema import Schema from ninja.signature import ViewSignature, is_async from ninja.throttling import BaseThrottle from ninja.types import DictStrAny from ninja.utils import check_csrf, is_async_callable +try: + from asgiref.sync import sync_to_async +except ModuleNotFoundError: + pass + if TYPE_CHECKING: from ninja import NinjaAPI, Router # pragma: no cover @@ -454,44 +459,46 @@ def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: op.set_api_instance(api, router) def get_view(self) -> Callable: - view: Callable - if self.is_async: - view = self._async_view - else: - view = self._sync_view - - view.__func__.csrf_exempt = True # type: ignore - return view - - def _sync_view(self, request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase: - operation = self._find_operation(request) - if operation is None: - return self._not_allowed() - return operation.run(request, *a, **kw) - - async def _async_view( - self, request: HttpRequest, *a: Any, **kw: Any - ) -> HttpResponseBase: - from asgiref.sync import sync_to_async - - operation = self._find_operation(request) - if operation is None: - return self._not_allowed() - if operation.is_async: - return await cast(AsyncOperation, operation).run(request, *a, **kw) - return await sync_to_async(operation.run)(request, *a, **kw) + is_async = self.is_async + operations = self.operations + allowed_methods = {method for op in operations for method in op.methods} + response_not_allowed = ( + HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed"), + ) - def _find_operation(self, request: HttpRequest) -> Optional[Operation]: - for op in self.operations: - if request.method in op.methods: - return op - return None + def sync_view(request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase: + operation = next( + (op for op in operations if request.method in op.methods), None + ) + if operation is None: + return HttpResponseNotAllowed( + allowed_methods, content=b"Method not allowed" + ) + return operation.run(request, *a, **kw) - def _not_allowed(self) -> HttpResponse: - allowed_methods = set() - for op in self.operations: - allowed_methods.update(op.methods) - return HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed") + async def async_view( + request: HttpRequest, *a: Any, **kw: Any + ) -> HttpResponseBase: + operation = next( + (op for op in operations if request.method in op.methods), None + ) + if operation is None: + return HttpResponseNotAllowed( + allowed_methods, content=b"Method not allowed" + ) + if operation.is_async: + return await cast(AsyncOperation, operation).run(request, *a, **kw) + return await sync_to_async(operation.run)(request, *a, **kw) + + if is_async: + sync_to_async # Ensure we fail here and not in view + view = async_view if is_async else sync_view + view.csrf_exempt = True # type: ignore + if self.url_name: + view.__name__ = "".join(c if c.isalnum() else "_" for c in self.url_name) + if view.__name__[0].isnumeric(): + view.__name__ = f"_{view.__name__}" + return view class ResponseObject: diff --git a/tests/test_api_instance.py b/tests/test_api_instance.py index 7e0043218..7b8988cdf 100644 --- a/tests/test_api_instance.py +++ b/tests/test_api_instance.py @@ -9,12 +9,12 @@ router = Router() -@api.get("/global") +@api.get("/global", url_name="global-op") def global_op(request): pass -@router.get("/router") +@router.get("/router", url_name="45") def router_op(request): pass @@ -28,6 +28,16 @@ def test_api_instance(): for path_ops in rtr.path_operations.values(): for op in path_ops.operations: assert op.api is api + global_op_pattern, router_op_pattern = ( + next( + url_pattern + for url_pattern in api.urls[0] + if url_pattern.name == pattern_name + ) + for pattern_name in ["global-op", "45"] + ) + assert global_op_pattern.callback.__name__ == "global_op" + assert router_op_pattern.callback.__name__ == "_45" def test_reuse_router_error(): From b403d48120a1d81f19cda1a26d42259b2372d454 Mon Sep 17 00:00:00 2001 From: c4ffein <c4ffein@gmail.com> Date: Tue, 20 Aug 2024 01:50:35 +0200 Subject: [PATCH 2/4] missing import --- ninja/operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ninja/operation.py b/ninja/operation.py index 74251e5af..cac3fb848 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -21,7 +21,7 @@ from ninja.constants import NOT_SET, NOT_SET_TYPE from ninja.errors import AuthenticationError, ConfigError, Throttled, ValidationError from ninja.params.models import TModels -from ninja.schema import Schema +from ninja.schema import Schema, pydantic_version from ninja.signature import ViewSignature, is_async from ninja.throttling import BaseThrottle from ninja.types import DictStrAny From 9b385d92f056a202cf3a65e2ee59280b45918138 Mon Sep 17 00:00:00 2001 From: c4ffein <c4ffein@gmail.com> Date: Tue, 20 Aug 2024 02:02:51 +0200 Subject: [PATCH 3/4] sometimes ruff makes me so bored --- ninja/operation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ninja/operation.py b/ninja/operation.py index cac3fb848..e40a8a556 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -30,7 +30,7 @@ try: from asgiref.sync import sync_to_async except ModuleNotFoundError: - pass + sync_to_async = None if TYPE_CHECKING: from ninja import NinjaAPI, Router # pragma: no cover @@ -466,6 +466,9 @@ def get_view(self) -> Callable: HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed"), ) + if is_async: + assert sync_to_async # Ensure we fail here and not in view + def sync_view(request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase: operation = next( (op for op in operations if request.method in op.methods), None @@ -490,8 +493,6 @@ async def async_view( return await cast(AsyncOperation, operation).run(request, *a, **kw) return await sync_to_async(operation.run)(request, *a, **kw) - if is_async: - sync_to_async # Ensure we fail here and not in view view = async_view if is_async else sync_view view.csrf_exempt = True # type: ignore if self.url_name: From 69146f606ae84d360df8d5ba95fc9f4da9faed2b Mon Sep 17 00:00:00 2001 From: c4ffein <c4ffein@gmail.com> Date: Tue, 20 Aug 2024 19:09:19 +0200 Subject: [PATCH 4/4] quickfix for PoC --- ninja/operation.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ninja/operation.py b/ninja/operation.py index e40a8a556..0f60ed1ff 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -27,11 +27,6 @@ from ninja.types import DictStrAny from ninja.utils import check_csrf, is_async_callable -try: - from asgiref.sync import sync_to_async -except ModuleNotFoundError: - sync_to_async = None - if TYPE_CHECKING: from ninja import NinjaAPI, Router # pragma: no cover @@ -462,12 +457,9 @@ def get_view(self) -> Callable: is_async = self.is_async operations = self.operations allowed_methods = {method for op in operations for method in op.methods} - response_not_allowed = ( - HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed"), - ) if is_async: - assert sync_to_async # Ensure we fail here and not in view + from asgiref.sync import sync_to_async def sync_view(request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase: operation = next(