Skip to content

Commit 97fc56f

Browse files
jairad26sanketkedia
authored andcommitted
[ENH] Add local support for schema (#5714)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - This PR adds support for schema in sqlite sysdb, correctly reconciling with schema, legacy metadata, and supporting configuration updates. It also adds support for passing schema via bindings, to allow for local chroma support. It also updates cli usage of to allow copying of schema - New functionality - ... ## Test plan _How are these changes tested?_ expanded schema e2e tests to ensure bindings and single node all work as intended - [ x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent c424ee6 commit 97fc56f

File tree

27 files changed

+730
-192
lines changed

27 files changed

+730
-192
lines changed

chromadb/api/collection_configuration.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -812,21 +812,26 @@ def update_schema_from_collection_configuration(
812812
Returns:
813813
Updated Schema object
814814
"""
815-
# TODO: Remove this check once schema is enabled in local.
816-
if schema.defaults.float_list is None:
817-
return schema
818815

819816
# Get the vector index from defaults and #embedding key
820-
if schema.defaults.float_list is None or schema.defaults.float_list.vector_index is None:
817+
if (
818+
schema.defaults.float_list is None
819+
or schema.defaults.float_list.vector_index is None
820+
):
821821
raise ValueError("Schema is missing defaults.float_list.vector_index")
822822

823823
embedding_key = "#embedding"
824824
if embedding_key not in schema.keys:
825825
raise ValueError(f"Schema is missing keys[{embedding_key}]")
826826

827827
embedding_value_types = schema.keys[embedding_key]
828-
if embedding_value_types.float_list is None or embedding_value_types.float_list.vector_index is None:
829-
raise ValueError(f"Schema is missing keys[{embedding_key}].float_list.vector_index")
828+
if (
829+
embedding_value_types.float_list is None
830+
or embedding_value_types.float_list.vector_index is None
831+
):
832+
raise ValueError(
833+
f"Schema is missing keys[{embedding_key}].float_list.vector_index"
834+
)
830835

831836
# Update vector index config in both locations
832837
for vector_index in [
@@ -868,7 +873,10 @@ def update_schema_from_collection_configuration(
868873
spann_config.ef_search = update_spann["ef_search"]
869874

870875
# Update embedding function if present
871-
if "embedding_function" in configuration and configuration["embedding_function"] is not None:
876+
if (
877+
"embedding_function" in configuration
878+
and configuration["embedding_function"] is not None
879+
):
872880
vector_index.config.embedding_function = configuration["embedding_function"]
873881

874882
return schema

chromadb/api/rust.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def list_collections(
193193
CollectionModel(
194194
id=collection.id,
195195
name=collection.name,
196-
serialized_schema=None,
196+
serialized_schema=collection.schema,
197197
configuration_json=collection.configuration,
198198
metadata=collection.metadata,
199199
dimension=collection.dimension,
@@ -229,14 +229,25 @@ def create_collection(
229229
else:
230230
configuration_json_str = None
231231

232+
if schema:
233+
schema_str = json.dumps(schema.serialize_to_json())
234+
else:
235+
schema_str = None
236+
232237
collection = self.bindings.create_collection(
233-
name, configuration_json_str, metadata, get_or_create, tenant, database
238+
name,
239+
configuration_json_str,
240+
schema_str,
241+
metadata,
242+
get_or_create,
243+
tenant,
244+
database,
234245
)
235246
collection_model = CollectionModel(
236247
id=collection.id,
237248
name=collection.name,
238249
configuration_json=collection.configuration,
239-
serialized_schema=None,
250+
serialized_schema=collection.schema,
240251
metadata=collection.metadata,
241252
dimension=collection.dimension,
242253
tenant=collection.tenant,
@@ -256,7 +267,7 @@ def get_collection(
256267
id=collection.id,
257268
name=collection.name,
258269
configuration_json=collection.configuration,
259-
serialized_schema=None,
270+
serialized_schema=collection.schema,
260271
metadata=collection.metadata,
261272
dimension=collection.dimension,
262273
tenant=collection.tenant,

chromadb/chromadb_rust_bindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Bindings:
9595
self,
9696
name: str,
9797
configuration_json_str: Optional[str] = None,
98+
schema_str: Optional[str] = None,
9899
metadata: Optional[CollectionMetadata] = None,
99100
get_or_create: bool = False,
100101
tenant: str = DEFAULT_TENANT,

chromadb/test/api/test_schema_e2e.py

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_spann_disabled_mode,
2020
skip_if_not_cluster,
2121
skip_reason_spann_disabled,
22+
skip_reason_spann_enabled,
2223
)
2324
from chromadb.test.utils.wait_for_version_increase import (
2425
get_collection_version,
@@ -29,7 +30,7 @@
2930
register_sparse_embedding_function,
3031
)
3132
from chromadb.api.models.Collection import Collection
32-
from chromadb.errors import InvalidArgumentError
33+
from chromadb.errors import InvalidArgumentError, InternalError
3334
from chromadb.execution.expression import Knn, Search
3435
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
3536
from uuid import uuid4
@@ -94,8 +95,7 @@ def build_from_config(config: Dict[str, Any]) -> "RecordingSearchEmbeddingFuncti
9495
return RecordingSearchEmbeddingFunction(config.get("label", "default"))
9596

9697

97-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
98-
def test_schema_spann_vector_config_persistence(
98+
def test_schema_vector_config_persistence(
9999
client_factories: "ClientFactories",
100100
) -> None:
101101
"""Ensure schema-provided SPANN settings persist across client restarts."""
@@ -136,29 +136,46 @@ def test_schema_spann_vector_config_persistence(
136136
assert vector_index is not None
137137
assert vector_index.enabled is True
138138
assert vector_index.config is not None
139-
assert vector_index.config.spann is not None
140-
spann_config = vector_index.config.spann
141-
assert spann_config.search_nprobe == 16
142-
assert spann_config.write_nprobe == 32
143-
assert spann_config.ef_construction == 120
144-
assert spann_config.max_neighbors == 24
139+
140+
if not is_spann_disabled_mode:
141+
assert vector_index.config.spann is not None
142+
spann_config = vector_index.config.spann
143+
assert spann_config.search_nprobe == 16
144+
assert spann_config.write_nprobe == 32
145+
assert spann_config.ef_construction == 120
146+
assert spann_config.max_neighbors == 24
147+
else:
148+
assert vector_index.config.spann is None
149+
assert vector_index.config.hnsw is not None
150+
hnsw_config = vector_index.config.hnsw
151+
assert hnsw_config.ef_construction == 100
152+
assert hnsw_config.ef_search == 100
153+
assert hnsw_config.max_neighbors == 16
154+
assert hnsw_config.resize_factor == 1.2
145155

146156
ef = vector_index.config.embedding_function
147157
assert ef is not None
148158
assert ef.name() == "simple_ef"
149159
assert ef.get_config() == {"dim": 6}
150160

151161
persisted_json = persisted_schema.serialize_to_json()
152-
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
153-
"config"
154-
]["spann"]
155-
assert spann_json["search_nprobe"] == 16
156-
assert spann_json["write_nprobe"] == 32
162+
if not is_spann_disabled_mode:
163+
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
164+
"config"
165+
]["spann"]
166+
assert spann_json["search_nprobe"] == 16
167+
assert spann_json["write_nprobe"] == 32
168+
else:
169+
hnsw_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
170+
"config"
171+
]["hnsw"]
172+
assert hnsw_json["ef_construction"] == 100
173+
assert hnsw_json["ef_search"] == 100
174+
assert hnsw_json["max_neighbors"] == 16
157175

158176
client_reloaded = client_factories.create_client_from_system()
159177
reloaded_collection = client_reloaded.get_collection(
160178
name=collection_name,
161-
embedding_function=SimpleEmbeddingFunction(dim=6), # type: ignore[arg-type]
162179
)
163180

164181
reloaded_schema = reloaded_collection.schema
@@ -168,9 +185,23 @@ def test_schema_spann_vector_config_persistence(
168185
reloaded_vector_index = reloaded_embedding_override.vector_index
169186
assert reloaded_vector_index is not None
170187
assert reloaded_vector_index.config is not None
171-
assert reloaded_vector_index.config.spann is not None
172-
assert reloaded_vector_index.config.spann.search_nprobe == 16
173-
assert reloaded_vector_index.config.spann.write_nprobe == 32
188+
if not is_spann_disabled_mode:
189+
assert reloaded_vector_index.config.spann is not None
190+
assert reloaded_vector_index.config.spann.search_nprobe == 16
191+
assert reloaded_vector_index.config.spann.write_nprobe == 32
192+
else:
193+
assert reloaded_vector_index.config.hnsw is not None
194+
assert reloaded_vector_index.config.hnsw.ef_construction == 100
195+
assert reloaded_vector_index.config.hnsw.ef_search == 100
196+
assert reloaded_vector_index.config.hnsw.max_neighbors == 16
197+
assert reloaded_vector_index.config.hnsw.resize_factor == 1.2
198+
199+
config = reloaded_collection.configuration
200+
assert config is not None
201+
config_ef = config.get("embedding_function")
202+
assert config_ef is not None
203+
assert config_ef.name() == "simple_ef"
204+
assert config_ef.get_config() == {"dim": 6}
174205

175206

176207
@register_sparse_embedding_function
@@ -258,7 +289,6 @@ def _collect_knn_queries(rank: Any) -> List[Any]:
258289
return queries
259290

260291

261-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
262292
def test_schema_defaults_enable_indexed_operations(
263293
client_factories: "ClientFactories",
264294
) -> None:
@@ -336,7 +366,6 @@ def test_schema_defaults_enable_indexed_operations(
336366
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
337367

338368

339-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
340369
def test_get_or_create_and_get_collection_preserve_schema(
341370
client_factories: "ClientFactories",
342371
) -> None:
@@ -379,7 +408,6 @@ def test_get_or_create_and_get_collection_preserve_schema(
379408
assert set(stored["ids"]) == {"schema-preserve"}
380409

381410

382-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
383411
def test_delete_collection_resets_schema_configuration(
384412
client_factories: "ClientFactories",
385413
) -> None:
@@ -408,6 +436,19 @@ def test_delete_collection_resets_schema_configuration(
408436
assert set(recreated_json["keys"].keys()) == set(baseline_json["keys"].keys())
409437

410438

439+
@pytest.mark.skipif(not is_spann_disabled_mode, reason=skip_reason_spann_enabled)
440+
def test_sparse_vector_not_allowed_locally(
441+
client_factories: "ClientFactories",
442+
) -> None:
443+
"""Sparse vector configs are not allowed to be created locally."""
444+
schema = Schema()
445+
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
446+
with pytest.raises(
447+
InternalError, match="Sparse vector indexing is not enabled in local"
448+
):
449+
_create_isolated_collection(client_factories, schema=schema)
450+
451+
411452
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
412453
def test_sparse_vector_source_key_and_index_constraints(
413454
client_factories: "ClientFactories",
@@ -466,7 +507,6 @@ def test_sparse_vector_source_key_and_index_constraints(
466507
assert set(string_filter["ids"]) == {"sparse-1"}
467508

468509

469-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
470510
def test_schema_persistence_with_custom_overrides(
471511
client_factories: "ClientFactories",
472512
) -> None:
@@ -507,7 +547,6 @@ def test_schema_persistence_with_custom_overrides(
507547
assert set(fetched["ids"]) == {"persist-1"}
508548

509549

510-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
511550
def test_collection_embed_uses_schema_or_collection_embedding_function(
512551
client_factories: "ClientFactories",
513552
) -> None:
@@ -594,7 +633,6 @@ def test_search_embeds_string_queries_in_nested_ranks(
594633
assert all(isinstance(query, list) for query in all_queries)
595634

596635

597-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
598636
def test_schema_delete_index_and_restore(
599637
client_factories: "ClientFactories",
600638
) -> None:
@@ -651,6 +689,52 @@ def test_schema_delete_index_and_restore(
651689
assert set(search["ids"]) == {"key-enabled"}
652690

653691

692+
def test_disabled_metadata_index_filters_raise_invalid_argument_all_modes(
693+
client_factories: "ClientFactories",
694+
) -> None:
695+
"""Disabled metadata inverted index should block filter-based operations in get, query, and delete for local, single node, and distributed."""
696+
schema = Schema().delete_index(
697+
key="restricted_tag", config=StringInvertedIndexConfig()
698+
)
699+
collection, _ = _create_isolated_collection(client_factories, schema=schema)
700+
701+
collection.add(
702+
ids=["restricted-doc"],
703+
embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
704+
metadatas=[{"restricted_tag": "blocked"}],
705+
documents=["doc"],
706+
)
707+
708+
assert collection.schema is not None
709+
schema_entry = collection.schema.keys["restricted_tag"].string
710+
assert schema_entry is not None
711+
index_config = schema_entry.string_inverted_index
712+
assert index_config is not None
713+
assert index_config.enabled is False
714+
715+
filter_payload: Dict[str, Any] = {"restricted_tag": "blocked"}
716+
717+
def _expect_disabled_error(operation: Callable[[], Any]) -> None:
718+
with pytest.raises(InvalidArgumentError) as exc_info:
719+
operation()
720+
assert "Cannot filter using metadata key 'restricted_tag'" in str(
721+
exc_info.value
722+
)
723+
724+
operations: List[Callable[[], Any]] = [
725+
lambda: collection.get(where=filter_payload),
726+
lambda: collection.query(
727+
query_embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
728+
n_results=1,
729+
where=filter_payload,
730+
),
731+
lambda: collection.delete(where=filter_payload),
732+
]
733+
734+
for operation in operations:
735+
_expect_disabled_error(operation)
736+
737+
654738
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
655739
def test_disabled_metadata_index_filters_raise_invalid_argument(
656740
client_factories: "ClientFactories",
@@ -1548,7 +1632,6 @@ def test_sparse_auto_embedding_with_empty_documents(
15481632
assert "empty_sparse" in metadata
15491633

15501634

1551-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
15521635
def test_default_embedding_function_when_no_schema_provided(
15531636
client_factories: "ClientFactories",
15541637
) -> None:
@@ -1604,7 +1687,6 @@ def test_default_embedding_function_when_no_schema_provided(
16041687
assert sparse_ef_config["type"] == "unknown"
16051688

16061689

1607-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
16081690
def test_custom_embedding_function_without_schema(
16091691
client_factories: "ClientFactories",
16101692
) -> None:
@@ -1679,7 +1761,6 @@ def test_custom_embedding_function_without_schema(
16791761
assert len(result["embeddings"][0]) == 8
16801762

16811763

1682-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
16831764
def test_custom_embedding_function_with_default_schema(
16841765
client_factories: "ClientFactories",
16851766
) -> None:
@@ -1755,7 +1836,6 @@ def test_custom_embedding_function_with_default_schema(
17551836
assert len(result["embeddings"][0]) == 8
17561837

17571838

1758-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
17591839
def test_conflicting_embedding_functions_in_schema_and_config_fails(
17601840
client_factories: "ClientFactories",
17611841
) -> None:

chromadb/test/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@
9191
skip_reason_spann_disabled = (
9292
"SPANN creation/modification disallowed in Rust bindings or integration test mode"
9393
)
94-
94+
skip_reason_spann_enabled = (
95+
"SPANN creation/modification allowed in Rust bindings or integration test mode"
96+
)
9597

9698

9799
def reset(api: BaseAPI) -> None:

0 commit comments

Comments
 (0)