diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index 102467bd..f1adc95e 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -82,7 +82,18 @@ async def _get_iam_principal_email( class PostgresEngine(PGEngine): """A class for managing connections to a Cloud SQL for Postgres database.""" - _connector: Optional[Connector] = None + def __init__( + self, + key: object, + pool: AsyncEngine, + loop: Optional[asyncio.AbstractEventLoop] = None, + thread: Optional[Thread] = None, + connector: Optional[Connector] = None, + ): + """Initialize PostgresEngine.""" + # Initialize the parent PGEngine + super().__init__(key, pool, loop=loop, thread=thread) + self._connector = connector @classmethod async def _create( @@ -130,13 +141,12 @@ async def _create( "both should be specified to use basic user/password " "authentication or neither for IAM DB authentication." ) - if cls._connector is None: - cls._connector = Connector( - loop=loop, - user_agent=USER_AGENT, - quota_project=quota_project, - refresh_strategy=RefreshStrategy.LAZY, - ) + connector = Connector( + loop=loop, + user_agent=USER_AGENT, + quota_project=quota_project, + refresh_strategy=RefreshStrategy.LAZY, + ) # if user and password are given, use basic auth if user and password: @@ -156,23 +166,45 @@ async def _create( # anonymous function to be used for SQLAlchemy 'creator' argument async def getconn() -> asyncpg.Connection: - conn = await cls._connector.connect_async( # type: ignore - f"{project_id}:{region}:{instance}", - "asyncpg", - user=db_user, - password=password, - db=database, - enable_iam_auth=enable_iam_auth, - ip_type=ip_type, - ) - return conn + async def _connect() -> asyncpg.Connection: + return await connector.connect_async( # type: ignore + f"{project_id}:{region}:{instance}", + "asyncpg", + user=db_user, + password=password, + db=database, + enable_iam_auth=enable_iam_auth, + ip_type=ip_type, + ) + + # Jump to the background loop to execute connector logic. + if loop and asyncio.get_running_loop() != loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(_connect(), loop) + ) + + return await _connect() engine = create_async_engine( "postgresql+asyncpg://", async_creator=getconn, **engine_args, ) - return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore + return cls(PGEngine._PGEngine__create_key, engine, loop, thread, connector=connector) # type: ignore + + async def _run_as_async(self, coro: Awaitable[T]) -> T: + """Run an async coroutine asynchronously.""" + if not self._loop: + return await coro + try: + if asyncio.get_running_loop() == self._loop: + return await coro + except RuntimeError: + pass + + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, self._loop) + ) @classmethod def __start_background_loop( @@ -190,7 +222,7 @@ def __start_background_loop( ) -> Future: # Running a loop in a background thread allows us to support # async methods from non-async environments - if cls._default_loop is None: + if cls._default_loop is None or cls._default_loop.is_closed(): cls._default_loop = asyncio.new_event_loop() cls._default_thread = Thread( target=cls._default_loop.run_forever, daemon=True @@ -636,3 +668,9 @@ async def _aload_table_schema( ) return metadata.tables[f"{schema_name}.{table_name}"] + + async def close(self) -> None: + """Close the engine and the connector.""" + if self._connector: + await self._run_as_async(self._connector.close()) # type: ignore + await self._run_as_async(super().close())