1-
21from contextlib import asynccontextmanager
32from dataclasses import dataclass
43from typing import TYPE_CHECKING , Optional , TypeVar , Union , cast
2625 from litestar .types import BeforeMessageSendHookHandler , EmptyType , Message , Scope
2726
2827
29- CONNECTION_SCOPE_KEY = "_asyncpg_db_connection"
30- SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START , HTTP_DISCONNECT , WEBSOCKET_DISCONNECT , WEBSOCKET_CLOSE }
28+ SESSION_TERMINUS_ASGI_EVENTS = {
29+ HTTP_RESPONSE_START ,
30+ HTTP_DISCONNECT ,
31+ WEBSOCKET_DISCONNECT ,
32+ WEBSOCKET_CLOSE ,
33+ }
3134T = TypeVar ("T" )
3235
3336if TYPE_CHECKING :
3639 AsyncpgConnection : TypeAlias = "Union[Connection, PoolConnectionProxy]"
3740
3841
39- async def default_before_send_handler (message : "Message" , scope : "Scope" ) -> None :
40- """Handle closing and cleaning up sessions before sending.
42+ def default_before_send_handler (
43+ connection_scope_key : str ,
44+ ) -> "BeforeMessageSendHookHandler" :
45+ """Return the default before_send handler to handle asyncpg connections.
4146
4247 Args:
43- message: ASGI-``Message``
44- scope: An ASGI-``Scope``
48+ connection_scope_key: The key for the connection scope
4549
4650 Returns:
47- None
51+ The handler callable
4852 """
49- session = cast ("Union[PoolConnectionProxy,Connection,None]" , get_scope_state (scope , CONNECTION_SCOPE_KEY ))
50- if session is not None and message ["type" ] in SESSION_TERMINUS_ASGI_EVENTS :
51- delete_scope_state (scope , CONNECTION_SCOPE_KEY )
53+
54+ def before_send_handler (message : "Message" , scope : "Scope" ) -> None :
55+ """Handle closing and cleaning up sessions before sending.
56+
57+ Args:
58+ message: ASGI-``Message``
59+ scope: An ASGI-``Scope``
60+
61+ Returns:
62+ None
63+ """
64+ session = cast (
65+ "Union[PoolConnectionProxy,Connection,None]" ,
66+ get_scope_state (scope , connection_scope_key ),
67+ )
68+ if session is not None and message ["type" ] in SESSION_TERMINUS_ASGI_EVENTS :
69+ delete_scope_state (scope , connection_scope_key )
70+
71+ return before_send_handler
72+
5273
5374def serializer (value : "Any" ) -> str :
5475 """Serialize JSON field values.
@@ -115,8 +136,8 @@ class AsyncpgConfig:
115136 pool_dependency_key : str = "db_pool"
116137 """Key under which to store the asyncpg Pool in the application dependency injection map. """
117138 connection_dependency_key : str = "db_connection"
118- """Key under which to store the asyncpg Pool in the application dependency injection map. """
119- before_send_handler : "BeforeMessageSendHookHandler" = default_before_send_handler
139+ """Key under which to store the asyncpg Connection in the application dependency injection map. """
140+ before_send_handler : "BeforeMessageSendHookHandler | None " = None
120141 """Handler to call before the ASGI message is sent.
121142
122143 The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
@@ -135,6 +156,11 @@ class AsyncpgConfig:
135156 If set, the plugin will use the provided pool rather than instantiate one.
136157 """
137158
159+ def __post_init__ (self ) -> None :
160+ self .connection_scope_key = f"_asyncpg_{ self .connection_dependency_key } "
161+ if self .before_send_handler is None :
162+ self .before_send_handler = default_before_send_handler (self .connection_scope_key )
163+
138164 @property
139165 def pool_config_dict (self ) -> "dict[str, Any]" :
140166 """Return the pool configuration as a dict.
@@ -212,15 +238,12 @@ async def set_json_handlers(conn: "AsyncpgConnection") -> None:
212238
213239 self .pool_instance = await asyncpg_create_pool (** pool_config )
214240 if self .pool_instance is None :
215- msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
241+ msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
216242 raise ImproperlyConfiguredException (msg )
217243 return self .pool_instance
218244
219245 @asynccontextmanager
220- async def lifespan (
221- self ,
222- app : "Litestar"
223- ) -> "AsyncGenerator[None, None]" :
246+ async def lifespan (self , app : "Litestar" ) -> "AsyncGenerator[None, None]" :
224247 db_pool = await self .create_pool ()
225248 app .state .update ({self .pool_app_state_key : db_pool })
226249 try :
@@ -240,11 +263,7 @@ def provide_pool(self, state: "State") -> "Pool":
240263 """
241264 return cast ("Pool" , state .get (self .pool_app_state_key ))
242265
243- async def provide_connection (
244- self ,
245- state : "State" ,
246- scope : "Scope"
247- ) -> "AsyncGenerator[AsyncpgConnection, None]" :
266+ async def provide_connection (self , state : "State" , scope : "Scope" ) -> "AsyncGenerator[AsyncpgConnection, None]" :
248267 """Create a connection instance.
249268
250269 Args:
@@ -254,18 +273,19 @@ async def provide_connection(
254273 Returns:
255274 A connection instance.
256275 """
257- connection = cast ("Optional[Union[Connection, PoolConnectionProxy]]" , get_scope_state (scope , CONNECTION_SCOPE_KEY ))
276+ connection = cast (
277+ "Optional[Union[Connection, PoolConnectionProxy]]" ,
278+ get_scope_state (scope , self .connection_scope_key ),
279+ )
258280 if connection is None :
259281 pool = cast ("Pool" , state .get (self .pool_app_state_key ))
260282
261283 async with pool .acquire () as connection :
262- set_scope_state (scope , CONNECTION_SCOPE_KEY , connection )
284+ set_scope_state (scope , self . connection_scope_key , connection )
263285 yield connection
264286
265287 @asynccontextmanager
266- async def get_connection (
267- self
268- ) -> "AsyncGenerator[AsyncpgConnection, None]" :
288+ async def get_connection (self ) -> "AsyncGenerator[AsyncpgConnection, None]" :
269289 """Create a connection instance.
270290
271291 Args:
0 commit comments