Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import os
from typing import Any, Callable, Optional, Type, List

from typing import Any, Callable, List, Optional, Type

try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from qdrant_client.http.models import FieldCondition, Filter, MatchValue

QDRANT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -57,8 +56,8 @@ class QdrantVectorSearchTool(BaseTool):
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
args_schema: Type[BaseModel] = QdrantToolSchema
query: Optional[str] = None
filter_by: Optional[str] = None
filter_value: Optional[str] = None
filter_conditions: Optional[list[tuple[str, Any]]] = []
search_filter: Optional[Filter] = None
collection_name: Optional[str] = None
limit: Optional[int] = Field(default=3)
score_threshold: float = Field(default=0.35)
Expand Down Expand Up @@ -101,6 +100,18 @@ def __init__(self, **kwargs):
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
"Please install it with: uv add qdrant-client"
)
if kwargs.get("filter_conditions"):
must_conditions = []
for filter_condition in kwargs.get("filter_conditions"):
must_conditions.append(
FieldCondition(
key=filter_condition[0],
match=MatchValue(value=filter_condition[1]),
)
)
self.search_filter = Filter(must=must_conditions)
else:
self.search_filter = None

def _run(
self,
Expand All @@ -126,14 +137,26 @@ def _run(
if not self.qdrant_url:
raise ValueError("QDRANT_URL is not set")

# Create filter if filter parameters are provided
search_filter = None
# If filter_by and filter_value are provided, add them to the search filter
# without modifying the original search filter
search_filter = self.search_filter.copy() if self.search_filter else None
if filter_by and filter_value:
search_filter = Filter(
must=[
if (
search_filter
and hasattr(search_filter, "must")
and isinstance(search_filter.must, list)
):
search_filter.must.append(
FieldCondition(key=filter_by, match=MatchValue(value=filter_value))
]
)
)
else:
search_filter = Filter(
must=[
FieldCondition(
key=filter_by, match=MatchValue(value=filter_value)
)
]
)

# Search in Qdrant using the built-in query method
query_vector = (
Expand All @@ -151,12 +174,11 @@ def _run(

# Format results similar to storage implementation
results = []
# Extract the list of ScoredPoint objects from the tuple
for point in search_results:
for point in search_results.points:
result = {
"metadata": point[1][0].payload.get("metadata", {}),
"context": point[1][0].payload.get("text", ""),
"distance": point[1][0].score,
"distance": point.score,
"metadata": point.payload.get("metadata", {}),
"context": point.payload.get("text", ""),
}
results.append(result)

Expand Down
Loading