Skip to content

Commit f4c7e4c

Browse files
GitMarco27DeanChensj
authored andcommitted
feat: Add support for reusing an existing SQLAlchemy AsyncEngine in DatabaseSessionService
Merges #3565 Also resolved reviewer comments: - Track engine ownership to avoid disposing caller-owned engine on close. - Avoid mutating caller-owned engine with sqlite connection pragma listeners. - Added tests to verify caller-owned engine remains usable after service close. - Configured manual engine with StaticPool in tests to avoid flakiness. Co-authored-by: Shangjie Chen <deanchen@google.com> COPYBARA_INTEGRATE_REVIEW=#3565 from GitMarco27:feature/database-session-service-accept-engine 2abdee6 PiperOrigin-RevId: 934104748
1 parent fa18d26 commit f4c7e4c

2 files changed

Lines changed: 194 additions & 35 deletions

File tree

src/google/adk/sessions/database_session_service.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323
from typing import AsyncIterator
2424
from typing import Optional
25+
from typing import overload
2526
from typing import TypeAlias
2627
from typing import TypeVar
2728

@@ -193,51 +194,100 @@ def __init__(self, version: str):
193194
class DatabaseSessionService(BaseSessionService):
194195
"""A session service that uses a database for storage."""
195196

196-
def __init__(self, db_url: str, **kwargs: Any):
197-
"""Initializes the database session service with a database URL."""
198-
# 1. Create DB engine for db connection
199-
# 2. Create all tables based on schema
200-
# 3. Initialize all properties
197+
@overload
198+
def __init__(
199+
self,
200+
db_url: str,
201+
**kwargs: Any,
202+
) -> None:
203+
"""Initializes the database session service with a database URL.
204+
205+
Args:
206+
db_url: Database URL string for creating a new engine.
207+
**kwargs: Additional keyword arguments passed to create_async_engine.
208+
"""
209+
210+
@overload
211+
def __init__(
212+
self,
213+
*,
214+
db_engine: AsyncEngine,
215+
) -> None:
216+
"""Initializes the database session service with an existing SQLAlchemy AsyncEngine.
217+
218+
Args:
219+
db_engine: Existing SQLAlchemy AsyncEngine instance to use.
220+
"""
221+
222+
def __init__(
223+
self,
224+
db_url: Optional[str] = None,
225+
db_engine: Optional[AsyncEngine] = None,
226+
**kwargs: Any,
227+
) -> None:
228+
"""Initializes the database session service.
229+
230+
Args:
231+
db_url: Database URL string for creating a new engine. Mutually exclusive
232+
with db_engine.
233+
db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url.
234+
**kwargs: Additional keyword arguments passed to create_async_engine when
235+
db_url is provided. Ignored when db_engine is provided.
236+
237+
Raises:
238+
ValueError: If neither or both db_url and db_engine are provided, or if
239+
engine creation fails.
240+
"""
201241
try:
202242
import sqlalchemy # noqa: F401
203243
except ImportError as e:
204244
from ..utils._dependency import missing_extra
205245

206246
raise missing_extra("sqlalchemy", "db") from e
207247

208-
try:
209-
engine_kwargs = dict(kwargs)
210-
url = make_url(db_url)
211-
if (
212-
url.get_backend_name() == _SQLITE_DIALECT
213-
and url.database == ":memory:"
214-
):
215-
engine_kwargs.setdefault("poolclass", StaticPool)
216-
connect_args = dict(engine_kwargs.get("connect_args", {}))
217-
connect_args.setdefault("check_same_thread", False)
218-
engine_kwargs["connect_args"] = connect_args
219-
elif url.get_backend_name() != _SQLITE_DIALECT:
220-
engine_kwargs.setdefault("pool_pre_ping", True)
221-
222-
db_engine = create_async_engine(db_url, **engine_kwargs)
223-
if db_engine.dialect.name == _SQLITE_DIALECT:
224-
# Set sqlite pragma to enable foreign keys constraints
225-
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
226-
227-
except Exception as e:
228-
if isinstance(e, ArgumentError):
229-
raise ValueError(
230-
f"Invalid database URL format or argument '{db_url}'."
231-
) from e
232-
if isinstance(e, ImportError):
248+
if (db_url is None) == (db_engine is None):
249+
raise ValueError(
250+
"Exactly one of 'db_url' or 'db_engine' must be provided."
251+
)
252+
253+
if db_engine is None:
254+
self._owns_db_engine = True
255+
try:
256+
engine_kwargs = dict(kwargs)
257+
url = make_url(db_url)
258+
if (
259+
url.get_backend_name() == _SQLITE_DIALECT
260+
and url.database == ":memory:"
261+
):
262+
engine_kwargs.setdefault("poolclass", StaticPool)
263+
connect_args = dict(engine_kwargs.get("connect_args", {}))
264+
connect_args.setdefault("check_same_thread", False)
265+
engine_kwargs["connect_args"] = connect_args
266+
elif url.get_backend_name() != _SQLITE_DIALECT:
267+
engine_kwargs.setdefault("pool_pre_ping", True)
268+
269+
db_engine = create_async_engine(db_url, **engine_kwargs)
270+
if db_engine.dialect.name == _SQLITE_DIALECT:
271+
# Set sqlite pragma to enable foreign keys constraints
272+
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
273+
274+
except Exception as e:
275+
if isinstance(e, ArgumentError):
276+
raise ValueError(
277+
f"Invalid database URL format or argument '{db_url}'."
278+
) from e
279+
if isinstance(e, ImportError):
280+
raise ValueError(
281+
f"Database related module not found for URL '{db_url}'."
282+
) from e
233283
raise ValueError(
234-
f"Database related module not found for URL '{db_url}'."
284+
f"Failed to create database engine for URL '{db_url}'"
235285
) from e
236-
raise ValueError(
237-
f"Failed to create database engine for URL '{db_url}'"
238-
) from e
286+
else:
287+
self._owns_db_engine = False
239288

