Skip to content

Commit c8846eb

Browse files
jairad26itaismith
andauthored
[ENH] Embed query strings in search api (#5599)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - This PR adds support to embed string queries within Knn if provided & if embedding function can be found for the given key - New functionality - ... ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_ --------- Co-authored-by: Itai Smith <[email protected]>
1 parent cb8a461 commit c8846eb

File tree

10 files changed

+615
-30
lines changed

10 files changed

+615
-30
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

chromadb/api/models/AsyncCollection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ async def search(
344344
345345
# Single search
346346
result = await collection.search(search)
347-
347+
348348
# Multiple searches at once
349349
searches = [
350350
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
@@ -357,9 +357,14 @@ async def search(
357357
if searches_list is None:
358358
searches_list = []
359359

360+
# Embed any string queries in Knn objects
361+
embedded_searches = [
362+
self._embed_search_string_queries(search) for search in searches_list
363+
]
364+
360365
return await self._client._search(
361366
collection_id=self.id,
362-
searches=cast(List[Search], searches_list),
367+
searches=cast(List[Search], embedded_searches),
363368
tenant=self.tenant,
364369
database=self.database,
365370
)

chromadb/api/models/Collection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from chromadb.api.collection_configuration import UpdateCollectionConfiguration
2424
from chromadb.execution.expression.plan import Search
25-
from typing import cast, List
2625

2726
import logging
2827

@@ -362,9 +361,14 @@ def search(
362361
if searches_list is None:
363362
searches_list = []
364363

364+
# Embed any string queries in Knn objects
365+
embedded_searches = [
366+
self._embed_search_string_queries(search) for search in searches_list
367+
]
368+
365369
return self._client._search(
366370
collection_id=self.id,
367-
searches=cast(List[Search], searches_list),
371+
searches=cast(List[Search], embedded_searches),
368372
tenant=self.tenant,
369373
database=self.database,
370374
)

chromadb/api/models/CollectionCommon.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from uuid import UUID
1818

1919
from chromadb.api.types import (
20+
EMBEDDING_KEY,
2021
URI,
2122
Schema,
2223
SparseVectorIndexConfig,
@@ -784,3 +785,232 @@ def _sparse_embed(
784785
if is_query:
785786
return sparse_embedding_function.embed_query(input=input)
786787
return sparse_embedding_function(input=input)
788+
789+
def _embed_knn_string_queries(self, knn: Any) -> Any:
790+
"""Embed string queries in Knn objects using the appropriate embedding function.
791+
792+
Args:
793+
knn: A Knn object that may have a string query
794+
795+
Returns:
796+
A Knn object with the string query replaced by an embedding
797+
798+
Raises:
799+
ValueError: If the query is a string but no embedding function is available
800+
"""
801+
from chromadb.execution.expression.operator import Knn
802+
803+
if not isinstance(knn, Knn):
804+
return knn
805+
806+
# If query is not a string, nothing to do
807+
if not isinstance(knn.query, str):
808+
return knn
809+
810+
query_text = knn.query
811+
key = knn.key
812+
813+
# Handle main embedding field
814+
if key == EMBEDDING_KEY:
815+
# Use the collection's main embedding function
816+
embedding = self._embed(input=[query_text], is_query=True)
817+
if not embedding or len(embedding) != 1:
818+
raise ValueError(
819+
"Embedding function returned unexpected number of embeddings"
820+
)
821+
# Return a new Knn with the embedded query
822+
return Knn(
823+
query=embedding[0],
824+
key=knn.key,
825+
limit=knn.limit,
826+
default=knn.default,
827+
return_rank=knn.return_rank,
828+
)
829+
830+
# Handle metadata field with potential sparse embedding
831+
schema = self.schema
832+
if schema is None or key not in schema.key_overrides:
833+
raise ValueError(
834+
f"Cannot embed string query for key '{key}': "
835+
f"key not found in schema. Please provide an embedded vector or "
836+
f"configure an embedding function for this key in the schema."
837+
)
838+
839+
value_type = schema.key_overrides[key]
840+
841+
# Check for sparse vector with embedding function
842+
if value_type.sparse_vector is not None:
843+
sparse_index = value_type.sparse_vector.sparse_vector_index
844+
if sparse_index is not None and sparse_index.enabled:
845+
config = sparse_index.config
846+
if config.embedding_function is not None:
847+
embedding_func = config.embedding_function
848+
if not isinstance(embedding_func, SparseEmbeddingFunction):
849+
embedding_func = cast(
850+
SparseEmbeddingFunction[Any], embedding_func
851+
)
852+
validate_sparse_embedding_function(embedding_func)
853+
854+
# Embed the query
855+
sparse_embedding = self._sparse_embed(
856+
input=[query_text],
857+
sparse_embedding_function=embedding_func,
858+
is_query=True,
859+
)
860+
861+
if not sparse_embedding or len(sparse_embedding) != 1:
862+
raise ValueError(
863+
"Sparse embedding function returned unexpected number of embeddings"
864+
)
865+
866+
# Return a new Knn with the sparse embedding
867+
return Knn(
868+
query=sparse_embedding[0],
869+
key=knn.key,
870+
limit=knn.limit,
871+
default=knn.default,
872+
return_rank=knn.return_rank,
873+
)
874+
875+
# Check for dense vector with embedding function (float_list)
876+
if value_type.float_list is not None:
877+
vector_index = value_type.float_list.vector_index
878+
if vector_index is not None and vector_index.enabled:
879+
config = vector_index.config
880+
if config.embedding_function is not None:
881+
embedding_func = config.embedding_function
882+
validate_embedding_function(embedding_func)
883+
884+
# Embed the query using the schema's embedding function
885+
try:
886+
embeddings = embedding_func.embed_query(input=[query_text])
887+
except AttributeError:
888+
# Fallback if embed_query doesn't exist
889+
embeddings = embedding_func([query_text])
890+
891+
if not embeddings or len(embeddings) != 1:
892+
raise ValueError(
893+
"Embedding function returned unexpected number of embeddings"
894+
)
895+
896+
# Return a new Knn with the dense embedding
897+
return Knn(
898+
query=embeddings[0],
899+
key=knn.key,
900+
limit=knn.limit,
901+
default=knn.default,
902+
return_rank=knn.return_rank,
903+
)
904+
905+
raise ValueError(
906+
f"Cannot embed string query for key '{key}': "
907+
f"no embedding function configured for this key in the schema. "
908+
f"Please provide an embedded vector or configure an embedding function."
909+
)
910+
911+
def _embed_rank_string_queries(self, rank: Any) -> Any:
912+
"""Recursively embed string queries in Rank expressions.
913+
914+
Args:
915+
rank: A Rank expression that may contain Knn objects with string queries
916+
917+
Returns:
918+
A Rank expression with all string queries embedded
919+
"""
920+
# Import here to avoid circular dependency
921+
from chromadb.execution.expression.operator import (
922+
Knn,
923+
Abs,
924+
Div,
925+
Exp,
926+
Log,
927+
Max,
928+
Min,
929+
Mul,
930+
Sub,
931+
Sum,
932+
Val,
933+
Rrf,
934+
)
935+
936+
if rank is None:
937+
return None
938+
939+
# Base case: Knn - embed if it has a string query
940+
if isinstance(rank, Knn):
941+
return self._embed_knn_string_queries(rank)
942+
943+
# Base case: Val - no embedding needed
944+
if isinstance(rank, Val):
945+
return rank
946+
947+
# Recursive cases: walk through child ranks
948+
if isinstance(rank, Abs):
949+
return Abs(self._embed_rank_string_queries(rank.rank))
950+
951+
if isinstance(rank, Div):
952+
return Div(
953+
self._embed_rank_string_queries(rank.left),
954+
self._embed_rank_string_queries(rank.right),
955+
)
956+
957+
if isinstance(rank, Exp):
958+
return Exp(self._embed_rank_string_queries(rank.rank))
959+
960+
if isinstance(rank, Log):
961+
return Log(self._embed_rank_string_queries(rank.rank))
962+
963+
if isinstance(rank, Max):
964+
return Max([self._embed_rank_string_queries(r) for r in rank.ranks])
965+
966+
if isinstance(rank, Min):
967+
return Min([self._embed_rank_string_queries(r) for r in rank.ranks])
968+
969+
if isinstance(rank, Mul):
970+
return Mul([self._embed_rank_string_queries(r) for r in rank.ranks])
971+
972+
if isinstance(rank, Sub):
973+
return Sub(
974+
self._embed_rank_string_queries(rank.left),
975+
self._embed_rank_string_queries(rank.right),
976+
)
977+
978+
if isinstance(rank, Sum):
979+
return Sum([self._embed_rank_string_queries(r) for r in rank.ranks])
980+
981+
if isinstance(rank, Rrf):
982+
return Rrf(
983+
ranks=[self._embed_rank_string_queries(r) for r in rank.ranks],
984+
k=rank.k,
985+
weights=rank.weights,
986+
normalize=rank.normalize,
987+
)
988+
989+
# Unknown rank type - return as is
990+
return rank
991+
992+
def _embed_search_string_queries(self, search: Any) -> Any:
993+
"""Embed string queries in a Search object.
994+
995+
Args:
996+
search: A Search object that may contain Knn objects with string queries
997+
998+
Returns:
999+
A Search object with all string queries embedded
1000+
"""
1001+
# Import here to avoid circular dependency
1002+
from chromadb.execution.expression.plan import Search
1003+
1004+
if not isinstance(search, Search):
1005+
return search
1006+
1007+
# Embed the rank expression if it exists
1008+
embedded_rank = self._embed_rank_string_queries(search._rank)
1009+
1010+
# Create a new Search with the embedded rank
1011+
return Search(
1012+
where=search._where,
1013+
rank=embedded_rank,
1014+
limit=search._limit,
1015+
select=search._select,
1016+
)

chromadb/execution/expression/operator.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass, field
2-
from enum import Enum
32
from typing import Optional, List, Dict, Set, Any, Union
43

54
import numpy as np
@@ -904,6 +903,10 @@ def __abs__(self) -> "Abs":
904903
"""Absolute value: abs(rank)"""
905904
return Abs(self)
906905

906+
def abs(self) -> "Abs":
907+
"""Absolute value builder: rank.abs()"""
908+
return Abs(self)
909+
907910
# Builder methods for functions
908911
def exp(self) -> "Exp":
909912
"""Exponential: e^rank"""
@@ -1016,7 +1019,10 @@ class Knn(Rank):
10161019
"""KNN-based ranking
10171020
10181021
Args:
1019-
query: The query vector for KNN search (dense, sparse, or numpy array)
1022+
query: The query for KNN search. Can be:
1023+
- A string (will be automatically embedded using the collection's embedding function)
1024+
- A dense vector (list or numpy array)
1025+
- A sparse vector (SparseVector dict)
10201026
key: The embedding key to search against. Can be:
10211027
- "#embedding" (default) - searches the main embedding field
10221028
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
@@ -1025,16 +1031,23 @@ class Knn(Rank):
10251031
return_rank: If True, return the rank position (0, 1, 2, ...) instead of distance (default: False)
10261032
10271033
Examples:
1028-
# Search main embeddings (equivalent forms)
1034+
# Search with string query (automatically embedded)
1035+
Knn(query="hello world") # Will use collection's embedding function
1036+
1037+
# Search main embeddings with vectors (equivalent forms)
10291038
Knn(query=[0.1, 0.2]) # Uses default key="#embedding"
10301039
Knn(query=[0.1, 0.2], key=K.EMBEDDING)
10311040
Knn(query=[0.1, 0.2], key="#embedding")
10321041
1033-
# Search sparse embeddings stored in metadata
1042+
# Search sparse embeddings stored in metadata with string
1043+
Knn(query="hello world", key="custom_embedding") # Will use schema's embedding function
1044+
1045+
# Search sparse embeddings stored in metadata with vector
10341046
Knn(query=my_vector, key="custom_embedding") # Example: searches a metadata field
10351047
"""
10361048

10371049
query: Union[
1050+
str,
10381051
List[float],
10391052
SparseVector,
10401053
"NDArray[np.float32]",

0 commit comments

Comments
 (0)