diff --git a/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py b/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py index 61fd63c8..df59a6c8 100644 --- a/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py +++ b/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py @@ -1,11 +1,10 @@ import json import os -from typing import Any, Callable, Optional, Type, List - +from typing import Any, Callable, List, Optional, Type try: from qdrant_client import QdrantClient - from qdrant_client.http.models import Filter, FieldCondition, MatchValue + from qdrant_client.http.models import FieldCondition, Filter, MatchValue QDRANT_AVAILABLE = True except ImportError: @@ -57,8 +56,8 @@ class QdrantVectorSearchTool(BaseTool): description: str = "A tool to search the Qdrant database for relevant information on internal documents." args_schema: Type[BaseModel] = QdrantToolSchema query: Optional[str] = None - filter_by: Optional[str] = None - filter_value: Optional[str] = None + filter_conditions: Optional[list[tuple[str, Any]]] = [] + search_filter: Optional[Filter] = None collection_name: Optional[str] = None limit: Optional[int] = Field(default=3) score_threshold: float = Field(default=0.35) @@ -101,6 +100,18 @@ def __init__(self, **kwargs): "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. " "Please install it with: uv add qdrant-client" ) + if kwargs.get("filter_conditions"): + must_conditions = [] + for filter_condition in kwargs.get("filter_conditions"): + must_conditions.append( + FieldCondition( + key=filter_condition[0], + match=MatchValue(value=filter_condition[1]), + ) + ) + self.search_filter = Filter(must=must_conditions) + else: + self.search_filter = None def _run( self, @@ -126,14 +137,26 @@ def _run( if not self.qdrant_url: raise ValueError("QDRANT_URL is not set") - # Create filter if filter parameters are provided - search_filter = None + # If filter_by and filter_value are provided, add them to the search filter + # without modifying the original search filter + search_filter = self.search_filter.copy() if self.search_filter else None if filter_by and filter_value: - search_filter = Filter( - must=[ + if ( + search_filter + and hasattr(search_filter, "must") + and isinstance(search_filter.must, list) + ): + search_filter.must.append( FieldCondition(key=filter_by, match=MatchValue(value=filter_value)) - ] - ) + ) + else: + search_filter = Filter( + must=[ + FieldCondition( + key=filter_by, match=MatchValue(value=filter_value) + ) + ] + ) # Search in Qdrant using the built-in query method query_vector = ( @@ -151,12 +174,11 @@ def _run( # Format results similar to storage implementation results = [] - # Extract the list of ScoredPoint objects from the tuple - for point in search_results: + for point in search_results.points: result = { - "metadata": point[1][0].payload.get("metadata", {}), - "context": point[1][0].payload.get("text", ""), - "distance": point[1][0].score, + "distance": point.score, + "metadata": point.payload.get("metadata", {}), + "context": point.payload.get("text", ""), } results.append(result)