1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import asyncio
1415import os
1516import uuid
17+ from typing import Any , Coroutine
1618
1719import pytest
1820import pytest_asyncio
3335table_name_async = "message_store" + str (uuid .uuid4 ())
3436
3537
38+ # Helper to bridge the Main Test Loop and the Engine Background Loop
39+ async def run_on_background (engine : PostgresEngine , coro : Coroutine ) -> Any :
40+ """Runs a coroutine on the engine's background loop."""
41+ return await asyncio .wrap_future (
42+ asyncio .run_coroutine_threadsafe (coro , engine ._loop )
43+ )
44+
45+
3646async def aexecute (engine : PostgresEngine , query : str ) -> None :
37- async with engine ._pool .connect () as conn :
38- await conn .execute (text (query ))
39- await conn .commit ()
47+ async def _impl ():
48+ async with engine ._pool .connect () as conn :
49+ await conn .execute (text (query ))
50+ await conn .commit ()
51+
52+ await run_on_background (engine , _impl ())
4053
4154
4255@pytest_asyncio .fixture
@@ -47,7 +60,10 @@ async def async_engine():
4760 instance = instance_id ,
4861 database = db_name ,
4962 )
50- await async_engine ._ainit_chat_history_table (table_name = table_name_async )
63+ await run_on_background (
64+ async_engine ,
65+ async_engine ._ainit_chat_history_table (table_name = table_name_async ),
66+ )
5167 yield async_engine
5268 # use default table for AsyncPostgresChatMessageHistory
5369 query = f'DROP TABLE IF EXISTS "{ table_name_async } "'
@@ -59,14 +75,19 @@ async def async_engine():
5975async def test_chat_message_history_async (
6076 async_engine : PostgresEngine ,
6177) -> None :
62- history = await AsyncPostgresChatMessageHistory .create (
63- engine = async_engine , session_id = "test" , table_name = table_name_async
78+ history = await run_on_background (
79+ async_engine ,
80+ AsyncPostgresChatMessageHistory .create (
81+ engine = async_engine , session_id = "test" , table_name = table_name_async
82+ ),
6483 )
6584 msg1 = HumanMessage (content = "hi!" )
6685 msg2 = AIMessage (content = "whats up?" )
67- await history .aadd_message (msg1 )
68- await history .aadd_message (msg2 )
69- messages = await history ._aget_messages ()
86+
87+ await run_on_background (async_engine , history .aadd_message (msg1 ))
88+ await run_on_background (async_engine , history .aadd_message (msg2 ))
89+
90+ messages = await run_on_background (async_engine , history ._aget_messages ())
7091
7192 # verify messages are correct
7293 assert messages [0 ].content == "hi!"
@@ -75,48 +96,71 @@ async def test_chat_message_history_async(
7596 assert type (messages [1 ]) is AIMessage
7697
7798 # verify clear() clears message history
78- await history .aclear ()
79- assert len (await history ._aget_messages ()) == 0
99+ await run_on_background (async_engine , history .aclear ())
100+ messages_after_clear = await run_on_background (
101+ async_engine , history ._aget_messages ()
102+ )
103+ assert len (messages_after_clear ) == 0
80104
81105
82106@pytest .mark .asyncio
83107async def test_chat_message_history_sync_messages (
84108 async_engine : PostgresEngine ,
85109) -> None :
86- history1 = await AsyncPostgresChatMessageHistory .create (
87- engine = async_engine , session_id = "test" , table_name = table_name_async
110+ history1 = await run_on_background (
111+ async_engine ,
112+ AsyncPostgresChatMessageHistory .create (
113+ engine = async_engine , session_id = "test" , table_name = table_name_async
114+ ),
88115 )
89- history2 = await AsyncPostgresChatMessageHistory .create (
90- engine = async_engine , session_id = "test" , table_name = table_name_async
116+ history2 = await run_on_background (
117+ async_engine ,
118+ AsyncPostgresChatMessageHistory .create (
119+ engine = async_engine , session_id = "test" , table_name = table_name_async
120+ ),
91121 )
92122 msg1 = HumanMessage (content = "hi!" )
93123 msg2 = AIMessage (content = "whats up?" )
94- await history1 .aadd_message (msg1 )
95- await history2 .aadd_message (msg2 )
124+ await run_on_background (async_engine , history1 .aadd_message (msg1 ))
125+ await run_on_background (async_engine , history2 .aadd_message (msg2 ))
126+
127+ len_history1 = len (await run_on_background (async_engine , history1 ._aget_messages ()))
128+ len_history2 = len (await run_on_background (async_engine , history2 ._aget_messages ()))
96129
97- assert len ( await history1 . _aget_messages ()) == 2
98- assert len ( await history2 . _aget_messages ()) == 2
130+ assert len_history1 == 2
131+ assert len_history2 == 2
99132
100133 # verify clear() clears message history
101- await history2 .aclear ()
102- assert len (await history2 ._aget_messages ()) == 0
134+ await run_on_background (async_engine , history2 .aclear ())
135+ len_history2_after_clear = len (
136+ await run_on_background (async_engine , history2 ._aget_messages ())
137+ )
138+ assert len_history2_after_clear == 0
103139
104140
105141@pytest .mark .asyncio
106142async def test_chat_table_async (async_engine ):
107143 with pytest .raises (ValueError ):
108- await AsyncPostgresChatMessageHistory .create (
109- engine = async_engine , session_id = "test" , table_name = "doesnotexist"
144+ await run_on_background (
145+ async_engine ,
146+ AsyncPostgresChatMessageHistory .create (
147+ engine = async_engine , session_id = "test" , table_name = "doesnotexist"
148+ ),
110149 )
111150
112151
113152@pytest .mark .asyncio
114153async def test_chat_schema_async (async_engine ):
115154 table_name = "test_table" + str (uuid .uuid4 ())
116- await async_engine ._ainit_document_table (table_name = table_name )
155+ await run_on_background (
156+ async_engine , async_engine ._ainit_document_table (table_name = table_name )
157+ )
117158 with pytest .raises (IndexError ):
118- await AsyncPostgresChatMessageHistory .create (
119- engine = async_engine , session_id = "test" , table_name = table_name
159+ await run_on_background (
160+ async_engine ,
161+ AsyncPostgresChatMessageHistory .create (
162+ engine = async_engine , session_id = "test" , table_name = table_name
163+ ),
120164 )
121165
122166 query = f'DROP TABLE IF EXISTS "{ table_name } "'
0 commit comments