diff --git a/include/ez/__init__.py b/include/ez/__init__.py index edd51f9..8294ff7 100644 --- a/include/ez/__init__.py +++ b/include/ez/__init__.py @@ -7,6 +7,7 @@ from pathlib import Path from . import ( + admin, data, database, events, @@ -37,6 +38,7 @@ "EZ_PYTHON_EXECUTABLE", "SITE_DIR", + "admin", "data", "database", "events", diff --git a/include/utilities/utils.py b/include/utilities/utils.py index 9e62e94..51f01b7 100644 --- a/include/utilities/utils.py +++ b/include/utilities/utils.py @@ -16,3 +16,14 @@ def wrapper(fn: Callable[B, Callable[P, T]]) -> Callable[P, T]: return fn(*args, **kwargs) return wrapper + + +def spacify(text: str, sep: str = ' ') -> str: + """ + Takes in a string of either PascalCase or camelCase and return it split to words + with the given separator. + + Default separator is a single space. + """ + return ''.join(sep + char if char.isupper() else char.strip() for char in text).strip() + diff --git a/modules/admin/__api__/__init__.py b/modules/admin/__api__/__init__.py index fd13de2..d739524 100644 --- a/modules/admin/__api__/__init__.py +++ b/modules/admin/__api__/__init__.py @@ -1,4 +1,6 @@ from ..menu import MENU +from . import view + def add_item(name: str): MENU.add_item(name) diff --git a/modules/admin/__api__/notifications.py b/modules/admin/__api__/notifications.py new file mode 100644 index 0000000..2415add --- /dev/null +++ b/modules/admin/__api__/notifications.py @@ -0,0 +1,33 @@ +from typing import Any, Iterable, overload + +from seamless.types import Renderable + +from ..user import User + + +_EMPTY: Any = object() + + +@overload +def push(users: Iterable[User], message: str, /) -> None: ... +@overload +def push(users: Iterable[User], message: str, /, *, title: str) -> None: ... +@overload +def push(users: Iterable[User], custom: Renderable, /) -> None: ... + +def push(users: Iterable[User], message: str | Renderable, /, *, title: str = _EMPTY): + raise NotImplementedError + + +@overload +def push_all(message: str, /, *, exclude: Iterable[User] | None = None) -> None: ... +@overload +def push_all(message: str, /, *, title: str, exclude: Iterable[User] | None = None) -> None: ... +@overload +def push_all(custom: Renderable, /, *, exclude: Iterable[User] | None = None) -> None: ... + +def push_all(message: str | Renderable, /, *, title: str = _EMPTY, exclude: Iterable[User] | None = None): + if exclude is None: + exclude = [] + + raise NotImplementedError diff --git a/modules/admin/__api__/view.py b/modules/admin/__api__/view.py new file mode 100644 index 0000000..37821bf --- /dev/null +++ b/modules/admin/__api__/view.py @@ -0,0 +1 @@ +from ..views.settings_viewer import render_settings_viewer diff --git a/modules/admin/__init__.py b/modules/admin/__init__.py index e69de29..d5eccf2 100644 --- a/modules/admin/__init__.py +++ b/modules/admin/__init__.py @@ -0,0 +1,4 @@ +__deps__ = [ + "plugins", + "web", +] diff --git a/modules/admin/__main__.py b/modules/admin/__main__.py index e69de29..4ef953c 100644 --- a/modules/admin/__main__.py +++ b/modules/admin/__main__.py @@ -0,0 +1 @@ +from . import router diff --git a/modules/admin/authentication.py b/modules/admin/authentication.py new file mode 100644 index 0000000..9b8e735 --- /dev/null +++ b/modules/admin/authentication.py @@ -0,0 +1,21 @@ +from uuid import uuid4 +from .user import User + +from . import context + +USERNAME = "admin" +PASSWORD = "pwd" + + +ADMIN_USER = User() + + +def authenticate(username: str, password: str) -> context.Session | None: + if username != USERNAME or password != PASSWORD: + return None + + user = ADMIN_USER + + session = context.create_session(user, uuid4().hex) + + return session diff --git a/modules/admin/context.py b/modules/admin/context.py new file mode 100644 index 0000000..e77537a --- /dev/null +++ b/modules/admin/context.py @@ -0,0 +1,261 @@ +""" +The admin dashboard context module. + +The context for the dashboard is composed of multiple layers. + +First, the dashboard context is implemented by the ContextRoot class. + +The context root then contains a: +- shared context (SharedContext) +- user contexts (UserContext) +- sessions (Session) + +The shared context is data that is shared between all logged in users. +The user context is data for a specific user, but shared with all its sessions (open tabs, etc..) +The session context is data for a specific session (connection, tab in browser, etc..) +""" + +import ez + +from typing import TypeAlias, overload +from utilities.utils import bind + +from .user import User + + +SessionID: TypeAlias = str + + +class SharedContext: + ... + + +class SessionContext: + ... + + +class UserContext: + def __init__(self, user: User) -> None: + self._user = user + self._sessions = [] + + @property + def user(self) -> User: + return self._user + + @property + def sessions(self) -> list["Session"]: + return self._sessions.copy() + + def create_session(self, session_id: str) -> "Session": + session = Session(self._user, session_id) + self._sessions.append(session) + return session + + def connect_session(self, session: "Session"): + if session.user != self._user: + raise ValueError("Session user does not match user context.") + self._sessions.append(session) + + def disconnect_session(self, session: "Session"): + try: + self._sessions.remove(session) + except ValueError: + return False + return True + + +class Session: + def __init__(self, user: User, session_id: SessionID) -> None: + self._id = session_id + self._user = user + self._context = SessionContext() + + @property + def id(self) -> SessionID: + return self._id + + @property + def user(self) -> User: + return self._user + + @property + def context(self) -> SessionContext: + return self._context + + +class ContextRoot: + _user_contexts: dict[User, UserContext] + _sessions: dict[SessionID, Session] + + def __init__(self) -> None: + self._global_context = SharedContext() + self._user_contexts = {} + self._sessions = {} + + @property + def global_context(self) -> SharedContext: + return self._global_context + + def user_context(self, user: User) -> UserContext: + if user not in self._user_contexts: + self._user_contexts[user] = UserContext(user) + return self._user_contexts[user] + + def get_user_context(self, user: User) -> UserContext | None: + try: + return self._user_contexts[user] + except KeyError: + return None + + def _clear_user_context(self, user: User) -> None: + try: + del self._user_contexts[user] + except KeyError: + return + + def session(self, session_id: SessionID) -> Session: + if session_id not in self._sessions: + raise KeyError(f"Session with ID '{session_id}' not found.") + return self._sessions[session_id] + + def create_session(self, user: User, session_id: str) -> Session: + session = self.user_context(user).create_session(session_id) + self._sessions[session_id] = session + return session + + def delete_session(self, session_id: SessionID) -> bool: + try: + session = self._sessions.pop(session_id) + except KeyError: + return False + context = self.get_user_context(session.user) + + if context is None: + return False + + result = context.disconnect_session(session) + + if not context.sessions: + self._clear_user_context(session.user) + + return result + + def disconnect_user(self, user: User): + context = self.get_user_context(user) + + if context is None: + return + + for session in context.sessions: + self.delete_session(session.id) + + +_SITE_CONTEXT = ContextRoot() + + +@bind(_SITE_CONTEXT) +def get_shared_context(ctx: ContextRoot): + def _get_global_context() -> SharedContext: + return ctx.global_context + + return _get_global_context + + +@bind(_SITE_CONTEXT) +def get_context_root(ctx: ContextRoot): + def _get_site_context() -> ContextRoot: + return ctx + + return _get_site_context + + +@bind(_SITE_CONTEXT) +def get_user_context(ctx: ContextRoot): + @overload + def _get_user_context(user: User, /) -> UserContext: ... + @overload + def _get_user_context(session: Session, /) -> UserContext: ... + @overload + def _get_user_context(session_id: SessionID, /) -> UserContext: ... + + def _get_user_context(value: User | Session | SessionID, /) -> UserContext: + match value: + case User() as user: + result = ctx.get_user_context(user) + if result is None: + raise ValueError("User context not found.") + return result + case Session() as session: + return _get_user_context(session.user) + case SessionID() as session_id: + return _get_user_context(ctx.session(session_id)) + case _: + raise TypeError("Invalid argument type.") + + return _get_user_context + + +@bind(_SITE_CONTEXT) +def create_session(ctx: ContextRoot): + def _create_session(user: User, session_id: SessionID) -> Session: + return ctx.create_session(user, session_id) + + return _create_session + + +@bind(_SITE_CONTEXT) +def get_session(ctx: ContextRoot): + def _get_session(session_id: SessionID) -> Session: + return ctx.session(session_id) + + return _get_session + + +@bind(_SITE_CONTEXT) +def close_session(ctx: ContextRoot): + @overload + def _close_session(session: Session, /) -> bool: ... + @overload + def _close_session(session_id: SessionID, /) -> bool: ... + + def _close_session(session_id: Session | SessionID, /) -> bool: + if isinstance(session_id, Session): + session_id = session_id.id + return ctx.delete_session(session_id) + + return _close_session + + +@bind(_SITE_CONTEXT) +def disconnect_user(ctx: ContextRoot): + def _disconnect_user(user: User): + ctx.disconnect_user(user) + + return _disconnect_user + + +@bind(_SITE_CONTEXT) +def get_connected_users(ctx: ContextRoot): + def _get_users() -> list[User]: + return list(ctx._user_contexts.keys()) + + return _get_users + + +def get_current_session() -> Session: + request = ez.web.http.request() + if request is None: + raise ValueError("No request found") + try: + sid = request.session["ez-admin:sid"] + except KeyError: + raise ValueError("Missing session id") + return get_session(sid) + + +def get_current_user(): + return get_current_session().user + + +del _SITE_CONTEXT diff --git a/modules/admin/requests.py b/modules/admin/requests.py new file mode 100644 index 0000000..188a8bd --- /dev/null +++ b/modules/admin/requests.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class LoginRequest(BaseModel): + username: str + password: str diff --git a/modules/admin/router.py b/modules/admin/router.py new file mode 100644 index 0000000..557f816 --- /dev/null +++ b/modules/admin/router.py @@ -0,0 +1,28 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse + +import ez + +from .context import get_current_session, close_session +from .requests import LoginRequest + +from .authentication import authenticate + + +admin = ez.web.http.router("/admin") + + +@admin.post("/login") +async def login(request: Request): + async with request.form() as form: + login = LoginRequest.model_validate(form) + + return JSONResponse({"message": "Logged in"}) + + +@admin.post("/logout") +def logout(request: Request): + session = get_current_session() + close_session(session) + return JSONResponse({"message": "Logged out"}) + diff --git a/modules/admin/ui/menu.py b/modules/admin/ui/menu.py new file mode 100644 index 0000000..c3578a2 --- /dev/null +++ b/modules/admin/ui/menu.py @@ -0,0 +1,183 @@ +from typing import ( + Any, + TypeAlias, + SupportsIndex, + overload, +) + + +_EMPTY: Any = object() + +MenuID: TypeAlias = str + + +class MenuEntry: + _parent: "Menu | None" + + def __init__(self, parent: "Menu | None", id: str | None = None): + self._parent = parent + self._id = id + + @property + def menu(self): + return self._parent + + @property + def id(self): + return self._id + + +class Menu(MenuEntry): + _entries: list[MenuEntry] + _mapping: dict[MenuID, MenuEntry] + + def __init__(self, parent: "Menu | None"): + super().__init__(parent) + + self._entries = [] + self._mapping = {} + + def add_entry(self, entry: MenuEntry) -> None: + self._entries.append(entry) + if entry.id is not None: + self._mapping[entry.id] = entry + + def add_entry_at(self, index: SupportsIndex, entry: MenuEntry) -> None: + self._entries.insert(index, entry) + if entry.id is not None: + self._mapping[entry.id] = entry + + def add_entries(self, *entries: MenuEntry) -> None: + for entry in entries: + self.add_entry(entry) + + def add_entries_at(self, index: SupportsIndex, *entries: MenuEntry) -> None: + index = index.__index__() + for i, entry in enumerate(entries): + self.add_entry_at(index + i, entry) + + @overload + def add_entry_before(self, target: MenuEntry, entry: MenuEntry, /) -> None: ... + @overload + def add_entry_before(self, entry_id: MenuID, entry: MenuEntry, /) -> None: ... + + def add_entry_before(self, target: MenuEntry | MenuID, entry: MenuEntry, /): + if isinstance(target, MenuID): + target = self._mapping[target] + index = self._entries.index(target) + self.add_entry_at(index - 1, entry) + + @overload + def add_entry_after(self, target: MenuEntry, entry: MenuEntry, /) -> None: ... + @overload + def add_entry_after(self, entry_id: MenuID, entry: MenuEntry, /) -> None: ... + + def add_entry_after(self, target: MenuEntry | MenuID, entry: MenuEntry, /): + if isinstance(target, MenuID): + target = self._mapping[target] + index = self._entries.index(target) + self.add_entry_at(index, entry) + + @overload + def add_entries_before(self, target: MenuEntry, /, *entries: MenuEntry) -> None: ... + @overload + def add_entries_before(self, entry_id: MenuID, /, *entries: MenuEntry) -> None: ... + + def add_entries_before(self, target: MenuEntry | MenuID, /, *entries: MenuEntry): + if isinstance(target, MenuID): + target = self._mapping[target] + index = self._entries.index(target) + self.add_entries_at(index - 1, *entries) + + @overload + def add_entries_after(self, target: MenuEntry, /, *entries: MenuEntry) -> None: ... + @overload + def add_entries_after(self, entry_id: MenuID, /, *entries: MenuEntry) -> None: ... + + def add_entries_after(self, target: MenuEntry | MenuID, /, *entries: MenuEntry): + if isinstance(target, MenuID): + target = self._mapping[target] + index = self._entries.index(target) + self.add_entries_at(index, *entries) + + @overload + def entry_at(self, index: SupportsIndex, /) -> MenuEntry: ... + @overload + def entry_at(self, index: SupportsIndex, entry: MenuEntry, /) -> None: ... + + def entry_at(self, index: SupportsIndex, entry: MenuEntry | None = _EMPTY, /): + if entry is _EMPTY: + return self[index] + if entry is None: + del self[index] + else: + self[index] = entry + + def remove_entry(self, entry: MenuEntry) -> None: + del self[entry] + + def remove_entry_at(self, index: SupportsIndex) -> None: + del self[index] + + @overload + def index(self, entry: MenuEntry, /) -> int: ... + @overload + def index(self, entry_id: MenuID, /) -> int: ... + + def index(self, entry: MenuEntry | MenuID, /) -> int: + if isinstance(entry, MenuID): + entry = self._mapping[entry] + return self._entries.index(entry) + + @overload + def __contains__(self, entry: MenuEntry, /) -> bool: ... + @overload + def __contains__(self, entry_id: MenuID, /) -> bool: ... + + def __contains__(self, entry: MenuEntry | MenuID, /) -> bool: + if isinstance(entry, MenuID): + return entry in self._mapping + return entry in self._entries + + @overload + def __getitem__(self, index: SupportsIndex, /) -> MenuEntry: ... + @overload + def __getitem__(self, entry_id: MenuID, /) -> MenuEntry: ... + + def __getitem__(self, index_or_id: SupportsIndex | MenuID) -> MenuEntry: + if isinstance(index_or_id, MenuID): + return self._mapping[index_or_id] + return self._entries[index_or_id] + + @overload + def __setitem__(self, index: SupportsIndex, entry: MenuEntry, /) -> None: ... + @overload + def __setitem__(self, entry_id: MenuID, entry: MenuEntry, /) -> None: ... + + def __setitem__(self, index: SupportsIndex | MenuID, entry: MenuEntry, /): + if isinstance(index, MenuID): + index = self.index(index) + self._entries[index] = entry + + @overload + def __delitem__(self, index: SupportsIndex, /) -> None: ... + @overload + def __delitem__(self, entry_id: MenuID, /) -> None: ... + @overload + def __delitem__(self, entry: MenuEntry, /) -> None: ... + + def __delitem__(self, entry: SupportsIndex | MenuEntry | MenuID, /): + if not isinstance(entry, MenuEntry): + entry = self[entry] + self._entries.remove(entry) + if entry.id is not None: + del self._mapping[entry.id] + + def __iter__(self): + return iter(self._entries) + + def __len__(self): + return len(self._entries) + + +__all__ = ["Menu", "MenuEntry"] diff --git a/modules/admin/user.py b/modules/admin/user.py new file mode 100644 index 0000000..8e981d1 --- /dev/null +++ b/modules/admin/user.py @@ -0,0 +1,2 @@ +class User: + ... diff --git a/modules/admin/views/settings_viewer.py b/modules/admin/views/settings_viewer.py new file mode 100644 index 0000000..6af4c6c --- /dev/null +++ b/modules/admin/views/settings_viewer.py @@ -0,0 +1,40 @@ +import ez + +from pydantic.fields import FieldInfo + +from ez.plugins import Settings + +from seamless.html import Div +from seamless.styling import Style + + +def render_settings_viewer(settings: Settings): + def render_section(section: type[Settings], value: Settings): + return Div()( + Div()(section.__ez_section_title__), + render_settings_viewer(value) + ) + + def render_field(name: str, field: FieldInfo, value): + annotation = field.annotation + if not isinstance(annotation, type): + raise TypeError + provider_type = ez.data.providers.get_provider_type(annotation) + provider = provider_type.load(value) + return Div(style=Style( + display="flex", + flexDirection="column" + + ))( + (field.title or name) + ": ", + provider.render_input() + ) + + return Div()( + *[ + render_section(field.annotation, getattr(settings, name)) + if isinstance(field.annotation, type) and issubclass(field.annotation, Settings) + else render_field(name, field, getattr(settings, name)) + for name, field in settings.model_fields.items() + ] + ) diff --git a/modules/plugins/__api__/__init__.py b/modules/plugins/__api__/__init__.py index e8876af..3622beb 100644 --- a/modules/plugins/__api__/__init__.py +++ b/modules/plugins/__api__/__init__.py @@ -1,14 +1,29 @@ -from typing import Callable +from typing import Callable, TYPE_CHECKING -from ..plugin import Plugin, PluginInfo, PluginId, PluginAPI -from ..machinery.installer import IPluginInstaller, PluginInstallerInfo, PluginInstallerId -from ..machinery.loader import IPluginLoader, PluginLoaderInfo +if TYPE_CHECKING: + from modules.plugins.plugin import Plugin, PluginInfo, PluginId, PluginAPI + from modules.plugins.machinery.installer import IPluginInstaller, PluginInstallerInfo, PluginInstallerId + from modules.plugins.machinery.loader import IPluginLoader, PluginLoaderInfo -from .errors import EZPluginError, UnknownPluginError, DuplicateIDError -from .events import Plugins + from modules.plugins.errors import EZPluginError, UnknownPluginError, DuplicateIDError + from .events import Plugins -from ..manager import PLUGIN_MANAGER as __pm -from ..config import PLUGINS_PUBLIC_API_MODULE_NAME + from modules.plugins.manager import PLUGIN_MANAGER as __pm + from modules.plugins.config import PLUGINS_PUBLIC_API_MODULE_NAME + + from modules.plugins.framework.settings import Settings +else: + from ..plugin import Plugin, PluginInfo, PluginId, PluginAPI + from ..machinery.installer import IPluginInstaller, PluginInstallerInfo, PluginInstallerId + from ..machinery.loader import IPluginLoader, PluginLoaderInfo + + from .errors import EZPluginError, UnknownPluginError, DuplicateIDError + from .events import Plugins + + from ..manager import PLUGIN_MANAGER as __pm + from ..config import PLUGINS_PUBLIC_API_MODULE_NAME + + from ..framework.settings import Settings def get_plugins() -> list[Plugin]: @@ -108,6 +123,7 @@ def get_loaders() -> list[PluginLoaderInfo]: "PluginInfo", "EZPluginError", "Plugins", + "Settings", "get_plugins", "get_plugin", "install", diff --git a/modules/plugins/__api__/events.py b/modules/plugins/__api__/events.py index 804fea8..1bbfa38 100644 --- a/modules/plugins/__api__/events.py +++ b/modules/plugins/__api__/events.py @@ -1 +1,7 @@ -from ..events import Plugins +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from modules.plugins.events import Plugins +else: + from ..events import Plugins diff --git a/modules/plugins/__main__.py b/modules/plugins/__main__.py index 2c17ba4..848b410 100644 --- a/modules/plugins/__main__.py +++ b/modules/plugins/__main__.py @@ -79,6 +79,11 @@ def load_plugins(): ) plugin_manager.run_plugins(*plugin_ids) + plugins = plugin_manager.get_plugins() + ez.log.info(f"Loaded {len(plugins)} plugins:") + for plugin in plugins: + ez.log.info(f"\tLoaded plugin: {plugin.info.package_name}") + ez.events.emit(Plugins.DidLoad, plugins) diff --git a/modules/plugins/framework/settings.py b/modules/plugins/framework/settings.py new file mode 100644 index 0000000..39aee85 --- /dev/null +++ b/modules/plugins/framework/settings.py @@ -0,0 +1,27 @@ +from typing import Any, ClassVar, Unpack + +from pydantic import BaseModel, ConfigDict + +from utilities.utils import spacify + + +_EMPTY: Any = object() + + +class Settings(BaseModel): + __ez_section_title__: ClassVar[str | None] + __ez_section_id__: ClassVar[str | None] + + def __init_subclass__(cls, *, section: str = _EMPTY, section_id: str = _EMPTY, **kwargs: Unpack[ConfigDict]): + super().__init_subclass__(**kwargs) + + if section is _EMPTY and section_id is _EMPTY: + section = spacify(cls.__name__) + + if section_id is _EMPTY: + section_id = section.lower().replace(' ', '-') + elif section is _EMPTY: + section = section_id.replace('-', ' ').title() + + cls.__ez_section_title__ = section + cls.__ez_section_id__ = section_id diff --git a/modules/templates/template.py b/modules/templates/template.py index 9140e67..2db50b1 100644 --- a/modules/templates/template.py +++ b/modules/templates/template.py @@ -16,7 +16,7 @@ class Template(TemplateBase, Generic[T]): def __init__(self, name: str, params: type[T] | type, parent: "TemplatePack | None" = None): if not isinstance(params, type) or not issubclass(params, TemplateParams): - raise TypeError(f"Functional template parameter must be a subclass of BaseModel. Got {params}.") + raise TypeError(f"Functional template parameter must be a subclass of TemplateParams. Got {params}.") super().__init__(name, parent) diff --git a/modules/web/__api__/http.py b/modules/web/__api__/http.py index 97e38c4..d25968b 100644 --- a/modules/web/__api__/http.py +++ b/modules/web/__api__/http.py @@ -3,11 +3,15 @@ from sandbox import current_plugin from utilities.utils import bind -from ..http import HTTPException, HTTPMethod, HTTPStatus -from ..routing import API_ROUTER - if TYPE_CHECKING: - from ..routing import PluginRouter + from modules.web.routing import PluginRouter + from modules.web.context import get_request, set_request, request + from modules.web.http import HTTPException, HTTPMethod, HTTPStatus + from modules.web.routing import API_ROUTER +else: + from ..context import get_request, set_request, request + from ..http import HTTPException, HTTPMethod, HTTPStatus + from ..routing import API_ROUTER _EMPTY: Any = object() @@ -23,7 +27,7 @@ def _current_router(): @bind(API_ROUTER) def router(api_router: "PluginRouter"): - def _router(route: str = _EMPTY): + def _router(route: str = _EMPTY, **kwargs): plugin = current_plugin() if not plugin.is_root: @@ -38,7 +42,7 @@ def _router(route: str = _EMPTY): if route is _EMPTY: route = '/' + plugin.oid - return api_router.add_router(plugin, route) + return api_router.add_router(plugin, route, **kwargs) return _router @@ -107,4 +111,8 @@ def head(_route: str, **kwargs): "HTTPException", "HTTPMethod", "HTTPStatus", + + "get_request", + "set_request", + "request", ] diff --git a/modules/web/__main__.py b/modules/web/__main__.py index b9512d8..27f6850 100644 --- a/modules/web/__main__.py +++ b/modules/web/__main__.py @@ -1,4 +1,4 @@ -from . import routing +from . import routing, context routing.add_route(routing.API_MOUNT) diff --git a/modules/web/context.py b/modules/web/context.py new file mode 100644 index 0000000..dbc3a75 --- /dev/null +++ b/modules/web/context.py @@ -0,0 +1,70 @@ +from contextvars import ContextVar +from typing import Any, TypeAlias, overload + +from utilities.utils import bind + +from .request import Request + +from starlette.middleware.base import BaseHTTPMiddleware + +import ez + + +RequestVar: TypeAlias = ContextVar[Request | None] +_REQUEST: RequestVar = ContextVar("request", default=None) + + +@bind(_REQUEST) +def get_request(var: RequestVar): + def _get_request() -> Request: + return var.get() + + return _get_request + + +def has_request(): + return get_request() is not None + + +@bind(_REQUEST) +def set_request(var: RequestVar): + def _set_request(request: Request | None): + var.set(request) + + return _set_request + + +@bind() +def request(): + _EMPTY: Any = object() + + @overload + def _request() -> Request: ... + @overload + def _request(request: Request) -> None: ... + + def _request(request: Request = _EMPTY): + if request is _EMPTY: + return get_request() + set_request(request) + + return _request + + +async def _on_request(request, call_next): + set_request(Request(request)) + response = await call_next(request) + set_request(None) + return response + + +ez.lowlevel.WEB_APP.add_middleware(BaseHTTPMiddleware, dispatch=_on_request) + + +del _REQUEST + +__all__ = [ + "get_request", + "set_request", + "request", +] diff --git a/modules/web/request.py b/modules/web/request.py new file mode 100644 index 0000000..d49d9c1 --- /dev/null +++ b/modules/web/request.py @@ -0,0 +1,18 @@ +""" +Defines the Request class, which represents an incoming HTTP request. + +As of v1, this module is a thin wrapper around Starlette's Request class. + +When we move to use our own web app framework, or to a different ASGI framework, +we will replace this module with the appropriate request class. + +However, the basic API should remain the same, so the rest of the application +should not have to change. +""" + +from starlette.requests import Request + + +__all__ = [ + "Request" +] diff --git a/modules/web/routing.py b/modules/web/routing.py index f8b4a90..e2e338a 100644 --- a/modules/web/routing.py +++ b/modules/web/routing.py @@ -21,11 +21,11 @@ def __init__(self) -> None: def router(self): return self._router - def add_router(self, app: Application, route: str, *, cls: type[Router] = EZRouter) -> Router: + def add_router(self, app: Application, route: str, *args, cls: type[Router] = EZRouter, **kwargs) -> Router: if app in self._mounts: return cast(Router, self._mounts[app].app) - router = cls() + router = cls(*args, **kwargs) mount = Mount(route, app=router) if not app.is_root: