Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,21 +812,26 @@ def update_schema_from_collection_configuration(
Returns:
Updated Schema object
"""
# TODO: Remove this check once schema is enabled in local.
if schema.defaults.float_list is None:
return schema

# Get the vector index from defaults and #embedding key
if schema.defaults.float_list is None or schema.defaults.float_list.vector_index is None:
if (
schema.defaults.float_list is None
or schema.defaults.float_list.vector_index is None
):
raise ValueError("Schema is missing defaults.float_list.vector_index")

embedding_key = "#embedding"
if embedding_key not in schema.keys:
raise ValueError(f"Schema is missing keys[{embedding_key}]")

embedding_value_types = schema.keys[embedding_key]
if embedding_value_types.float_list is None or embedding_value_types.float_list.vector_index is None:
raise ValueError(f"Schema is missing keys[{embedding_key}].float_list.vector_index")
if (
embedding_value_types.float_list is None
or embedding_value_types.float_list.vector_index is None
):
raise ValueError(
f"Schema is missing keys[{embedding_key}].float_list.vector_index"
)

# Update vector index config in both locations
for vector_index in [
Expand Down Expand Up @@ -868,7 +873,10 @@ def update_schema_from_collection_configuration(
spann_config.ef_search = update_spann["ef_search"]

# Update embedding function if present
if "embedding_function" in configuration and configuration["embedding_function"] is not None:
if (
"embedding_function" in configuration
and configuration["embedding_function"] is not None
):
vector_index.config.embedding_function = configuration["embedding_function"]

return schema
19 changes: 15 additions & 4 deletions chromadb/api/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def list_collections(
CollectionModel(
id=collection.id,
name=collection.name,
serialized_schema=None,
serialized_schema=collection.schema,
configuration_json=collection.configuration,
metadata=collection.metadata,
dimension=collection.dimension,
Expand Down Expand Up @@ -229,14 +229,25 @@ def create_collection(
else:
configuration_json_str = None

if schema:
schema_str = json.dumps(schema.serialize_to_json())
Comment on lines +232 to +233
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Missing error handling for JSON serialization. The json.dumps() call can raise exceptions if the schema object contains non-serializable data. This will cause the function to crash rather than gracefully handling invalid schemas.

try:
    schema_str = json.dumps(schema.serialize_to_json())
except (TypeError, ValueError) as e:
    raise ValueError(f"Failed to serialize schema: {e}")
Context for Agents
[**BestPractice**]

Missing error handling for JSON serialization. The `json.dumps()` call can raise exceptions if the schema object contains non-serializable data. This will cause the function to crash rather than gracefully handling invalid schemas.

```python
try:
    schema_str = json.dumps(schema.serialize_to_json())
except (TypeError, ValueError) as e:
    raise ValueError(f"Failed to serialize schema: {e}")
```

File: chromadb/api/rust.py
Line: 233

else:
schema_str = None

collection = self.bindings.create_collection(
name, configuration_json_str, metadata, get_or_create, tenant, database
name,
configuration_json_str,
schema_str,
metadata,
get_or_create,
tenant,
database,
)
collection_model = CollectionModel(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=None,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
Expand All @@ -256,7 +267,7 @@ def get_collection(
id=collection.id,
name=collection.name,
configuration_json=collection.configuration,
serialized_schema=None,
serialized_schema=collection.schema,
metadata=collection.metadata,
dimension=collection.dimension,
tenant=collection.tenant,
Expand Down
1 change: 1 addition & 0 deletions chromadb/chromadb_rust_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Bindings:
self,
name: str,
configuration_json_str: Optional[str] = None,
schema_str: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
Expand Down
136 changes: 108 additions & 28 deletions chromadb/test/api/test_schema_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
is_spann_disabled_mode,
skip_if_not_cluster,
skip_reason_spann_disabled,
skip_reason_spann_enabled,
)
from chromadb.test.utils.wait_for_version_increase import (
get_collection_version,
Expand All @@ -29,7 +30,7 @@
register_sparse_embedding_function,
)
from chromadb.api.models.Collection import Collection
from chromadb.errors import InvalidArgumentError
from chromadb.errors import InvalidArgumentError, InternalError
from chromadb.execution.expression import Knn, Search
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
from uuid import uuid4
Expand Down Expand Up @@ -94,8 +95,7 @@ def build_from_config(config: Dict[str, Any]) -> "RecordingSearchEmbeddingFuncti
return RecordingSearchEmbeddingFunction(config.get("label", "default"))


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_schema_spann_vector_config_persistence(
def test_schema_vector_config_persistence(
client_factories: "ClientFactories",
) -> None:
"""Ensure schema-provided SPANN settings persist across client restarts."""
Expand Down Expand Up @@ -136,29 +136,46 @@ def test_schema_spann_vector_config_persistence(
assert vector_index is not None
assert vector_index.enabled is True
assert vector_index.config is not None
assert vector_index.config.spann is not None
spann_config = vector_index.config.spann
assert spann_config.search_nprobe == 16
assert spann_config.write_nprobe == 32
assert spann_config.ef_construction == 120
assert spann_config.max_neighbors == 24

if not is_spann_disabled_mode:
assert vector_index.config.spann is not None
spann_config = vector_index.config.spann
assert spann_config.search_nprobe == 16
assert spann_config.write_nprobe == 32
assert spann_config.ef_construction == 120
assert spann_config.max_neighbors == 24
else:
assert vector_index.config.spann is None
assert vector_index.config.hnsw is not None
hnsw_config = vector_index.config.hnsw
assert hnsw_config.ef_construction == 100
assert hnsw_config.ef_search == 100
assert hnsw_config.max_neighbors == 16
assert hnsw_config.resize_factor == 1.2

ef = vector_index.config.embedding_function
assert ef is not None
assert ef.name() == "simple_ef"
assert ef.get_config() == {"dim": 6}

persisted_json = persisted_schema.serialize_to_json()
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
"config"
]["spann"]
assert spann_json["search_nprobe"] == 16
assert spann_json["write_nprobe"] == 32
if not is_spann_disabled_mode:
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
"config"
]["spann"]
assert spann_json["search_nprobe"] == 16
assert spann_json["write_nprobe"] == 32
else:
hnsw_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
"config"
]["hnsw"]
assert hnsw_json["ef_construction"] == 100
assert hnsw_json["ef_search"] == 100
assert hnsw_json["max_neighbors"] == 16

client_reloaded = client_factories.create_client_from_system()
reloaded_collection = client_reloaded.get_collection(
name=collection_name,
embedding_function=SimpleEmbeddingFunction(dim=6), # type: ignore[arg-type]
)

reloaded_schema = reloaded_collection.schema
Expand All @@ -168,9 +185,23 @@ def test_schema_spann_vector_config_persistence(
reloaded_vector_index = reloaded_embedding_override.vector_index
assert reloaded_vector_index is not None
assert reloaded_vector_index.config is not None
assert reloaded_vector_index.config.spann is not None
assert reloaded_vector_index.config.spann.search_nprobe == 16
assert reloaded_vector_index.config.spann.write_nprobe == 32
if not is_spann_disabled_mode:
assert reloaded_vector_index.config.spann is not None
assert reloaded_vector_index.config.spann.search_nprobe == 16
assert reloaded_vector_index.config.spann.write_nprobe == 32
else:
assert reloaded_vector_index.config.hnsw is not None
assert reloaded_vector_index.config.hnsw.ef_construction == 100
assert reloaded_vector_index.config.hnsw.ef_search == 100
assert reloaded_vector_index.config.hnsw.max_neighbors == 16
assert reloaded_vector_index.config.hnsw.resize_factor == 1.2

config = reloaded_collection.configuration
assert config is not None
config_ef = config.get("embedding_function")
assert config_ef is not None
assert config_ef.name() == "simple_ef"
assert config_ef.get_config() == {"dim": 6}


@register_sparse_embedding_function
Expand Down Expand Up @@ -258,7 +289,6 @@ def _collect_knn_queries(rank: Any) -> List[Any]:
return queries


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_schema_defaults_enable_indexed_operations(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -336,7 +366,6 @@ def test_schema_defaults_enable_indexed_operations(
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_get_or_create_and_get_collection_preserve_schema(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -379,7 +408,6 @@ def test_get_or_create_and_get_collection_preserve_schema(
assert set(stored["ids"]) == {"schema-preserve"}


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_delete_collection_resets_schema_configuration(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -408,6 +436,19 @@ def test_delete_collection_resets_schema_configuration(
assert set(recreated_json["keys"].keys()) == set(baseline_json["keys"].keys())


@pytest.mark.skipif(not is_spann_disabled_mode, reason=skip_reason_spann_enabled)
def test_sparse_vector_not_allowed_locally(
client_factories: "ClientFactories",
) -> None:
"""Sparse vector configs are not allowed to be created locally."""
schema = Schema()
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
with pytest.raises(
InternalError, match="Sparse vector indexing is not enabled in local"
):
_create_isolated_collection(client_factories, schema=schema)


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_sparse_vector_source_key_and_index_constraints(
client_factories: "ClientFactories",
Expand Down Expand Up @@ -466,7 +507,6 @@ def test_sparse_vector_source_key_and_index_constraints(
assert set(string_filter["ids"]) == {"sparse-1"}


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_schema_persistence_with_custom_overrides(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -507,7 +547,6 @@ def test_schema_persistence_with_custom_overrides(
assert set(fetched["ids"]) == {"persist-1"}


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_collection_embed_uses_schema_or_collection_embedding_function(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -594,7 +633,6 @@ def test_search_embeds_string_queries_in_nested_ranks(
assert all(isinstance(query, list) for query in all_queries)


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_schema_delete_index_and_restore(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -651,6 +689,52 @@ def test_schema_delete_index_and_restore(
assert set(search["ids"]) == {"key-enabled"}


def test_disabled_metadata_index_filters_raise_invalid_argument_all_modes(
client_factories: "ClientFactories",
) -> None:
"""Disabled metadata inverted index should block filter-based operations in get, query, and delete for local, single node, and distributed."""
schema = Schema().delete_index(
key="restricted_tag", config=StringInvertedIndexConfig()
)
collection, _ = _create_isolated_collection(client_factories, schema=schema)

collection.add(
ids=["restricted-doc"],
embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
metadatas=[{"restricted_tag": "blocked"}],
documents=["doc"],
)

assert collection.schema is not None
schema_entry = collection.schema.keys["restricted_tag"].string
assert schema_entry is not None
index_config = schema_entry.string_inverted_index
assert index_config is not None
assert index_config.enabled is False

filter_payload: Dict[str, Any] = {"restricted_tag": "blocked"}

def _expect_disabled_error(operation: Callable[[], Any]) -> None:
with pytest.raises(InvalidArgumentError) as exc_info:
operation()
assert "Cannot filter using metadata key 'restricted_tag'" in str(
exc_info.value
)

operations: List[Callable[[], Any]] = [
lambda: collection.get(where=filter_payload),
lambda: collection.query(
query_embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
n_results=1,
where=filter_payload,
),
lambda: collection.delete(where=filter_payload),
]

for operation in operations:
_expect_disabled_error(operation)


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_disabled_metadata_index_filters_raise_invalid_argument(
client_factories: "ClientFactories",
Expand Down Expand Up @@ -1548,7 +1632,6 @@ def test_sparse_auto_embedding_with_empty_documents(
assert "empty_sparse" in metadata


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_default_embedding_function_when_no_schema_provided(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -1604,7 +1687,6 @@ def test_default_embedding_function_when_no_schema_provided(
assert sparse_ef_config["type"] == "unknown"


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_custom_embedding_function_without_schema(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -1679,7 +1761,6 @@ def test_custom_embedding_function_without_schema(
assert len(result["embeddings"][0]) == 8


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_custom_embedding_function_with_default_schema(
client_factories: "ClientFactories",
) -> None:
Expand Down Expand Up @@ -1755,7 +1836,6 @@ def test_custom_embedding_function_with_default_schema(
assert len(result["embeddings"][0]) == 8


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_conflicting_embedding_functions_in_schema_and_config_fails(
client_factories: "ClientFactories",
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion chromadb/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@
skip_reason_spann_disabled = (
"SPANN creation/modification disallowed in Rust bindings or integration test mode"
)

skip_reason_spann_enabled = (
"SPANN creation/modification allowed in Rust bindings or integration test mode"
)


def reset(api: BaseAPI) -> None:
Expand Down
Loading