Skip to content

Commit e677fb7

Browse files
authored
fix: separate connection state per asyncpg plugin (#25)
Set the config connection_scope_key based on the connection_dependency_key to allow multiple asyncpg plugins to maintain connection state separately.
1 parent db83e83 commit e677fb7

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

litestar_asyncpg/config.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from contextlib import asynccontextmanager
32
from dataclasses import dataclass
43
from typing import TYPE_CHECKING, Optional, TypeVar, Union, cast
@@ -26,8 +25,12 @@
2625
from litestar.types import BeforeMessageSendHookHandler, EmptyType, Message, Scope
2726

2827

29-
CONNECTION_SCOPE_KEY = "_asyncpg_db_connection"
30-
SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE}
28+
SESSION_TERMINUS_ASGI_EVENTS = {
29+
HTTP_RESPONSE_START,
30+
HTTP_DISCONNECT,
31+
WEBSOCKET_DISCONNECT,
32+
WEBSOCKET_CLOSE,
33+
}
3134
T = TypeVar("T")
3235

3336
if TYPE_CHECKING:
@@ -36,19 +39,37 @@
3639
AsyncpgConnection: TypeAlias = "Union[Connection, PoolConnectionProxy]"
3740

3841

39-
async def default_before_send_handler(message: "Message", scope: "Scope") -> None:
40-
"""Handle closing and cleaning up sessions before sending.
42+
def default_before_send_handler(
43+
connection_scope_key: str,
44+
) -> "BeforeMessageSendHookHandler":
45+
"""Return the default before_send handler to handle asyncpg connections.
4146
4247
Args:
43-
message: ASGI-``Message``
44-
scope: An ASGI-``Scope``
48+
connection_scope_key: The key for the connection scope
4549
4650
Returns:
47-
None
51+
The handler callable
4852
"""
49-
session = cast("Union[PoolConnectionProxy,Connection,None]", get_scope_state(scope, CONNECTION_SCOPE_KEY))
50-
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
51-
delete_scope_state(scope, CONNECTION_SCOPE_KEY)
53+
54+
def before_send_handler(message: "Message", scope: "Scope") -> None:
55+
"""Handle closing and cleaning up sessions before sending.
56+
57+
Args:
58+
message: ASGI-``Message``
59+
scope: An ASGI-``Scope``
60+
61+
Returns:
62+
None
63+
"""
64+
session = cast(
65+
"Union[PoolConnectionProxy,Connection,None]",
66+
get_scope_state(scope, connection_scope_key),
67+
)
68+
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
69+
delete_scope_state(scope, connection_scope_key)
70+
71+
return before_send_handler
72+
5273

5374
def serializer(value: "Any") -> str:
5475
"""Serialize JSON field values.
@@ -115,8 +136,8 @@ class AsyncpgConfig:
115136
pool_dependency_key: str = "db_pool"
116137
"""Key under which to store the asyncpg Pool in the application dependency injection map. """
117138
connection_dependency_key: str = "db_connection"
118-
"""Key under which to store the asyncpg Pool in the application dependency injection map. """
119-
before_send_handler: "BeforeMessageSendHookHandler" = default_before_send_handler
139+
"""Key under which to store the asyncpg Connection in the application dependency injection map. """
140+
before_send_handler: "BeforeMessageSendHookHandler | None" = None
120141
"""Handler to call before the ASGI message is sent.
121142
122143
The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and
@@ -135,6 +156,11 @@ class AsyncpgConfig:
135156
If set, the plugin will use the provided pool rather than instantiate one.
136157
"""
137158

159+
def __post_init__(self) -> None:
160+
self.connection_scope_key = f"_asyncpg_{self.connection_dependency_key}"
161+
if self.before_send_handler is None:
162+
self.before_send_handler = default_before_send_handler(self.connection_scope_key)
163+
138164
@property
139165
def pool_config_dict(self) -> "dict[str, Any]":
140166
"""Return the pool configuration as a dict.
@@ -212,15 +238,12 @@ async def set_json_handlers(conn: "AsyncpgConnection") -> None:
212238

213239
self.pool_instance = await asyncpg_create_pool(**pool_config)
214240
if self.pool_instance is None:
215-
msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
241+
msg = "Could not configure the 'pool_instance'. Please check your configuration." # type: ignore[unreachable]
216242
raise ImproperlyConfiguredException(msg)
217243
return self.pool_instance
218244

219245
@asynccontextmanager
220-
async def lifespan(
221-
self,
222-
app: "Litestar"
223-
) -> "AsyncGenerator[None, None]":
246+
async def lifespan(self, app: "Litestar") -> "AsyncGenerator[None, None]":
224247
db_pool = await self.create_pool()
225248
app.state.update({self.pool_app_state_key: db_pool})
226249
try:
@@ -240,11 +263,7 @@ def provide_pool(self, state: "State") -> "Pool":
240263
"""
241264
return cast("Pool", state.get(self.pool_app_state_key))
242265

243-
async def provide_connection(
244-
self,
245-
state: "State",
246-
scope: "Scope"
247-
) -> "AsyncGenerator[AsyncpgConnection, None]":
266+
async def provide_connection(self, state: "State", scope: "Scope") -> "AsyncGenerator[AsyncpgConnection, None]":
248267
"""Create a connection instance.
249268
250269
Args:
@@ -254,18 +273,19 @@ async def provide_connection(
254273
Returns:
255274
A connection instance.
256275
"""
257-
connection = cast("Optional[Union[Connection, PoolConnectionProxy]]", get_scope_state(scope, CONNECTION_SCOPE_KEY))
276+
connection = cast(
277+
"Optional[Union[Connection, PoolConnectionProxy]]",
278+
get_scope_state(scope, self.connection_scope_key),
279+
)
258280
if connection is None:
259281
pool = cast("Pool", state.get(self.pool_app_state_key))
260282

261283
async with pool.acquire() as connection:
262-
set_scope_state(scope, CONNECTION_SCOPE_KEY, connection)
284+
set_scope_state(scope, self.connection_scope_key, connection)
263285
yield connection
264286

265287
@asynccontextmanager
266-
async def get_connection(
267-
self
268-
) -> "AsyncGenerator[AsyncpgConnection, None]":
288+
async def get_connection(self) -> "AsyncGenerator[AsyncpgConnection, None]":
269289
"""Create a connection instance.
270290
271291
Args:

litestar_asyncpg/plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def on_app_init(self, app_config: "AppConfig") -> "AppConfig":
5050
},
5151
)
5252
app_config.type_encoders = {pgproto.UUID: str, **(app_config.type_encoders or {})}
53-
app_config.before_send.append(self._config.before_send_handler)
5453
app_config.lifespan.append(self._config.lifespan)
5554
app_config.signature_namespace.update(self._config.signature_namespace)
55+
if self._config.before_send_handler is not None:
56+
app_config.before_send.append(self._config.before_send_handler)
5657

5758
return app_config

0 commit comments

Comments
 (0)