Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cloud-sql-python-connector[asyncpg]==1.18.4
cloud-sql-python-connector[asyncpg]==1.18.5
llama-index-core==0.14.4
pgvector==0.4.1
SQLAlchemy[asyncio]==2.0.43
79 changes: 52 additions & 27 deletions tests/test_async_chat_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from typing import Sequence
from typing import Any, Coroutine, Sequence

import pytest
import pytest_asyncio
Expand All @@ -28,18 +28,35 @@
sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresChatStore. Use PostgresChatStore interface instead."


# Helper to bridge the Main Test Loop and the Engine Background Loop
async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any:
"""Runs a coroutine on the engine's background loop."""
if engine._loop:
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, engine._loop)
)
return await coro


async def aexecute(engine: PostgresEngine, query: str) -> None:
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()
async def _impl():
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()

await run_on_background(engine, _impl())


async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]:
async with engine._pool.connect() as conn:
result = await conn.execute(text(query))
result_map = result.mappings()
result_fetch = result_map.fetchall()
return result_fetch
async def _impl():
async with engine._pool.connect() as conn:
result = await conn.execute(text(query))
result_map = result.mappings()
result_fetch = result_map.fetchall()
return result_fetch

result = await run_on_background(engine, _impl())
return result


def get_env_var(key: str, desc: str) -> str:
Expand Down Expand Up @@ -96,8 +113,10 @@ async def async_engine(

@pytest_asyncio.fixture(scope="class")
async def chat_store(self, async_engine):
await async_engine._ainit_chat_store_table(table_name=default_table_name_async)

await run_on_background(
async_engine,
async_engine._ainit_chat_store_table(table_name=default_table_name_async),
)
chat_store = await AsyncPostgresChatStore.create(
engine=async_engine, table_name=default_table_name_async
)
Expand All @@ -117,21 +136,23 @@ async def test_async_add_message(self, async_engine, chat_store):
key = "test_add_key"

message = ChatMessage(content="add_message_test", role="user")
await chat_store.async_add_message(key, message=message)
await run_on_background(
async_engine, chat_store.async_add_message(key, message=message)
)

query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
results = await afetch(async_engine, query)
result = results[0]
assert result["message"] == message.model_dump()

async def test_aset_and_aget_messages(self, chat_store):
async def test_aset_and_aget_messages(self, async_engine, chat_store):
message_1 = ChatMessage(content="First message", role="user")
message_2 = ChatMessage(content="Second message", role="user")
messages = [message_1, message_2]
key = "test_set_and_get_key"
await chat_store.aset_messages(key, messages)
await run_on_background(async_engine, chat_store.aset_messages(key, messages))

results = await chat_store.aget_messages(key)
results = await run_on_background(async_engine, chat_store.aget_messages(key))

assert len(results) == 2
assert results[0].content == message_1.content
Expand All @@ -140,9 +161,9 @@ async def test_aset_and_aget_messages(self, chat_store):
async def test_adelete_messages(self, async_engine, chat_store):
messages = [ChatMessage(content="Message to delete", role="user")]
key = "test_delete_key"
await chat_store.aset_messages(key, messages)
await run_on_background(async_engine, chat_store.aset_messages(key, messages))

await chat_store.adelete_messages(key)
await run_on_background(async_engine, chat_store.adelete_messages(key))
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
results = await afetch(async_engine, query)

Expand All @@ -153,9 +174,9 @@ async def test_adelete_message(self, async_engine, chat_store):
message_2 = ChatMessage(content="Delete me", role="user")
messages = [message_1, message_2]
key = "test_delete_message_key"
await chat_store.aset_messages(key, messages)
await run_on_background(async_engine, chat_store.aset_messages(key, messages))

await chat_store.adelete_message(key, 1)
await run_on_background(async_engine, chat_store.adelete_message(key, 1))
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
results = await afetch(async_engine, query)

Expand All @@ -168,9 +189,9 @@ async def test_adelete_last_message(self, async_engine, chat_store):
message_3 = ChatMessage(content="Message 3", role="user")
messages = [message_1, message_2, message_3]
key = "test_delete_last_message_key"
await chat_store.aset_messages(key, messages)
await run_on_background(async_engine, chat_store.aset_messages(key, messages))

await chat_store.adelete_last_message(key)
await run_on_background(async_engine, chat_store.adelete_last_message(key))
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
results = await afetch(async_engine, query)

Expand All @@ -183,18 +204,22 @@ async def test_aget_keys(self, async_engine, chat_store):
message_2 = [ChatMessage(content="Second message", role="user")]
key_1 = "key1"
key_2 = "key2"
await chat_store.aset_messages(key_1, message_1)
await chat_store.aset_messages(key_2, message_2)
await run_on_background(
async_engine, chat_store.aset_messages(key_1, message_1)
)
await run_on_background(
async_engine, chat_store.aset_messages(key_2, message_2)
)

keys = await chat_store.aget_keys()
keys = await run_on_background(async_engine, chat_store.aget_keys())

assert key_1 in keys
assert key_2 in keys

async def test_set_exisiting_key(self, async_engine, chat_store):
message_1 = [ChatMessage(content="First message", role="user")]
key = "test_set_exisiting_key"
await chat_store.aset_messages(key, message_1)
await run_on_background(async_engine, chat_store.aset_messages(key, message_1))

query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
results = await afetch(async_engine, query)
Expand All @@ -207,7 +232,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store):
message_3 = ChatMessage(content="Third message", role="user")
messages = [message_2, message_3]

await chat_store.aset_messages(key, messages)
await run_on_background(async_engine, chat_store.aset_messages(key, messages))

query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
results = await afetch(async_engine, query)
Expand Down
Loading