Skip to content

Commit 3bd7347

Browse files
committed
fix tests
1 parent cc68c06 commit 3bd7347

File tree

6 files changed

+237
-140
lines changed

6 files changed

+237
-140
lines changed

tests/test_async_chat_store.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ async def async_engine(
115115
async def chat_store(self, async_engine):
116116
await run_on_background(
117117
async_engine,
118-
await async_engine._ainit_chat_store_table(
119-
table_name=default_table_name_async
120-
),
118+
async_engine._ainit_chat_store_table(table_name=default_table_name_async),
121119
)
122120
chat_store = await AsyncPostgresChatStore.create(
123121
engine=async_engine, table_name=default_table_name_async
@@ -138,21 +136,23 @@ async def test_async_add_message(self, async_engine, chat_store):
138136
key = "test_add_key"
139137

140138
message = ChatMessage(content="add_message_test", role="user")
141-
await chat_store.async_add_message(key, message=message)
139+
await run_on_background(
140+
async_engine, chat_store.async_add_message(key, message=message)
141+
)
142142

143143
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
144144
results = await afetch(async_engine, query)
145145
result = results[0]
146146
assert result["message"] == message.model_dump()
147147

148-
async def test_aset_and_aget_messages(self, chat_store):
148+
async def test_aset_and_aget_messages(self, async_engine, chat_store):
149149
message_1 = ChatMessage(content="First message", role="user")
150150
message_2 = ChatMessage(content="Second message", role="user")
151151
messages = [message_1, message_2]
152152
key = "test_set_and_get_key"
153-
await chat_store.aset_messages(key, messages)
153+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
154154

155-
results = await chat_store.aget_messages(key)
155+
results = await run_on_background(async_engine, chat_store.aget_messages(key))
156156

157157
assert len(results) == 2
158158
assert results[0].content == message_1.content
@@ -161,9 +161,9 @@ async def test_aset_and_aget_messages(self, chat_store):
161161
async def test_adelete_messages(self, async_engine, chat_store):
162162
messages = [ChatMessage(content="Message to delete", role="user")]
163163
key = "test_delete_key"
164-
await chat_store.aset_messages(key, messages)
164+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
165165

166-
await chat_store.adelete_messages(key)
166+
await run_on_background(async_engine, chat_store.adelete_messages(key))
167167
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
168168
results = await afetch(async_engine, query)
169169

@@ -174,9 +174,9 @@ async def test_adelete_message(self, async_engine, chat_store):
174174
message_2 = ChatMessage(content="Delete me", role="user")
175175
messages = [message_1, message_2]
176176
key = "test_delete_message_key"
177-
await chat_store.aset_messages(key, messages)
177+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
178178

179-
await chat_store.adelete_message(key, 1)
179+
await run_on_background(async_engine, chat_store.adelete_message(key, 1))
180180
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
181181
results = await afetch(async_engine, query)
182182

@@ -189,9 +189,9 @@ async def test_adelete_last_message(self, async_engine, chat_store):
189189
message_3 = ChatMessage(content="Message 3", role="user")
190190
messages = [message_1, message_2, message_3]
191191
key = "test_delete_last_message_key"
192-
await chat_store.aset_messages(key, messages)
192+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
193193

194-
await chat_store.adelete_last_message(key)
194+
await run_on_background(async_engine, chat_store.adelete_last_message(key))
195195
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;"""
196196
results = await afetch(async_engine, query)
197197

@@ -204,18 +204,22 @@ async def test_aget_keys(self, async_engine, chat_store):
204204
message_2 = [ChatMessage(content="Second message", role="user")]
205205
key_1 = "key1"
206206
key_2 = "key2"
207-
await chat_store.aset_messages(key_1, message_1)
208-
await chat_store.aset_messages(key_2, message_2)
207+
await run_on_background(
208+
async_engine, chat_store.aset_messages(key_1, message_1)
209+
)
210+
await run_on_background(
211+
async_engine, chat_store.aset_messages(key_2, message_2)
212+
)
209213

210-
keys = await chat_store.aget_keys()
214+
keys = await run_on_background(async_engine, chat_store.aget_keys())
211215

212216
assert key_1 in keys
213217
assert key_2 in keys
214218

215219
async def test_set_exisiting_key(self, async_engine, chat_store):
216220
message_1 = [ChatMessage(content="First message", role="user")]
217221
key = "test_set_exisiting_key"
218-
await chat_store.aset_messages(key, message_1)
222+
await run_on_background(async_engine, chat_store.aset_messages(key, message_1))
219223

220224
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
221225
results = await afetch(async_engine, query)
@@ -228,7 +232,7 @@ async def test_set_exisiting_key(self, async_engine, chat_store):
228232
message_3 = ChatMessage(content="Third message", role="user")
229233
messages = [message_2, message_3]
230234

231-
await chat_store.aset_messages(key, messages)
235+
await run_on_background(async_engine, chat_store.aset_messages(key, messages))
232236

233237
query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';"""
234238
results = await afetch(async_engine, query)

0 commit comments

Comments
 (0)