Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cloud-sql-python-connector[asyncpg]==1.18.4
cloud-sql-python-connector[asyncpg]==1.18.5
numpy==2.3.3; python_version >= "3.11"
numpy==2.2.6; python_version == "3.10"
numpy==2.0.2; python_version <= "3.9"
Expand Down
98 changes: 72 additions & 26 deletions tests/test_async_chatmessagehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uuid
from typing import Any, Coroutine

import pytest
import pytest_asyncio
Expand All @@ -33,10 +35,23 @@
table_name_async = "message_store" + str(uuid.uuid4())


# Helper to bridge the Main Test Loop and the Engine Background Loop
async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any:
"""Runs a coroutine on the engine's background loop."""
if engine._loop:
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, engine._loop)
)
return await coro


async def aexecute(engine: PostgresEngine, query: str) -> None:
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()
async def _impl():
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()

await run_on_background(engine, _impl())


@pytest_asyncio.fixture
Expand All @@ -47,7 +62,10 @@ async def async_engine():
instance=instance_id,
database=db_name,
)
await async_engine._ainit_chat_history_table(table_name=table_name_async)
await run_on_background(
async_engine,
async_engine._ainit_chat_history_table(table_name=table_name_async),
)
yield async_engine
# use default table for AsyncPostgresChatMessageHistory
query = f'DROP TABLE IF EXISTS "{table_name_async}"'
Expand All @@ -59,14 +77,19 @@ async def async_engine():
async def test_chat_message_history_async(
async_engine: PostgresEngine,
) -> None:
history = await AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
history = await run_on_background(
async_engine,
AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
),
)
msg1 = HumanMessage(content="hi!")
msg2 = AIMessage(content="whats up?")
await history.aadd_message(msg1)
await history.aadd_message(msg2)
messages = await history._aget_messages()

await run_on_background(async_engine, history.aadd_message(msg1))
await run_on_background(async_engine, history.aadd_message(msg2))

messages = await run_on_background(async_engine, history._aget_messages())

# verify messages are correct
assert messages[0].content == "hi!"
Expand All @@ -75,48 +98,71 @@ async def test_chat_message_history_async(
assert type(messages[1]) is AIMessage

# verify clear() clears message history
await history.aclear()
assert len(await history._aget_messages()) == 0
await run_on_background(async_engine, history.aclear())
messages_after_clear = await run_on_background(
async_engine, history._aget_messages()
)
assert len(messages_after_clear) == 0


@pytest.mark.asyncio
async def test_chat_message_history_sync_messages(
async_engine: PostgresEngine,
) -> None:
history1 = await AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
history1 = await run_on_background(
async_engine,
AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
),
)
history2 = await AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
history2 = await run_on_background(
async_engine,
AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name_async
),
)
msg1 = HumanMessage(content="hi!")
msg2 = AIMessage(content="whats up?")
await history1.aadd_message(msg1)
await history2.aadd_message(msg2)
await run_on_background(async_engine, history1.aadd_message(msg1))
await run_on_background(async_engine, history2.aadd_message(msg2))

len_history1 = len(await run_on_background(async_engine, history1._aget_messages()))
len_history2 = len(await run_on_background(async_engine, history2._aget_messages()))

assert len(await history1._aget_messages()) == 2
assert len(await history2._aget_messages()) == 2
assert len_history1 == 2
assert len_history2 == 2

# verify clear() clears message history
await history2.aclear()
assert len(await history2._aget_messages()) == 0
await run_on_background(async_engine, history2.aclear())
len_history2_after_clear = len(
await run_on_background(async_engine, history2._aget_messages())
)
assert len_history2_after_clear == 0


@pytest.mark.asyncio
async def test_chat_table_async(async_engine):
with pytest.raises(ValueError):
await AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name="doesnotexist"
await run_on_background(
async_engine,
AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name="doesnotexist"
),
)


@pytest.mark.asyncio
async def test_chat_schema_async(async_engine):
table_name = "test_table" + str(uuid.uuid4())
await async_engine._ainit_document_table(table_name=table_name)
await run_on_background(
async_engine, async_engine._ainit_document_table(table_name=table_name)
)
with pytest.raises(IndexError):
await AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name
await run_on_background(
async_engine,
AsyncPostgresChatMessageHistory.create(
engine=async_engine, session_id="test", table_name=table_name
),
)

query = f'DROP TABLE IF EXISTS "{table_name}"'
Expand Down
Loading