240289
self.db_engine: AsyncEngine = db_engine
290+
241291
# DB session factory method
242292
self.database_session_factory: async_sessionmaker[
243293
DatabaseSessionFactory
@@ -802,7 +852,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
802852

803853
async def close(self) -> None:
804854
"""Disposes the SQLAlchemy engine and closes pooled connections."""
805-
await self.db_engine.dispose()
855+
if self._owns_db_engine:
856+
await self.db_engine.dispose()
806857

807858
async def __aenter__(self) -> DatabaseSessionService:
808859
"""Enters the async context manager and returns this service."""

tests/unittests/sessions/test_session_service.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from google.genai import types
3535
import pytest
3636
from sqlalchemy import delete
37+
from sqlalchemy import text
38+
from sqlalchemy.ext.asyncio import create_async_engine
39+
from sqlalchemy.pool import StaticPool
3740

3841

3942
class SessionServiceType(enum.Enum):
@@ -1951,3 +1954,108 @@ def test_database_session_service_visible_in_module_namespace():
19511954

19521955
assert 'DatabaseSessionService' in dir(sessions_module)
19531956
assert sessions_module.DatabaseSessionService is DatabaseSessionService
1957+
1958+
1959+
@pytest.mark.asyncio
1960+
async def test_database_session_service_with_db_url():
1961+
"""Test DatabaseSessionService initialization with db_url."""
1962+
# Test db_url as positional argument
1963+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1964+
app_name = 'test_app'
1965+
user_id = 'test_user'
1966+
1967+
# Create and retrieve a session
1968+
session = await service.create_session(
1969+
app_name=app_name, user_id=user_id, state={'key': 'value'}
1970+
)
1971+
assert session.app_name == app_name
1972+
assert session.user_id == user_id
1973+
assert session.state == {'key': 'value'}
1974+
1975+
# Let's check that we can retrieve it
1976+
retrieved = await service.get_session(
1977+
app_name=app_name, user_id=user_id, session_id=session.id
1978+
)
1979+
assert retrieved == session
1980+
1981+
# test db_url as keyword argument
1982+
service2 = DatabaseSessionService(db_url='sqlite+aiosqlite:///:memory:')
1983+
session2 = await service2.create_session(
1984+
app_name=app_name, user_id=user_id, state={'key': 'value2'}
1985+
)
1986+
assert session2.state == {'key': 'value2'}
1987+
1988+
1989+
@pytest.mark.asyncio
1990+
async def test_database_session_service_with_db_engine():
1991+
"""Test DatabaseSessionService initialization with db_engine."""
1992+
# Create an engine manually with StaticPool to avoid flakes
1993+
engine = create_async_engine(
1994+
'sqlite+aiosqlite:///:memory:',
1995+
poolclass=StaticPool,
1996+
connect_args={'check_same_thread': False},
1997+
)
1998+
1999+
# Create service with db_engine
2000+
service = DatabaseSessionService(db_engine=engine)
2001+
app_name = 'test_app'
2002+
user_id = 'test_user'
2003+
2004+
# Create and retrieve a session
2005+
session = await service.create_session(
2006+
app_name=app_name, user_id=user_id, state={'key': 'value'}
2007+
)
2008+
assert session.app_name == app_name
2009+
assert session.user_id == user_id
2010+
assert session.state == {'key': 'value'}
2011+
2012+
# Let's check that we can retrieve it
2013+
retrieved = await service.get_session(
2014+
app_name=app_name, user_id=user_id, session_id=session.id
2015+
)
2016+
assert retrieved == session
2017+
2018+
2019+
@pytest.mark.asyncio
2020+
async def test_database_session_service_caller_owned_engine_not_disposed_on_close():
2021+
"""Verifies that a caller-owned engine is not disposed when the service is closed."""
2022+
engine = create_async_engine(
2023+
'sqlite+aiosqlite:///:memory:',
2024+
poolclass=StaticPool,
2025+
connect_args={'check_same_thread': False},
2026+
)
2027+
2028+
service = DatabaseSessionService(db_engine=engine)
2029+
2030+
# Use the service
2031+
session = await service.create_session(app_name='app', user_id='user')
2032+
assert session is not None
2033+
2034+
# Close the service
2035+
await service.close()
2036+
2037+
# Verify engine is still usable by running a query
2038+
async with engine.connect() as conn:
2039+
result = await conn.execute(text('SELECT 1;'))
2040+
assert result.scalar() == 1
2041+
2042+
2043+
@pytest.mark.asyncio
2044+
async def test_database_session_service_requires_one_argument():
2045+
"""Test that DatabaseSessionService requires exactly one of db_url or db_engine."""
2046+
# Neither argument provided
2047+
with pytest.raises(
2048+
ValueError,
2049+
match="Exactly one of 'db_url' or 'db_engine' must be provided",
2050+
):
2051+
DatabaseSessionService()
2052+
2053+
# Both arguments provided
2054+
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
2055+
with pytest.raises(
2056+
ValueError,
2057+
match="Exactly one of 'db_url' or 'db_engine' must be provided",
2058+
):
2059+
DatabaseSessionService(
2060+
db_url='sqlite+aiosqlite:///:memory:', db_engine=engine
2061+
)

0 commit comments

Comments
 (0)