diff --git a/requirements.txt b/requirements.txt index ffec811..e3299a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.18.4 +cloud-sql-python-connector[asyncpg]==1.18.5 llama-index-core==0.14.4 pgvector==0.4.1 SQLAlchemy[asyncio]==2.0.43 diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py index 32042e7..2abffc6 100644 --- a/tests/test_async_chat_store.py +++ b/tests/test_async_chat_store.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -28,18 +28,35 @@ sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresChatStore. Use PostgresChatStore interface instead." +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + result = await run_on_background(engine, _impl()) + return result def get_env_var(key: str, desc: str) -> str: @@ -96,8 +113,10 @@ async def async_engine( @pytest_asyncio.fixture(scope="class") async def chat_store(self, async_engine): - await async_engine._ainit_chat_store_table(table_name=default_table_name_async) - + await run_on_background( + async_engine, + async_engine._ainit_chat_store_table(table_name=default_table_name_async), + ) chat_store = await AsyncPostgresChatStore.create( engine=async_engine, table_name=default_table_name_async ) @@ -117,21 +136,23 @@ async def test_async_add_message(self, async_engine, chat_store): key = "test_add_key" message = ChatMessage(content="add_message_test", role="user") - await chat_store.async_add_message(key, message=message) + await run_on_background( + async_engine, chat_store.async_add_message(key, message=message) + ) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" results = await afetch(async_engine, query) result = results[0] assert result["message"] == message.model_dump() - async def test_aset_and_aget_messages(self, chat_store): + async def test_aset_and_aget_messages(self, async_engine, chat_store): message_1 = ChatMessage(content="First message", role="user") message_2 = ChatMessage(content="Second message", role="user") messages = [message_1, message_2] key = "test_set_and_get_key" - await chat_store.aset_messages(key, messages) + await run_on_background(async_engine, chat_store.aset_messages(key, messages)) - results = await chat_store.aget_messages(key) + results = await run_on_background(async_engine, chat_store.aget_messages(key)) assert len(results) == 2 assert results[0].content == message_1.content @@ -140,9 +161,9 @@ async def test_aset_and_aget_messages(self, chat_store): async def test_adelete_messages(self, async_engine, chat_store): messages = [ChatMessage(content="Message to delete", role="user")] key = "test_delete_key" - await chat_store.aset_messages(key, messages) + await run_on_background(async_engine, chat_store.aset_messages(key, messages)) - await chat_store.adelete_messages(key) + await run_on_background(async_engine, chat_store.adelete_messages(key)) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" results = await afetch(async_engine, query) @@ -153,9 +174,9 @@ async def test_adelete_message(self, async_engine, chat_store): message_2 = ChatMessage(content="Delete me", role="user") messages = [message_1, message_2] key = "test_delete_message_key" - await chat_store.aset_messages(key, messages) + await run_on_background(async_engine, chat_store.aset_messages(key, messages)) - await chat_store.adelete_message(key, 1) + await run_on_background(async_engine, chat_store.adelete_message(key, 1)) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" results = await afetch(async_engine, query) @@ -168,9 +189,9 @@ async def test_adelete_last_message(self, async_engine, chat_store): message_3 = ChatMessage(content="Message 3", role="user") messages = [message_1, message_2, message_3] key = "test_delete_last_message_key" - await chat_store.aset_messages(key, messages) + await run_on_background(async_engine, chat_store.aset_messages(key, messages)) - await chat_store.adelete_last_message(key) + await run_on_background(async_engine, chat_store.adelete_last_message(key)) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" results = await afetch(async_engine, query) @@ -183,10 +204,14 @@ async def test_aget_keys(self, async_engine, chat_store): message_2 = [ChatMessage(content="Second message", role="user")] key_1 = "key1" key_2 = "key2" - await chat_store.aset_messages(key_1, message_1) - await chat_store.aset_messages(key_2, message_2) + await run_on_background( + async_engine, chat_store.aset_messages(key_1, message_1) + ) + await run_on_background( + async_engine, chat_store.aset_messages(key_2, message_2) + ) - keys = await chat_store.aget_keys() + keys = await run_on_background(async_engine, chat_store.aget_keys()) assert key_1 in keys assert key_2 in keys @@ -194,7 +219,7 @@ async def test_aget_keys(self, async_engine, chat_store): async def test_set_exisiting_key(self, async_engine, chat_store): message_1 = [ChatMessage(content="First message", role="user")] key = "test_set_exisiting_key" - await chat_store.aset_messages(key, message_1) + await run_on_background(async_engine, chat_store.aset_messages(key, message_1)) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" results = await afetch(async_engine, query) @@ -207,7 +232,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store): message_3 = ChatMessage(content="Third message", role="user") messages = [message_2, message_3] - await chat_store.aset_messages(key, messages) + await run_on_background(async_engine, chat_store.aset_messages(key, messages)) query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" results = await afetch(async_engine, query) diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index 4c0dacb..cb4cb16 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid import warnings -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -31,18 +32,35 @@ sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + result = await run_on_background(engine, _impl()) + return result def get_env_var(key: str, desc: str) -> str: @@ -93,10 +111,16 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): - await async_engine._ainit_doc_store_table(table_name=default_table_name_async) + await run_on_background( + async_engine, + async_engine._ainit_doc_store_table(table_name=default_table_name_async), + ) - doc_store = await AsyncPostgresDocumentStore.create( - engine=async_engine, table_name=default_table_name_async + doc_store = await run_on_background( + async_engine, + AsyncPostgresDocumentStore.create( + engine=async_engine, table_name=default_table_name_async + ), ) yield doc_store @@ -106,10 +130,16 @@ async def doc_store(self, async_engine): @pytest_asyncio.fixture(scope="class") async def custom_doc_store(self, async_engine): - await async_engine._ainit_doc_store_table(table_name=custom_table_name_async) + await run_on_background( + async_engine, + async_engine._ainit_doc_store_table(table_name=custom_table_name_async), + ) - custom_doc_store = await AsyncPostgresDocumentStore.create( - engine=async_engine, table_name=custom_table_name_async, batch_size=0 + custom_doc_store = await run_on_background( + async_engine, + AsyncPostgresDocumentStore.create( + engine=async_engine, table_name=custom_table_name_async, batch_size=0 + ), ) yield custom_doc_store @@ -125,8 +155,11 @@ async def test_init_with_constructor(self, async_engine): async def test_create_without_table(self, async_engine): with pytest.raises(ValueError): - await AsyncPostgresDocumentStore.create( - engine=async_engine, table_name="non-existent-table" + await run_on_background( + async_engine, + AsyncPostgresDocumentStore.create( + engine=async_engine, table_name="non-existent-table" + ), ) async def test_warning(self, custom_doc_store): @@ -143,13 +176,13 @@ async def test_warning(self, custom_doc_store): w[-1].message ) - async def test_adocs(self, doc_store): + async def test_adocs(self, doc_store, async_engine): # Create and add document into the doc store. document_text = "add document test" doc = Document(text=document_text, id_="add_doc_test", metadata={"doc": "info"}) # Add document into the store - await doc_store.async_add_documents([doc]) + await run_on_background(async_engine, doc_store.async_add_documents([doc])) # Assert document is found using the docs property. docs = await doc_store.adocs @@ -161,7 +194,7 @@ async def test_async_add_document(self, async_engine, doc_store): document_text = "add document test" doc = Document(text=document_text, id_="add_doc_test", metadata={"doc": "info"}) - await doc_store.async_add_documents([doc]) + await run_on_background(async_engine, doc_store.async_add_documents([doc])) # Query the table to confirm the inserted document is present. query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" @@ -175,10 +208,13 @@ async def test_add_hash_before_data(self, async_engine, doc_store): doc = Document(text=document_text, id_="add_doc_test", metadata={"doc": "info"}) # Insert the document id with it's doc_hash. - await doc_store.aset_document_hash(doc_id=doc.doc_id, doc_hash=doc.hash) + await run_on_background( + async_engine, + doc_store.aset_document_hash(doc_id=doc.doc_id, doc_hash=doc.hash), + ) # Insert the document's data - await doc_store.async_add_documents([doc]) + await run_on_background(async_engine, doc_store.async_add_documents([doc])) # Confirm the overwrite was successful. query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" @@ -186,7 +222,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): result = results[0] assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text - async def test_aref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store, async_engine): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -196,23 +232,31 @@ async def test_aref_doc_exists(self, doc_store): ) doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info() - await doc_store.async_add_documents([ref_doc, doc]) + await run_on_background( + async_engine, doc_store.async_add_documents([ref_doc, doc]) + ) # Confirm that ref_doc_id is recorded for the doc. - result = await doc_store.aref_doc_exists(ref_doc_id=ref_doc.doc_id) + result = await run_on_background( + async_engine, doc_store.aref_doc_exists(ref_doc_id=ref_doc.doc_id) + ) assert result == True - async def test_fetch_ref_doc_info(self, doc_store): + async def test_fetch_ref_doc_info(self, doc_store, async_engine): # Create a ref_doc & doc and add them to the store. ref_doc = Document( text="first doc", id_="ref_parent_doc", metadata={"doc": "info"} ) doc = Document(text="second doc", id_="ref_child_doc", metadata={"doc": "info"}) doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info() - await doc_store.async_add_documents([ref_doc, doc]) + await run_on_background( + async_engine, doc_store.async_add_documents([ref_doc, doc]) + ) # Fetch to see if ref_doc_info is found. - result = await doc_store.aget_ref_doc_info(ref_doc_id=ref_doc.doc_id) + result = await run_on_background( + async_engine, doc_store.aget_ref_doc_info(ref_doc_id=ref_doc.doc_id) + ) assert result is not None # Add a new_doc with reference to doc. @@ -220,95 +264,144 @@ async def test_fetch_ref_doc_info(self, doc_store): text="third_doc", id_="ref_new_doc", metadata={"doc": "info"} ) new_doc.relationships[NodeRelationship.SOURCE] = doc.as_related_node_info() - await doc_store.async_add_documents([new_doc]) + await run_on_background(async_engine, doc_store.async_add_documents([new_doc])) # Fetch to see if ref_doc_info is found for both ref_doc and doc. - results = await doc_store.aget_all_ref_doc_info() + results = await run_on_background( + async_engine, doc_store.aget_all_ref_doc_info() + ) assert ref_doc.doc_id in results assert doc.doc_id in results - async def test_adelete_ref_doc(self, doc_store): + async def test_adelete_ref_doc(self, doc_store, async_engine): # Create a ref_doc & doc and add them to the store. ref_doc = Document( text="first doc", id_="ref_parent_doc", metadata={"doc": "info"} ) doc = Document(text="second doc", id_="ref_child_doc", metadata={"doc": "info"}) doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info() - await doc_store.async_add_documents([ref_doc, doc]) + await run_on_background( + async_engine, doc_store.async_add_documents([ref_doc, doc]) + ) # Delete the reference doc - await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) + await run_on_background( + async_engine, doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) + ) # Confirm the reference doc along with it's child nodes are deleted. assert ( - await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None + await run_on_background( + async_engine, + doc_store.aget_document(doc_id=doc.doc_id, raise_error=False), + ) + is None ) # Confirm deleting an non-existent reference doc returns None. - assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None + assert ( + await run_on_background( + async_engine, doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) + ) + is None + ) - async def test_set_and_get_document_hash(self, doc_store): + async def test_set_and_get_document_hash(self, doc_store, async_engine): # Set a doc hash for a document doc_id = "document_id" doc_hash = "document_hash" - await doc_store.aset_document_hash(doc_id=doc_id, doc_hash=doc_hash) + await run_on_background( + async_engine, doc_store.aset_document_hash(doc_id=doc_id, doc_hash=doc_hash) + ) # Assert with get that the hash is same as the one set. - assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + assert ( + await run_on_background( + async_engine, doc_store.aget_document_hash(doc_id=doc_id) + ) + == doc_hash + ) - async def test_aget_document_hash(self, doc_store): - assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_aget_document_hash(self, doc_store, async_engine): + assert ( + await run_on_background( + async_engine, doc_store.aget_document_hash(doc_id="non-existent-doc") + ) + is None + ) - async def test_set_and_get_document_hashes(self, doc_store): + async def test_set_and_get_document_hashes(self, doc_store, async_engine): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { "document one": "document one hash", "document two": "document two hash", } expected_dict = {v: k for k, v in document_dict.items()} - await doc_store.aset_document_hashes(doc_hashes=document_dict) + await run_on_background( + async_engine, doc_store.aset_document_hashes(doc_hashes=document_dict) + ) # Get all the doc hashes and assert it is same as the one set. - results = await doc_store.aget_all_document_hashes() + results = await run_on_background( + async_engine, doc_store.aget_all_document_hashes() + ) assert "document one hash" in results assert "document two hash" in results assert results["document one hash"] == expected_dict["document one hash"] assert results["document two hash"] == expected_dict["document two hash"] - async def test_doc_store_basic(self, doc_store): + async def test_doc_store_basic(self, doc_store, async_engine): # Create a doc and a node and add them to the store. doc = Document(text="document_1", id_="doc_id_1", metadata={"doc": "info"}) node = TextNode(text="node_1", id_="node_id_1", metadata={"node": "info"}) - await doc_store.async_add_documents([doc, node]) + await run_on_background( + async_engine, doc_store.async_add_documents([doc, node]) + ) # Assert if document exists - assert await doc_store.adocument_exists(doc_id=doc.doc_id) == True + assert ( + await run_on_background( + async_engine, doc_store.adocument_exists(doc_id=doc.doc_id) + ) + == True + ) # Assert if retrieved doc is the same as the one inserted. - retrieved_doc = await doc_store.aget_document(doc_id=doc.doc_id) + retrieved_doc = await run_on_background( + async_engine, doc_store.aget_document(doc_id=doc.doc_id) + ) assert retrieved_doc == doc # Assert if retrieved node is the same as the one inserted. - retrieved_node = await doc_store.aget_document(doc_id=node.node_id) + retrieved_node = await run_on_background( + async_engine, doc_store.aget_document(doc_id=node.node_id) + ) assert retrieved_node == node async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) - await doc_store.async_add_documents([doc]) + await run_on_background(async_engine, doc_store.async_add_documents([doc])) # Delete the document from the store. - await doc_store.adelete_document(doc_id=doc.doc_id) + await run_on_background( + async_engine, doc_store.adelete_document(doc_id=doc.doc_id) + ) # Assert the document is deleted by querying the table. query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 - async def test_delete_non_existent_document(self, doc_store): - await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + async def test_delete_non_existent_document(self, doc_store, async_engine): + await run_on_background( + async_engine, + doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False), + ) with pytest.raises(ValueError): - await doc_store.adelete_document(doc_id="non-existent-doc") + await run_on_background( + async_engine, doc_store.adelete_document(doc_id="non-existent-doc") + ) async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. @@ -319,7 +412,7 @@ async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info() # Insert only the node into the document store. - await doc_store.async_add_documents([doc]) + await run_on_background(async_engine, doc_store.async_add_documents([doc])) query = f"""select id as node_ids from "public"."{default_table_name_async}" where ref_doc_id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) @@ -328,7 +421,9 @@ async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): assert len(result) != 0 # Delete the document - await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) + await run_on_background( + async_engine, doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) + ) # Assert if parent doc is deleted query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" @@ -351,7 +446,9 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) node.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info() # Add all the structures into the store. - await doc_store.async_add_documents([ref_doc, doc, node]) + await run_on_background( + async_engine, doc_store.async_add_documents([ref_doc, doc, node]) + ) query = f"""select id as node_ids from "public"."{default_table_name_async}" where ref_doc_id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) @@ -364,7 +461,7 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) ] # Delete the child document. - await doc_store.adelete_document(doc.doc_id) + await run_on_background(async_engine, doc_store.adelete_document(doc.doc_id)) # Assert the ref_doc still exists. query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" @@ -379,7 +476,7 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) assert result["node_ids"] == [node.node_id] # Delete the child node - await doc_store.adelete_document(node.node_id) + await run_on_background(async_engine, doc_store.adelete_document(node.node_id)) # Assert the ref_doc is also deleted from the store. query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index b59975d..2e2e324 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid import warnings -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -29,18 +30,35 @@ sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresIndexStore. Use PostgresIndexStore interface instead." +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + result = await run_on_background(engine, _impl()) + return result def get_env_var(key: str, desc: str) -> str: @@ -91,10 +109,16 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): - await async_engine._ainit_index_store_table(table_name=default_table_name_async) + await run_on_background( + async_engine, + async_engine._ainit_index_store_table(table_name=default_table_name_async), + ) - index_store = await AsyncPostgresIndexStore.create( - engine=async_engine, table_name=default_table_name_async + index_store = await run_on_background( + async_engine, + AsyncPostgresIndexStore.create( + engine=async_engine, table_name=default_table_name_async + ), ) yield index_store @@ -111,44 +135,61 @@ async def test_init_with_constructor(self, async_engine): async def test_create_without_table(self, async_engine): with pytest.raises(ValueError): - await AsyncPostgresIndexStore.create( - engine=async_engine, table_name="non-existent-table" + await run_on_background( + async_engine, + AsyncPostgresIndexStore.create( + engine=async_engine, table_name="non-existent-table" + ), ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() index_id = index_struct.index_id index_type = index_struct.get_type() - await index_store.aadd_index_struct(index_struct) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_struct) + ) query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';""" results = await afetch(async_engine, query) result = results[0] assert result.get("type") == index_type - await index_store.adelete_index_struct(index_id) + await run_on_background( + async_engine, index_store.adelete_index_struct(index_id) + ) query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';""" results = await afetch(async_engine, query) assert results == [] - async def test_get_index(self, index_store): + async def test_get_index(self, index_store, async_engine): index_struct = IndexGraph() index_id = index_struct.index_id index_type = index_struct.get_type() - await index_store.aadd_index_struct(index_struct) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_struct) + ) - ind_struct = await index_store.aget_index_struct(index_id) + ind_struct = await run_on_background( + async_engine, index_store.aget_index_struct(index_id) + ) assert index_struct == ind_struct - async def test_aindex_structs(self, index_store): + async def test_aindex_structs(self, index_store, async_engine): index_dict_struct = IndexDict() index_list_struct = IndexList() index_graph_struct = IndexGraph() - await index_store.aadd_index_struct(index_dict_struct) - await index_store.async_add_index_struct(index_graph_struct) - await index_store.aadd_index_struct(index_list_struct) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_dict_struct) + ) + await run_on_background( + async_engine, index_store.async_add_index_struct(index_graph_struct) + ) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_list_struct) + ) indexes = await index_store.aindex_structs() @@ -156,16 +197,21 @@ async def test_aindex_structs(self, index_store): assert index_list_struct in indexes assert index_graph_struct in indexes - async def test_warning(self, index_store): + async def test_warning(self, index_store, async_engine): index_dict_struct = IndexDict() index_list_struct = IndexList() - await index_store.aadd_index_struct(index_dict_struct) - await index_store.aadd_index_struct(index_list_struct) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_dict_struct) + ) + await run_on_background( + async_engine, index_store.aadd_index_struct(index_list_struct) + ) with warnings.catch_warnings(record=True) as w: - index_struct = await index_store.aget_index_struct() - + index_struct = await run_on_background( + async_engine, index_store.aget_index_struct() + ) assert len(w) == 1 assert "No struct_id specified and more than one struct exists." in str( w[-1].message diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py index 6e2a665..8a10282 100644 --- a/tests/test_async_reader.py +++ b/tests/test_async_reader.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -29,18 +30,35 @@ sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + result = await run_on_background(engine, _impl()) + return result def get_env_var(key: str, desc: str) -> str: @@ -116,41 +134,56 @@ async def _collect_async_items(self, docs_generator): async def test_create_reader_with_invalid_parameters(self, async_engine): with pytest.raises(ValueError): - await AsyncPostgresReader.create( - engine=async_engine, + await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + ), ) with pytest.raises(ValueError): def fake_formatter(): return None - await AsyncPostgresReader.create( - engine=async_engine, - table_name=default_table_name_async, - format="text", - formatter=fake_formatter, + await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ), ) with pytest.raises(ValueError): - await AsyncPostgresReader.create( - engine=async_engine, - table_name=default_table_name_async, - format="fake_format", + await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ), ) async def test_lazy_load_data(self, async_engine): with pytest.raises(Exception, match=sync_method_exception_str): - reader = await AsyncPostgresReader.create( - engine=async_engine, - table_name=default_table_name_async, + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ), ) reader.lazy_load_data() async def test_load_data(self, async_engine): with pytest.raises(Exception, match=sync_method_exception_str): - reader = await AsyncPostgresReader.create( - engine=async_engine, - table_name=default_table_name_async, + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ), ) reader.load_data() @@ -176,9 +209,12 @@ async def test_load_from_query_default(self, async_engine): """ await aexecute(async_engine, insert_query) - reader = await AsyncPostgresReader.create( - engine=async_engine, - table_name=table_name, + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + table_name=table_name, + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) @@ -237,17 +273,20 @@ async def test_load_from_query_customized_content_customized_metadata( """ await aexecute(async_engine, insert_query) - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "fruit_name", - "variety", - "quantity_in_stock", - "price_per_unit", - "organic", - ], - metadata_columns=["fruit_id"], + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) @@ -281,14 +320,17 @@ async def test_load_from_query_customized_content_default_metadata( """ await aexecute(async_engine, insert_query) - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) @@ -304,15 +346,18 @@ async def test_load_from_query_customized_content_default_metadata( assert expected.text == actual.text assert expected.metadata == actual.metadata - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - format="JSON", + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ), ) actual_documents = await self._collect_async_items(reader.alazy_load_data()) @@ -357,12 +402,15 @@ async def test_load_from_query_with_json(self, async_engine): VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" await aexecute(async_engine, insert_query) - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - metadata_columns=[ - "variety", - ], + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) @@ -411,15 +459,18 @@ def my_formatter(row, content_columns): str(row[column]) for column in content_columns if column in row ) - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - formatter=my_formatter, + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) @@ -463,15 +514,18 @@ async def test_load_from_query_customized_content_default_metadata_custom_page_c """ await aexecute(async_engine, insert_query) - reader = await AsyncPostgresReader.create( - engine=async_engine, - query=f'SELECT * FROM "{table_name}";', - content_columns=[ - "variety", - "quantity_in_stock", - "price_per_unit", - ], - format="YAML", + reader = await run_on_background( + async_engine, + AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ), ) documents = await self._collect_async_items(reader.alazy_load_data()) diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 785a5b0..7bdf35e 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid import warnings -from typing import Sequence +from typing import Any, Coroutine, Sequence import pytest import pytest_asyncio @@ -60,18 +61,35 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async with engine._pool.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - result_fetch = result_map.fetchall() - return result_fetch + async def _impl(): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + result = await run_on_background(engine, _impl()) + return result @pytest.mark.asyncio(loop_scope="class") @@ -116,40 +134,51 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine._ainit_vector_store_table( - DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + await run_on_background( + engine, + engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ), + ) + vs = await run_on_background( + engine, AsyncPostgresVectorStore.create(engine, table_name=DEFAULT_TABLE) ) - vs = await AsyncPostgresVectorStore.create(engine, table_name=DEFAULT_TABLE) yield vs @pytest_asyncio.fixture(scope="class") async def custom_vs(self, engine): - await engine._ainit_vector_store_table( - DEFAULT_TABLE_CUSTOM_VS, - VECTOR_SIZE, - overwrite_existing=True, - metadata_columns=[ - Column(name="len", data_type="INTEGER", nullable=False), - Column( - name="nullable_int_field", - data_type="INTEGER", - nullable=True, - ), - Column( - name="nullable_str_field", - data_type="VARCHAR", - nullable=True, - ), - ], + await run_on_background( + engine, + engine._ainit_vector_store_table( + DEFAULT_TABLE_CUSTOM_VS, + VECTOR_SIZE, + overwrite_existing=True, + metadata_columns=[ + Column(name="len", data_type="INTEGER", nullable=False), + Column( + name="nullable_int_field", + data_type="INTEGER", + nullable=True, + ), + Column( + name="nullable_str_field", + data_type="VARCHAR", + nullable=True, + ), + ], + ), ) - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - table_name=DEFAULT_TABLE_CUSTOM_VS, - metadata_columns=[ - "len", - "nullable_int_field", - "nullable_str_field", - ], + AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE_CUSTOM_VS, + metadata_columns=[ + "len", + "nullable_int_field", + "nullable_str_field", + ], + ), ) yield vs @@ -163,8 +192,11 @@ async def test_validate_id_column_create(self, engine, vs): with pytest.raises( Exception, match=f"Id column, {test_id_column}, does not exist." ): - await AsyncPostgresVectorStore.create( - engine, table_name=DEFAULT_TABLE, id_column=test_id_column + await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ), ) async def test_validate_text_column_create(self, engine, vs): @@ -172,8 +204,11 @@ async def test_validate_text_column_create(self, engine, vs): with pytest.raises( Exception, match=f"Text column, {test_text_column}, does not exist." ): - await AsyncPostgresVectorStore.create( - engine, table_name=DEFAULT_TABLE, text_column=test_text_column + await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ), ) async def test_validate_embedding_column_create(self, engine, vs): @@ -182,10 +217,13 @@ async def test_validate_embedding_column_create(self, engine, vs): Exception, match=f"Embedding column, {test_embed_column}, does not exist.", ): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - table_name=DEFAULT_TABLE, - embedding_column=test_embed_column, + AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + embedding_column=test_embed_column, + ), ) async def test_validate_node_column_create(self, engine, vs): @@ -193,8 +231,11 @@ async def test_validate_node_column_create(self, engine, vs): with pytest.raises( Exception, match=f"Node column, {test_node_column}, does not exist." ): - await AsyncPostgresVectorStore.create( - engine, table_name=DEFAULT_TABLE, node_column=test_node_column + await run_on_background( + engine, + AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ), ) async def test_validate_ref_doc_id_column_create(self, engine, vs): @@ -203,10 +244,13 @@ async def test_validate_ref_doc_id_column_create(self, engine, vs): Exception, match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", ): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - table_name=DEFAULT_TABLE, - ref_doc_id_column=test_ref_doc_id_column, + AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ), ) async def test_validate_metadata_json_column_create(self, engine, vs): @@ -215,14 +259,17 @@ async def test_validate_metadata_json_column_create(self, engine, vs): Exception, match=f"Metadata column, {test_metadata_json_column}, does not exist.", ): - await AsyncPostgresVectorStore.create( + await run_on_background( engine, - table_name=DEFAULT_TABLE, - metadata_json_column=test_metadata_json_column, + AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, + ), ) async def test_async_add(self, engine, vs): - await vs.async_add(nodes) + await run_on_background(engine, vs.async_add(nodes)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 4 @@ -232,7 +279,7 @@ async def test_async_add_custom_vs(self, engine, custom_vs): for node in nodes: node.metadata["len"] = len(node.text) - await custom_vs.async_add(nodes) + await run_on_background(engine, custom_vs.async_add(nodes)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE_CUSTOM_VS}"') assert len(results) == 4 @@ -244,8 +291,8 @@ async def test_adelete(self, engine, vs): # Note: To be migrated to a pytest dependency on test_async_add # Blocked due to unexpected fixtures reloads while running integration test suite await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - await vs.async_add(nodes) - await vs.adelete(nodes[0].node_id) + await run_on_background(engine, vs.async_add(nodes)) + await run_on_background(engine, vs.adelete(nodes[0].node_id)) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 3 @@ -254,19 +301,24 @@ async def test_adelete_nodes(self, engine, vs): # Note: To be migrated to a pytest dependency on test_async_add # Blocked due to unexpected fixtures reloads while running integration test suite await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - await vs.async_add(nodes) - await vs.adelete_nodes( - node_ids=[nodes[0].node_id, nodes[1].node_id], - filters=MetadataFilters( - filters=[ - MetadataFilter( - key="text", - value="foo", - operator=FilterOperator.TEXT_MATCH, - ), - MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), - ], - condition=FilterCondition.OR, + await run_on_background(engine, vs.async_add(nodes)) + await run_on_background( + engine, + vs.adelete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", + value="foo", + operator=FilterOperator.TEXT_MATCH, + ), + MetadataFilter( + key="text", value="bar", operator=FilterOperator.EQ + ), + ], + condition=FilterCondition.OR, + ), ), ) @@ -277,23 +329,26 @@ async def test_aget_nodes(self, engine, vs): # Note: To be migrated to a pytest dependency on test_async_add # Blocked due to unexpected fixtures reloads while running integration test suite await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - await vs.async_add(nodes) - results = await vs.aget_nodes( - filters=MetadataFilters( - filters=[ - MetadataFilter( - key="text", - value="foo", - operator=FilterOperator.TEXT_MATCH, - ), - MetadataFilter( - key="text", - value="bar", - operator=FilterOperator.TEXT_MATCH, - ), - ], - condition=FilterCondition.AND, - ) + await run_on_background(engine, vs.async_add(nodes)) + results = await run_on_background( + engine, + vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", + value="foo", + operator=FilterOperator.TEXT_MATCH, + ), + MetadataFilter( + key="text", + value="bar", + operator=FilterOperator.TEXT_MATCH, + ), + ], + condition=FilterCondition.AND, + ) + ), ) assert len(results) == 1 @@ -303,11 +358,11 @@ async def test_aquery(self, engine, vs): # Note: To be migrated to a pytest dependency on test_async_add # Blocked due to unexpected fixtures reloads while running integration test suite await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - await vs.async_add(nodes) + await run_on_background(engine, vs.async_add(nodes)) query = VectorStoreQuery( query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 ) - results = await vs.aquery(query) + results = await run_on_background(engine, vs.aquery(query)) assert results.nodes is not None assert results.ids is not None @@ -323,7 +378,7 @@ async def test_aquery_filters(self, engine, custom_vs): for node in nodes: node.metadata["len"] = len(node.text) - await custom_vs.async_add(nodes) + await run_on_background(engine, custom_vs.async_add(nodes)) filters = MetadataFilters( filters=[ @@ -383,8 +438,8 @@ async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') - await vs.async_add(nodes) - await vs.aclear() + await run_on_background(engine, vs.async_add(nodes)) + await run_on_background(engine, vs.aclear()) results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') assert len(results) == 0 diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index cefcf85..05ebcf4 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio import os import uuid +from typing import Any, Coroutine import pytest import pytest_asyncio from llama_index.core.schema import TextNode from sqlalchemy import text -from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg import PostgresEngine, engine from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore from llama_index_cloud_sql_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, @@ -56,10 +57,23 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute(engine: PostgresEngine, query: str) -> None: - async with engine._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() + async def _impl(): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await run_on_background(engine, _impl()) @pytest.mark.asyncio(loop_scope="class") @@ -102,51 +116,59 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine._ainit_vector_store_table( - DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + await run_on_background( + engine, + engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ), ) - vs = await AsyncPostgresVectorStore.create( + vs = await run_on_background( engine, - table_name=DEFAULT_TABLE, + AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ), ) - await vs.async_add(nodes) - await vs.adrop_vector_index() + await run_on_background(engine, vs.async_add(nodes)) + await run_on_background(engine, vs.adrop_vector_index()) yield vs - async def test_aapply_vector_index(self, vs): + async def test_aapply_vector_index(self, vs, engine): index = HNSWIndex() - await vs.aapply_vector_index(index) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + await run_on_background(engine, vs.aapply_vector_index(index)) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) + await run_on_background(engine, vs.adrop_vector_index(DEFAULT_INDEX_NAME)) - async def test_areindex(self, vs): - if not await vs.is_valid_index(DEFAULT_INDEX_NAME): + async def test_areindex(self, vs, engine): + if not await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)): index = HNSWIndex() - await vs.aapply_vector_index(index) - await vs.areindex() - await vs.areindex(DEFAULT_INDEX_NAME) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) - await vs.adrop_vector_index() - - async def test_dropindex(self, vs): - await vs.adrop_vector_index() - result = await vs.is_valid_index(DEFAULT_INDEX_NAME) + await run_on_background(engine, vs.aapply_vector_index(index)) + await run_on_background(engine, vs.areindex()) + await run_on_background(engine, vs.areindex(DEFAULT_INDEX_NAME)) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) + await run_on_background(engine, vs.adrop_vector_index()) + + async def test_dropindex(self, vs, engine): + await run_on_background(engine, vs.adrop_vector_index()) + result = await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) assert not result - async def test_aapply_vector_index_ivfflat(self, vs): + async def test_aapply_vector_index_ivfflat(self, vs, engine): index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) - await vs.aapply_vector_index(index, concurrently=True) - assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await run_on_background( + engine, vs.aapply_vector_index(index, concurrently=True) + ) + assert await run_on_background(engine, vs.is_valid_index(DEFAULT_INDEX_NAME)) index = IVFFlatIndex( name="secondindex", distance_strategy=DistanceStrategy.INNER_PRODUCT, ) - await vs.aapply_vector_index(index) - assert await vs.is_valid_index("secondindex") - await vs.adrop_vector_index() - await vs.adrop_vector_index("secondindex") + await run_on_background(engine, vs.aapply_vector_index(index)) + assert await run_on_background(engine, vs.is_valid_index("secondindex")) + await run_on_background(engine, vs.adrop_vector_index()) + await run_on_background(engine, vs.adrop_vector_index("secondindex")) - async def test_is_valid_index(self, vs): - is_valid = await vs.is_valid_index("invalid_index") + async def test_is_valid_index(self, vs, engine): + is_valid = await run_on_background(engine, vs.is_valid_index("invalid_index")) assert is_valid == False diff --git a/tests/test_engine.py b/tests/test_engine.py index f6df414..31f0a36 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import os import uuid -from typing import Sequence +from typing import Any, Coroutine, Sequence import asyncpg # type: ignore import pytest @@ -46,27 +47,38 @@ def get_env_var(key: str, desc: str) -> str: return v +# Helper to bridge the Main Test Loop and the Engine Background Loop +async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any: + """Runs a coroutine on the engine's background loop.""" + if engine._loop: + return await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(coro, engine._loop) + ) + return await coro + + async def aexecute( engine: PostgresEngine, query: str, ) -> None: - async def run(engine, query): + async def _impl(): async with engine._pool.connect() as conn: await conn.execute(text(query)) await conn.commit() - await engine._run_as_async(run(engine, query)) + await run_on_background(engine, _impl()) async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: - async def run(engine, query): + async def _impl(): async with engine._pool.connect() as conn: result = await conn.execute(text(query)) result_map = result.mappings() result_fetch = result_map.fetchall() return result_fetch - return await engine._run_as_async(run(engine, query)) + result = await run_on_background(engine, _impl()) + return result @pytest.mark.asyncio @@ -150,7 +162,7 @@ async def test_from_engine( user, password, ): - async with Connector() as connector: + async with Connector(loop=asyncio.get_running_loop()) as connector: async def getconn() -> asyncpg.Connection: conn = await connector.connect_async( # type: ignore