Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,19 @@ def _apply_sparse_embeddings_to_metadatas(
metadatas: Optional[List[Metadata]],
documents: Optional[List[Document]] = None,
) -> Optional[List[Metadata]]:
if metadatas is None:
return None

sparse_targets = self._get_sparse_embedding_targets()
if not sparse_targets:
return metadatas

# If no metadatas provided, create empty dicts based on documents length
if metadatas is None:
if documents is None:
return None
metadatas = [{} for _ in range(len(documents))]

# Create copies, converting None to empty dict
updated_metadatas: List[Dict[str, Any]] = [
dict(metadata) for metadata in metadatas
dict(metadata) if metadata is not None else {} for metadata in metadatas
]

documents_list = list(documents) if documents is not None else None
Expand All @@ -607,26 +611,60 @@ def _apply_sparse_embeddings_to_metadatas(
embedding_func = cast(SparseEmbeddingFunction[Any], embedding_func)
validate_sparse_embedding_function(embedding_func)

# Initialize collection lists for batch processing
inputs: List[str] = []
positions: List[int] = []

# Handle special case: source_key is "#document"
if source_key == DOCUMENT_KEY:
if documents_list is None:
continue

# Collect documents that need embedding
for idx, metadata in enumerate(updated_metadatas):
# Skip if target already exists in metadata
if target_key in metadata:
continue

# Get document at this position
if idx < len(documents_list):
doc = documents_list[idx]
if isinstance(doc, str):
inputs.append(doc)
positions.append(idx)

# Generate embeddings for all collected documents
if len(inputs) == 0:
continue

sparse_embeddings = self._sparse_embed(
input=inputs,
sparse_embedding_function=embedding_func,
)

if len(sparse_embeddings) != len(positions):
raise ValueError(
"Sparse embedding function returned unexpected number of embeddings."
)

for position, embedding in zip(positions, sparse_embeddings):
updated_metadatas[position][target_key] = embedding

continue # Skip the metadata-based logic below

# Handle normal case: source_key is a metadata field
for idx, metadata in enumerate(updated_metadatas):
if target_key in metadata:
continue

if source_key == DOCUMENT_KEY:
source_value = None
if documents_list is not None and idx < len(documents_list):
source_value = documents_list[idx]
else:
source_value = metadata.get(source_key)
source_value = metadata.get(source_key)
if not isinstance(source_value, str):
continue

inputs.append(source_value)
positions.append(idx)

if not inputs:
if len(inputs) == 0:
continue

sparse_embeddings = self._sparse_embed(
Expand All @@ -642,8 +680,13 @@ def _apply_sparse_embeddings_to_metadatas(
for position, embedding in zip(positions, sparse_embeddings):
updated_metadatas[position][target_key] = embedding

validate_metadatas(cast(List[Metadata], updated_metadatas))
return cast(List[Metadata], updated_metadatas)
# Convert empty dicts back to None, validation requires non-empty dicts or None
result_metadatas: List[Optional[Metadata]] = [
metadata if metadata else None for metadata in updated_metadatas
]

validate_metadatas(cast(List[Metadata], result_metadatas))
return cast(List[Metadata], result_metadatas)

def _embed_record_set(
self,
Expand Down
Loading
Loading