Skip to content

Commit e7655bc

Browse files
committed
[ENH] Embed query strings in search api
1 parent 13791bd commit e7655bc

File tree

6 files changed

+371
-9
lines changed

6 files changed

+371
-9
lines changed

chromadb/api/models/AsyncCollection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ async def search(
344344
345345
# Single search
346346
result = await collection.search(search)
347-
347+
348348
# Multiple searches at once
349349
searches = [
350350
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
@@ -357,9 +357,14 @@ async def search(
357357
if searches_list is None:
358358
searches_list = []
359359

360+
# Embed any string queries in Knn objects
361+
embedded_searches = [
362+
self._embed_search_string_queries(search) for search in searches_list
363+
]
364+
360365
return await self._client._search(
361366
collection_id=self.id,
362-
searches=cast(List[Search], searches_list),
367+
searches=cast(List[Search], embedded_searches),
363368
tenant=self.tenant,
364369
database=self.database,
365370
)

chromadb/api/models/Collection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
2424
from chromadb.execution.expression.plan import Search
25-
from typing import cast, List
2625

2726
import logging
2827

@@ -362,9 +361,14 @@ def search(
362361
if searches_list is None:
363362
searches_list = []
364363

364+
# Embed any string queries in Knn objects
365+
embedded_searches = [
366+
self._embed_search_string_queries(search) for search in searches_list
367+
]
368+
365369
return self._client._search(
366370
collection_id=self.id,
367-
searches=cast(List[Search], searches_list),
371+
searches=cast(List[Search], embedded_searches),
368372
tenant=self.tenant,
369373
database=self.database,
370374
)

chromadb/api/models/CollectionCommon.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from uuid import UUID
1818

1919
from chromadb.api.types import (
20+
EMBEDDING_KEY,
2021
URI,
2122
Schema,
2223
SparseVectorIndexConfig,
@@ -741,3 +742,232 @@ def _sparse_embed(
741742
if is_query:
742743
return sparse_embedding_function.embed_query(input=input)
743744
return sparse_embedding_function(input=input)
745+
746+
def _embed_knn_string_queries(self, knn: Any) -> Any:
747+
"""Embed string queries in Knn objects using the appropriate embedding function.
748+
749+
Args:
750+
knn: A Knn object that may have a string query
751+
752+
Returns:
753+
A Knn object with the string query replaced by an embedding
754+
755+
Raises:
756+
ValueError: If the query is a string but no embedding function is available
757+
"""
758+
from chromadb.execution.expression.operator import Knn
759+
760+
if not isinstance(knn, Knn):
761+
return knn
762+
763+
# If query is not a string, nothing to do
764+
if not isinstance(knn.query, str):
765+
return knn
766+
767+
query_text = knn.query
768+
key = knn.key
769+
770+
# Handle main embedding field
771+
if key == EMBEDDING_KEY:
772+
# Use the collection's main embedding function
773+
embedding = self._embed(input=[query_text], is_query=True)
774+
if not embedding or len(embedding) != 1:
775+
raise ValueError(
776+
"Embedding function returned unexpected number of embeddings"
777+
)
778+
# Return a new Knn with the embedded query
779+
return Knn(
780+
query=embedding[0],
781+
key=knn.key,
782+
limit=knn.limit,
783+
default=knn.default,
784+
return_rank=knn.return_rank,
785+
)
786+
787+
# Handle metadata field with potential sparse embedding
788+
schema = self.schema
789+
if schema is None or key not in schema.key_overrides:
790+
raise ValueError(
791+
f"Cannot embed string query for key '{key}': "
792+
f"key not found in schema. Please provide an embedded vector or "
793+
f"configure an embedding function for this key in the schema."
794+
)
795+
796+
value_type = schema.key_overrides[key]
797+
798+
# Check for sparse vector with embedding function
799+
if value_type.sparse_vector is not None:
800+
sparse_index = value_type.sparse_vector.sparse_vector_index
801+
if sparse_index is not None and sparse_index.enabled:
802+
config = sparse_index.config
803+
if config.embedding_function is not None:
804+
embedding_func = config.embedding_function
805+
if not isinstance(embedding_func, SparseEmbeddingFunction):
806+
embedding_func = cast(
807+
SparseEmbeddingFunction[Any], embedding_func
808+
)
809+
validate_sparse_embedding_function(embedding_func)
810+
811+
# Embed the query
812+
sparse_embedding = self._sparse_embed(
813+
input=[query_text],
814+
sparse_embedding_function=embedding_func,
815+
is_query=True,
816+
)
817+
818+
if not sparse_embedding or len(sparse_embedding) != 1:
819+
raise ValueError(
820+
"Sparse embedding function returned unexpected number of embeddings"
821+
)
822+
823+
# Return a new Knn with the sparse embedding
824+
return Knn(
825+
query=sparse_embedding[0],
826+
key=knn.key,
827+
limit=knn.limit,
828+
default=knn.default,
829+
return_rank=knn.return_rank,
830+
)
831+
832+
# Check for dense vector with embedding function (float_list)
833+
if value_type.float_list is not None:
834+
vector_index = value_type.float_list.vector_index
835+
if vector_index is not None and vector_index.enabled:
836+
config = vector_index.config
837+
if config.embedding_function is not None:
838+
embedding_func = config.embedding_function
839+
validate_embedding_function(embedding_func)
840+
841+
# Embed the query using the schema's embedding function
842+
try:
843+
embeddings = embedding_func.embed_query(input=[query_text])
844+
except AttributeError:
845+
# Fallback if embed_query doesn't exist
846+
embeddings = embedding_func([query_text])
847+
848+
if not embeddings or len(embeddings) != 1:
849+
raise ValueError(
850+
"Embedding function returned unexpected number of embeddings"
851+
)
852+
853+
# Return a new Knn with the dense embedding
854+
return Knn(
855+
query=embeddings[0],
856+
key=knn.key,
857+
limit=knn.limit,
858+
default=knn.default,
859+
return_rank=knn.return_rank,
860+
)
861+
862+
raise ValueError(
863+
f"Cannot embed string query for key '{key}': "
864+
f"no embedding function configured for this key in the schema. "
865+
f"Please provide an embedded vector or configure an embedding function."
866+
)
867+
868+
def _embed_rank_string_queries(self, rank: Any) -> Any:
869+
"""Recursively embed string queries in Rank expressions.
870+
871+
Args:
872+
rank: A Rank expression that may contain Knn objects with string queries
873+
874+
Returns:
875+
A Rank expression with all string queries embedded
876+
"""
877+
# Import here to avoid circular dependency
878+
from chromadb.execution.expression.operator import (
879+
Knn,
880+
Abs,
881+
Div,
882+
Exp,
883+
Log,
884+
Max,
885+
Min,
886+
Mul,
887+
Sub,
888+
Sum,
889+
Val,
890+
Rrf,
891+
)
892+
893+
if rank is None:
894+
return None
895+
896+
# Base case: Knn - embed if it has a string query
897+
if isinstance(rank, Knn):
898+
return self._embed_knn_string_queries(rank)
899+
900+
# Base case: Val - no embedding needed
901+
if isinstance(rank, Val):
902+
return rank
903+
904+
# Recursive cases: walk through child ranks
905+
if isinstance(rank, Abs):
906+
return Abs(self._embed_rank_string_queries(rank.rank))
907+
908+
if isinstance(rank, Div):
909+
return Div(
910+
self._embed_rank_string_queries(rank.left),
911+
self._embed_rank_string_queries(rank.right),
912+
)
913+
914+
if isinstance(rank, Exp):
915+
return Exp(self._embed_rank_string_queries(rank.rank))
916+
917+
if isinstance(rank, Log):
918+
return Log(self._embed_rank_string_queries(rank.rank))
919+
920+
if isinstance(rank, Max):
921+
return Max([self._embed_rank_string_queries(r) for r in rank.ranks])
922+
923+
if isinstance(rank, Min):
924+
return Min([self._embed_rank_string_queries(r) for r in rank.ranks])
925+
926+
if isinstance(rank, Mul):
927+
return Mul([self._embed_rank_string_queries(r) for r in rank.ranks])
928+
929+
if isinstance(rank, Sub):
930+
return Sub(
931+
self._embed_rank_string_queries(rank.left),
932+
self._embed_rank_string_queries(rank.right),
933+
)
934+
935+
if isinstance(rank, Sum):
936+
return Sum([self._embed_rank_string_queries(r) for r in rank.ranks])
937+
938+
if isinstance(rank, Rrf):
939+
return Rrf(
940+
ranks=[self._embed_rank_string_queries(r) for r in rank.ranks],
941+
k=rank.k,
942+
weights=rank.weights,
943+
normalize=rank.normalize,
944+
)
945+
946+
# Unknown rank type - return as is
947+
return rank
948+
949+
def _embed_search_string_queries(self, search: Any) -> Any:
950+
"""Embed string queries in a Search object.
951+
952+
Args:
953+
search: A Search object that may contain Knn objects with string queries
954+
955+
Returns:
956+
A Search object with all string queries embedded
957+
"""
958+
# Import here to avoid circular dependency
959+
from chromadb.execution.expression.plan import Search
960+
961+
if not isinstance(search, Search):
962+
return search
963+
964+
# Embed the rank expression if it exists
965+
embedded_rank = self._embed_rank_string_queries(search._rank)
966+
967+
# Create a new Search with the embedded rank
968+
return Search(
969+
where=search._where,
970+
rank=embedded_rank,
971+
limit=search._limit,
972+
select=search._select,
973+
)

chromadb/execution/expression/operator.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass, field
2-
from enum import Enum
32
from typing import Optional, List, Dict, Set, Any, Union
43

54
import numpy as np
@@ -1009,7 +1008,10 @@ class Knn(Rank):
10091008
"""KNN-based ranking
10101009
10111010
Args:
1012-
query: The query vector for KNN search (dense, sparse, or numpy array)
1011+
query: The query for KNN search. Can be:
1012+
- A string (will be automatically embedded using the collection's embedding function)
1013+
- A dense vector (list or numpy array)
1014+
- A sparse vector (SparseVector dict)
10131015
key: The embedding key to search against. Can be:
10141016
- "#embedding" (default) - searches the main embedding field
10151017
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
@@ -1018,16 +1020,23 @@ class Knn(Rank):
10181020
return_rank: If True, return the rank position (0, 1, 2, ...) instead of distance (default: False)
10191021
10201022
Examples:
1021-
# Search main embeddings (equivalent forms)
1023+
# Search with string query (automatically embedded)
1024+
Knn(query="hello world") # Will use collection's embedding function
1025+
1026+
# Search main embeddings with vectors (equivalent forms)
10221027
Knn(query=[0.1, 0.2]) # Uses default key="#embedding"
10231028
Knn(query=[0.1, 0.2], key=K.EMBEDDING)
10241029
Knn(query=[0.1, 0.2], key="#embedding")
10251030
1026-
# Search sparse embeddings stored in metadata
1031+
# Search sparse embeddings stored in metadata with string
1032+
Knn(query="hello world", key="custom_embedding") # Will use schema's embedding function
1033+
1034+
# Search sparse embeddings stored in metadata with vector
10271035
Knn(query=my_vector, key="custom_embedding") # Example: searches a metadata field
10281036
"""
10291037

10301038
query: Union[
1039+
str,
10311040
List[float],
10321041
SparseVector,
10331042
"NDArray[np.float32]",

0 commit comments

Comments
 (0)