Skip to content

Commit 2f62745

Browse files
committed
chore: Update tests for Async Classes
1 parent c5342a6 commit 2f62745

8 files changed

+1030
-604
lines changed

tests/test_async_chatmessagehistory.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
import os
1516
import uuid
17+
from typing import Any, Coroutine
1618

1719
import pytest
1820
import pytest_asyncio
@@ -33,10 +35,21 @@
3335
table_name_async = "message_store" + str(uuid.uuid4())
3436

3537

38+
# Helper to bridge the Main Test Loop and the Engine Background Loop
39+
async def run_on_background(engine: PostgresEngine, coro: Coroutine) -> Any:
40+
"""Runs a coroutine on the engine's background loop."""
41+
return await asyncio.wrap_future(
42+
asyncio.run_coroutine_threadsafe(coro, engine._loop)
43+
)
44+
45+
3646
async def aexecute(engine: PostgresEngine, query: str) -> None:
37-
async with engine._pool.connect() as conn:
38-
await conn.execute(text(query))
39-
await conn.commit()
47+
async def _impl():
48+
async with engine._pool.connect() as conn:
49+
await conn.execute(text(query))
50+
await conn.commit()
51+
52+
await run_on_background(engine, _impl())
4053

4154

4255
@pytest_asyncio.fixture
@@ -47,7 +60,10 @@ async def async_engine():
4760
instance=instance_id,
4861
database=db_name,
4962
)
50-
await async_engine._ainit_chat_history_table(table_name=table_name_async)
63+
await run_on_background(
64+
async_engine,
65+
async_engine._ainit_chat_history_table(table_name=table_name_async),
66+
)
5167
yield async_engine
5268
# use default table for AsyncPostgresChatMessageHistory
5369
query = f'DROP TABLE IF EXISTS "{table_name_async}"'
@@ -59,14 +75,19 @@ async def async_engine():
5975
async def test_chat_message_history_async(
6076
async_engine: PostgresEngine,
6177
) -> None:
62-
history = await AsyncPostgresChatMessageHistory.create(
63-
engine=async_engine, session_id="test", table_name=table_name_async
78+
history = await run_on_background(
79+
async_engine,
80+
AsyncPostgresChatMessageHistory.create(
81+
engine=async_engine, session_id="test", table_name=table_name_async
82+
),
6483
)
6584
msg1 = HumanMessage(content="hi!")
6685
msg2 = AIMessage(content="whats up?")
67-
await history.aadd_message(msg1)
68-
await history.aadd_message(msg2)
69-
messages = await history._aget_messages()
86+
87+
await run_on_background(async_engine, history.aadd_message(msg1))
88+
await run_on_background(async_engine, history.aadd_message(msg2))
89+
90+
messages = await run_on_background(async_engine, history._aget_messages())
7091

7192
# verify messages are correct
7293
assert messages[0].content == "hi!"
@@ -75,48 +96,71 @@ async def test_chat_message_history_async(
7596
assert type(messages[1]) is AIMessage
7697

7798
# verify clear() clears message history
78-
await history.aclear()
79-
assert len(await history._aget_messages()) == 0
99+
await run_on_background(async_engine, history.aclear())
100+
messages_after_clear = await run_on_background(
101+
async_engine, history._aget_messages()
102+
)
103+
assert len(messages_after_clear) == 0
80104

81105

82106
@pytest.mark.asyncio
83107
async def test_chat_message_history_sync_messages(
84108
async_engine: PostgresEngine,
85109
) -> None:
86-
history1 = await AsyncPostgresChatMessageHistory.create(
87-
engine=async_engine, session_id="test", table_name=table_name_async
110+
history1 = await run_on_background(
111+
async_engine,
112+
AsyncPostgresChatMessageHistory.create(
113+
engine=async_engine, session_id="test", table_name=table_name_async
114+
),
88115
)
89-
history2 = await AsyncPostgresChatMessageHistory.create(
90-
engine=async_engine, session_id="test", table_name=table_name_async
116+
history2 = await run_on_background(
117+
async_engine,
118+
AsyncPostgresChatMessageHistory.create(
119+
engine=async_engine, session_id="test", table_name=table_name_async
120+
),
91121
)
92122
msg1 = HumanMessage(content="hi!")
93123
msg2 = AIMessage(content="whats up?")
94-
await history1.aadd_message(msg1)
95-
await history2.aadd_message(msg2)
124+
await run_on_background(async_engine, history1.aadd_message(msg1))
125+
await run_on_background(async_engine, history2.aadd_message(msg2))
126+
127+
len_history1 = len(await run_on_background(async_engine, history1._aget_messages()))
128+
len_history2 = len(await run_on_background(async_engine, history2._aget_messages()))
96129

97-
assert len(await history1._aget_messages()) == 2
98-
assert len(await history2._aget_messages()) == 2
130+
assert len_history1 == 2
131+
assert len_history2 == 2
99132

100133
# verify clear() clears message history
101-
await history2.aclear()
102-
assert len(await history2._aget_messages()) == 0
134+
await run_on_background(async_engine, history2.aclear())
135+
len_history2_after_clear = len(
136+
await run_on_background(async_engine, history2._aget_messages())
137+
)
138+
assert len_history2_after_clear == 0
103139

104140

105141
@pytest.mark.asyncio
106142
async def test_chat_table_async(async_engine):
107143
with pytest.raises(ValueError):
108-
await AsyncPostgresChatMessageHistory.create(
109-
engine=async_engine, session_id="test", table_name="doesnotexist"
144+
await run_on_background(
145+
async_engine,
146+
AsyncPostgresChatMessageHistory.create(
147+
engine=async_engine, session_id="test", table_name="doesnotexist"
148+
),
110149
)
111150

112151

113152
@pytest.mark.asyncio
114153
async def test_chat_schema_async(async_engine):
115154
table_name = "test_table" + str(uuid.uuid4())
116-
await async_engine._ainit_document_table(table_name=table_name)
155+
await run_on_background(
156+
async_engine, async_engine._ainit_document_table(table_name=table_name)
157+
)
117158
with pytest.raises(IndexError):
118-
await AsyncPostgresChatMessageHistory.create(
119-
engine=async_engine, session_id="test", table_name=table_name
159+
await run_on_background(
160+
async_engine,
161+
AsyncPostgresChatMessageHistory.create(
162+
engine=async_engine, session_id="test", table_name=table_name
163+
),
120164
)
121165

122166
query = f'DROP TABLE IF EXISTS "{table_name}"'

0 commit comments

Comments
 (0)