Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,18 @@ async def _get_iam_principal_email(
class PostgresEngine(PGEngine):
"""A class for managing connections to a Cloud SQL for Postgres database."""

_connector: Optional[Connector] = None
def __init__(
self,
key: object,
pool: AsyncEngine,
loop: Optional[asyncio.AbstractEventLoop] = None,
thread: Optional[Thread] = None,
connector: Optional[Connector] = None,
):
"""Initialize PostgresEngine."""
# Initialize the parent PGEngine
super().__init__(key, pool, loop=loop, thread=thread)
self._connector = connector

@classmethod
async def _create(
Expand Down Expand Up @@ -130,13 +141,12 @@ async def _create(
"both should be specified to use basic user/password "
"authentication or neither for IAM DB authentication."
)
if cls._connector is None:
cls._connector = Connector(
loop=loop,
user_agent=USER_AGENT,
quota_project=quota_project,
refresh_strategy=RefreshStrategy.LAZY,
)
connector = Connector(
loop=loop,
user_agent=USER_AGENT,
quota_project=quota_project,
refresh_strategy=RefreshStrategy.LAZY,
)

# if user and password are given, use basic auth
if user and password:
Expand All @@ -156,23 +166,45 @@ async def _create(

# anonymous function to be used for SQLAlchemy 'creator' argument
async def getconn() -> asyncpg.Connection:
conn = await cls._connector.connect_async( # type: ignore
f"{project_id}:{region}:{instance}",
"asyncpg",
user=db_user,
password=password,
db=database,
enable_iam_auth=enable_iam_auth,
ip_type=ip_type,
)
return conn
async def _connect() -> asyncpg.Connection:
return await connector.connect_async( # type: ignore
f"{project_id}:{region}:{instance}",
"asyncpg",
user=db_user,
password=password,
db=database,
enable_iam_auth=enable_iam_auth,
ip_type=ip_type,
)

# Jump to the background loop to execute connector logic.
if loop and asyncio.get_running_loop() != loop:
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(_connect(), loop)
)

return await _connect()

engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
**engine_args,
)
return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore
return cls(PGEngine._PGEngine__create_key, engine, loop, thread, connector=connector) # type: ignore

async def _run_as_async(self, coro: Awaitable[T]) -> T:
"""Run an async coroutine asynchronously."""
if not self._loop:
return await coro
try:
if asyncio.get_running_loop() == self._loop:
return await coro
except RuntimeError:
pass

return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, self._loop)
)

@classmethod
def __start_background_loop(
Expand All @@ -190,7 +222,7 @@ def __start_background_loop(
) -> Future:
# Running a loop in a background thread allows us to support
# async methods from non-async environments
if cls._default_loop is None:
if cls._default_loop is None or cls._default_loop.is_closed():
cls._default_loop = asyncio.new_event_loop()
cls._default_thread = Thread(
target=cls._default_loop.run_forever, daemon=True
Expand Down Expand Up @@ -636,3 +668,9 @@ async def _aload_table_schema(
)

return metadata.tables[f"{schema_name}.{table_name}"]

async def close(self) -> None:
"""Close the engine and the connector."""
if self._connector:
await self._run_as_async(self._connector.close()) # type: ignore
await self._run_as_async(super().close())
Loading