11"""Tests for operation utilities."""
22
33import os
4- from unittest .mock import Mock
4+ from unittest .mock import Mock , patch
55
66import pytest
77from bson import ObjectId
88from pymongo import MongoClient
99from pymongo .collection import Collection
1010
11- from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts
11+ from pymongo_vectorsearch_utils import drop_vector_search_index
12+ from pymongo_vectorsearch_utils .index import create_vector_search_index , wait_for_docs_in_index
13+ from pymongo_vectorsearch_utils .operation import bulk_embed_and_insert_texts , execute_search_query
1214
1315DB_NAME = "vectorsearch_utils_test"
1416COLLECTION_NAME = "test_operation"
17+ VECTOR_INDEX_NAME = "operation_vector_index"
1518
1619
1720@pytest .fixture (scope = "module" )
@@ -21,6 +24,15 @@ def client():
2124 yield client
2225 client .close ()
2326
27+ @pytest .fixture (scope = "module" )
28+ def preserved_collection (client ):
29+ if COLLECTION_NAME not in client [DB_NAME ].list_collection_names ():
30+ clxn = client [DB_NAME ].create_collection (COLLECTION_NAME )
31+ else :
32+ clxn = client [DB_NAME ][COLLECTION_NAME ]
33+ clxn .delete_many ({})
34+ yield clxn
35+ clxn .delete_many ({})
2436
2537@pytest .fixture
2638def collection (client ):
@@ -266,3 +278,176 @@ def test_custom_field_names(self, collection: Collection, mock_embedding_func):
266278 assert "vector" in doc
267279 assert doc ["content" ] == texts [0 ]
268280 assert doc ["vector" ] == [0.0 , 0.0 , 0.0 ]
281+
282+
283+ class TestExecuteSearchQuery :
284+ @pytest .fixture (scope = "class" , autouse = True )
285+ def vector_search_index (self , client ):
286+ coll = client [DB_NAME ][COLLECTION_NAME ]
287+ if len (coll .list_search_indexes (VECTOR_INDEX_NAME ).to_list ()) == 0 :
288+ create_vector_search_index (
289+ collection = coll ,
290+ index_name = VECTOR_INDEX_NAME ,
291+ dimensions = 3 ,
292+ path = "embedding" ,
293+ similarity = "cosine" ,
294+ filters = ["category" , "color" , "wheels" ],
295+ wait_until_complete = 120 ,
296+ )
297+ yield
298+ drop_vector_search_index (collection = coll , index_name = VECTOR_INDEX_NAME )
299+
300+ @pytest .fixture (scope = "class" , autouse = True )
301+ def sample_docs (self , preserved_collection : Collection ):
302+ texts = ["apple fruit" , "banana fruit" , "car vehicle" , "bike vehicle" ]
303+ metadatas = [
304+ {"category" : "fruit" , "color" : "red" },
305+ {"category" : "fruit" , "color" : "yellow" },
306+ {"category" : "vehicle" , "wheels" : 4 },
307+ {"category" : "vehicle" , "wheels" : 2 },
308+ ]
309+
310+ def embeddings (texts ):
311+ mapping = {
312+ "apple fruit" : [1.0 , 0.5 , 0.0 ],
313+ "banana fruit" : [0.5 , 0.5 , 0.0 ],
314+ "car vehicle" : [0.0 , 0.5 , 1.0 ],
315+ "bike vehicle" : [0.0 , 1.0 , 0.5 ],
316+ }
317+ return [mapping [text ] for text in texts ]
318+
319+ bulk_embed_and_insert_texts (
320+ texts = texts ,
321+ metadatas = metadatas ,
322+ embedding_func = embeddings ,
323+ collection = preserved_collection ,
324+ text_key = "text" ,
325+ embedding_key = "embedding" ,
326+ )
327+ # Add a document that should not be returned in searches
328+ preserved_collection .insert_one ({'_id' : ObjectId ('68c1a038fd976373aa4ec19f' ), 'category' : 'fruit' , 'color' : 'red' , 'embedding' : [1.0 , 1.0 , 1.0 ]})
329+ wait_for_docs_in_index (preserved_collection , VECTOR_INDEX_NAME , n_docs = 5 )
330+ return preserved_collection
331+
332+ def test_basic_search_query (self , sample_docs : Collection ):
333+ query_vector = [1.0 , 0.5 , 0.0 ]
334+
335+ result = execute_search_query (
336+ query_vector = query_vector ,
337+ collection = sample_docs ,
338+ embedding_key = "embedding" ,
339+ text_key = "text" ,
340+ index_name = VECTOR_INDEX_NAME ,
341+ k = 2 ,
342+ )
343+
344+ assert len (result ) == 2
345+ assert result [0 ]["text" ] == "apple fruit"
346+ assert result [1 ]["text" ] == "banana fruit"
347+ assert "score" in result [0 ]
348+ assert "score" in result [1 ]
349+
350+ def test_search_with_pre_filter (self , sample_docs : Collection ):
351+ query_vector = [1.0 , 0.5 , 1.0 ]
352+ pre_filter = {"category" : "fruit" }
353+
354+ result = execute_search_query (
355+ query_vector = query_vector ,
356+ collection = sample_docs ,
357+ embedding_key = "embedding" ,
358+ text_key = "text" ,
359+ index_name = VECTOR_INDEX_NAME ,
360+ k = 4 ,
361+ pre_filter = pre_filter ,
362+ )
363+
364+ assert len (result ) == 2
365+ assert result [0 ]["category" ] == "fruit"
366+ assert result [1 ]["category" ] == "fruit"
367+
368+ def test_search_with_post_filter_pipeline (self , sample_docs : Collection ):
369+ query_vector = [1.0 , 0.5 , 0.0 ]
370+ post_filter_pipeline = [
371+ {"$match" : {"score" : {"$gte" : 0.99 }}},
372+ {"$sort" : {"score" : - 1 }},
373+ ]
374+
375+ result = execute_search_query (
376+ query_vector = query_vector ,
377+ collection = sample_docs ,
378+ embedding_key = "embedding" ,
379+ text_key = "text" ,
380+ index_name = VECTOR_INDEX_NAME ,
381+ k = 2 ,
382+ post_filter_pipeline = post_filter_pipeline ,
383+ )
384+
385+ assert len (result ) == 1
386+
387+ def test_search_with_embeddings_included (self , sample_docs : Collection ):
388+ query_vector = [1.0 , 0.5 , 0.0 ]
389+
390+ result = execute_search_query (
391+ query_vector = query_vector ,
392+ collection = sample_docs ,
393+ embedding_key = "embedding" ,
394+ text_key = "text" ,
395+ index_name = VECTOR_INDEX_NAME ,
396+ k = 1 ,
397+ include_embeddings = True ,
398+ )
399+
400+ assert len (result ) == 1
401+ assert "embedding" in result [0 ]
402+ assert result [0 ]["embedding" ] == [1.0 , 0.5 , 0.0 ]
403+
404+ def test_search_with_custom_field_names (self , sample_docs : Collection ):
405+ query_vector = [1.0 , 0.5 , 0.25 ]
406+
407+ mock_cursor = [
408+ {
409+ "_id" : ObjectId (),
410+ "content" : "apple fruit" ,
411+ "vector" : [1.0 , 0.5 , 0.25 ],
412+ "score" : 0.9 ,
413+ }
414+ ]
415+
416+ with patch .object (sample_docs , "aggregate" ) as mock_aggregate :
417+ mock_aggregate .return_value = mock_cursor
418+
419+ result = execute_search_query (
420+ query_vector = query_vector ,
421+ collection = sample_docs ,
422+ embedding_key = "vector" ,
423+ text_key = "content" ,
424+ index_name = VECTOR_INDEX_NAME ,
425+ k = 1 ,
426+ )
427+
428+ assert len (result ) == 1
429+ assert "content" in result [0 ]
430+ assert result [0 ]["content" ] == "apple fruit"
431+
432+ pipeline_arg = mock_aggregate .call_args [0 ][0 ]
433+ vector_search_stage = pipeline_arg [0 ]["$vectorSearch" ]
434+ assert vector_search_stage ["path" ] == "vector"
435+ assert {"$project" : {"vector" : 0 }} in pipeline_arg
436+
437+ def test_search_filters_documents_without_text_key (self , sample_docs : Collection ):
438+ query_vector = [1.0 , 0.5 , 0.0 ]
439+
440+ result = execute_search_query (
441+ query_vector = query_vector ,
442+ collection = sample_docs ,
443+ embedding_key = "embedding" ,
444+ text_key = "text" ,
445+ index_name = VECTOR_INDEX_NAME ,
446+ k = 3 ,
447+ )
448+
449+ # Should only return documents with text field
450+ assert len (result ) == 2
451+ assert all ("text" in doc for doc in result )
452+ assert result [0 ]["text" ] == "apple fruit"
453+ assert result [1 ]["text" ] == "banana fruit"
0 commit comments