Skip to content

Commit bee55b4

Browse files
committed
[ENH] Add local support for schema
1 parent 08fa6e7 commit bee55b4

File tree

26 files changed

+701
-176
lines changed

26 files changed

+701
-176
lines changed

chromadb/api/rust.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def list_collections(
193193
CollectionModel(
194194
id=collection.id,
195195
name=collection.name,
196-
serialized_schema=None,
196+
serialized_schema=collection.schema,
197197
configuration_json=collection.configuration,
198198
metadata=collection.metadata,
199199
dimension=collection.dimension,
@@ -229,14 +229,25 @@ def create_collection(
229229
else:
230230
configuration_json_str = None
231231

232+
if schema:
233+
schema_str = json.dumps(schema.serialize_to_json())
234+
else:
235+
schema_str = None
236+
232237
collection = self.bindings.create_collection(
233-
name, configuration_json_str, metadata, get_or_create, tenant, database
238+
name,
239+
configuration_json_str,
240+
schema_str,
241+
metadata,
242+
get_or_create,
243+
tenant,
244+
database,
234245
)
235246
collection_model = CollectionModel(
236247
id=collection.id,
237248
name=collection.name,
238249
configuration_json=collection.configuration,
239-
serialized_schema=None,
250+
serialized_schema=collection.schema,
240251
metadata=collection.metadata,
241252
dimension=collection.dimension,
242253
tenant=collection.tenant,
@@ -256,7 +267,7 @@ def get_collection(
256267
id=collection.id,
257268
name=collection.name,
258269
configuration_json=collection.configuration,
259-
serialized_schema=None,
270+
serialized_schema=collection.schema,
260271
metadata=collection.metadata,
261272
dimension=collection.dimension,
262273
tenant=collection.tenant,

chromadb/chromadb_rust_bindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Bindings:
9595
self,
9696
name: str,
9797
configuration_json_str: Optional[str] = None,
98+
schema_str: Optional[str] = None,
9899
metadata: Optional[CollectionMetadata] = None,
99100
get_or_create: bool = False,
100101
tenant: str = DEFAULT_TENANT,

chromadb/test/api/test_schema_e2e.py

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_spann_disabled_mode,
2020
skip_if_not_cluster,
2121
skip_reason_spann_disabled,
22+
skip_reason_spann_enabled,
2223
)
2324
from chromadb.test.utils.wait_for_version_increase import (
2425
get_collection_version,
@@ -29,7 +30,7 @@
2930
register_sparse_embedding_function,
3031
)
3132
from chromadb.api.models.Collection import Collection
32-
from chromadb.errors import InvalidArgumentError
33+
from chromadb.errors import InvalidArgumentError, InternalError
3334
from chromadb.execution.expression import Knn, Search
3435
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
3536
from uuid import uuid4
@@ -94,8 +95,7 @@ def build_from_config(config: Dict[str, Any]) -> "RecordingSearchEmbeddingFuncti
9495
return RecordingSearchEmbeddingFunction(config.get("label", "default"))
9596

9697

97-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
98-
def test_schema_spann_vector_config_persistence(
98+
def test_schema_vector_config_persistence(
9999
client_factories: "ClientFactories",
100100
) -> None:
101101
"""Ensure schema-provided SPANN settings persist across client restarts."""
@@ -136,29 +136,46 @@ def test_schema_spann_vector_config_persistence(
136136
assert vector_index is not None
137137
assert vector_index.enabled is True
138138
assert vector_index.config is not None
139-
assert vector_index.config.spann is not None
140-
spann_config = vector_index.config.spann
141-
assert spann_config.search_nprobe == 16
142-
assert spann_config.write_nprobe == 32
143-
assert spann_config.ef_construction == 120
144-
assert spann_config.max_neighbors == 24
139+
140+
if not is_spann_disabled_mode:
141+
assert vector_index.config.spann is not None
142+
spann_config = vector_index.config.spann
143+
assert spann_config.search_nprobe == 16
144+
assert spann_config.write_nprobe == 32
145+
assert spann_config.ef_construction == 120
146+
assert spann_config.max_neighbors == 24
147+
else:
148+
assert vector_index.config.spann is None
149+
assert vector_index.config.hnsw is not None
150+
hnsw_config = vector_index.config.hnsw
151+
assert hnsw_config.ef_construction == 100
152+
assert hnsw_config.ef_search == 100
153+
assert hnsw_config.max_neighbors == 16
154+
assert hnsw_config.resize_factor == 1.2
145155

146156
ef = vector_index.config.embedding_function
147157
assert ef is not None
148158
assert ef.name() == "simple_ef"
149159
assert ef.get_config() == {"dim": 6}
150160

151161
persisted_json = persisted_schema.serialize_to_json()
152-
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
153-
"config"
154-
]["spann"]
155-
assert spann_json["search_nprobe"] == 16
156-
assert spann_json["write_nprobe"] == 32
162+
if not is_spann_disabled_mode:
163+
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
164+
"config"
165+
]["spann"]
166+
assert spann_json["search_nprobe"] == 16
167+
assert spann_json["write_nprobe"] == 32
168+
else:
169+
hnsw_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
170+
"config"
171+
]["hnsw"]
172+
assert hnsw_json["ef_construction"] == 100
173+
assert hnsw_json["ef_search"] == 100
174+
assert hnsw_json["max_neighbors"] == 16
157175

158176
client_reloaded = client_factories.create_client_from_system()
159177
reloaded_collection = client_reloaded.get_collection(
160178
name=collection_name,
161-
embedding_function=SimpleEmbeddingFunction(dim=6), # type: ignore[arg-type]
162179
)
163180

164181
reloaded_schema = reloaded_collection.schema
@@ -168,9 +185,23 @@ def test_schema_spann_vector_config_persistence(
168185
reloaded_vector_index = reloaded_embedding_override.vector_index
169186
assert reloaded_vector_index is not None
170187
assert reloaded_vector_index.config is not None
171-
assert reloaded_vector_index.config.spann is not None
172-
assert reloaded_vector_index.config.spann.search_nprobe == 16
173-
assert reloaded_vector_index.config.spann.write_nprobe == 32
188+
if not is_spann_disabled_mode:
189+
assert reloaded_vector_index.config.spann is not None
190+
assert reloaded_vector_index.config.spann.search_nprobe == 16
191+
assert reloaded_vector_index.config.spann.write_nprobe == 32
192+
else:
193+
assert reloaded_vector_index.config.hnsw is not None
194+
assert reloaded_vector_index.config.hnsw.ef_construction == 100
195+
assert reloaded_vector_index.config.hnsw.ef_search == 100
196+
assert reloaded_vector_index.config.hnsw.max_neighbors == 16
197+
assert reloaded_vector_index.config.hnsw.resize_factor == 1.2
198+
199+
config = reloaded_collection.configuration
200+
assert config is not None
201+
config_ef = config.get("embedding_function")
202+
assert config_ef is not None
203+
assert config_ef.name() == "simple_ef"
204+
assert config_ef.get_config() == {"dim": 6}
174205

175206

176207
@register_sparse_embedding_function
@@ -258,7 +289,6 @@ def _collect_knn_queries(rank: Any) -> List[Any]:
258289
return queries
259290

260291

261-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
262292
def test_schema_defaults_enable_indexed_operations(
263293
client_factories: "ClientFactories",
264294
) -> None:
@@ -336,7 +366,6 @@ def test_schema_defaults_enable_indexed_operations(
336366
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
337367

338368

339-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
340369
def test_get_or_create_and_get_collection_preserve_schema(
341370
client_factories: "ClientFactories",
342371
) -> None:
@@ -379,7 +408,6 @@ def test_get_or_create_and_get_collection_preserve_schema(
379408
assert set(stored["ids"]) == {"schema-preserve"}
380409

381410

382-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
383411
def test_delete_collection_resets_schema_configuration(
384412
client_factories: "ClientFactories",
385413
) -> None:
@@ -408,6 +436,19 @@ def test_delete_collection_resets_schema_configuration(
408436
assert set(recreated_json["keys"].keys()) == set(baseline_json["keys"].keys())
409437

410438

439+
@pytest.mark.skipif(not is_spann_disabled_mode, reason=skip_reason_spann_enabled)
440+
def test_sparse_vector_not_allowed_locally(
441+
client_factories: "ClientFactories",
442+
) -> None:
443+
"""Sparse vector configs are not allowed to be created locally."""
444+
schema = Schema()
445+
schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig())
446+
with pytest.raises(
447+
InternalError, match="Sparse vector indexing is not enabled in local"
448+
):
449+
_create_isolated_collection(client_factories, schema=schema)
450+
451+
411452
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
412453
def test_sparse_vector_source_key_and_index_constraints(
413454
client_factories: "ClientFactories",
@@ -466,7 +507,6 @@ def test_sparse_vector_source_key_and_index_constraints(
466507
assert set(string_filter["ids"]) == {"sparse-1"}
467508

468509

469-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
470510
def test_schema_persistence_with_custom_overrides(
471511
client_factories: "ClientFactories",
472512
) -> None:
@@ -507,7 +547,6 @@ def test_schema_persistence_with_custom_overrides(
507547
assert set(fetched["ids"]) == {"persist-1"}
508548

509549

510-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
511550
def test_collection_embed_uses_schema_or_collection_embedding_function(
512551
client_factories: "ClientFactories",
513552
) -> None:
@@ -594,7 +633,6 @@ def test_search_embeds_string_queries_in_nested_ranks(
594633
assert all(isinstance(query, list) for query in all_queries)
595634

596635

597-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
598636
def test_schema_delete_index_and_restore(
599637
client_factories: "ClientFactories",
600638
) -> None:
@@ -651,6 +689,52 @@ def test_schema_delete_index_and_restore(
651689
assert set(search["ids"]) == {"key-enabled"}
652690

653691

692+
def test_disabled_metadata_index_filters_raise_invalid_argument_all_modes(
693+
client_factories: "ClientFactories",
694+
) -> None:
695+
"""Disabled metadata inverted index should block filter-based operations in get, query, and delete for local, single node, and distributed."""
696+
schema = Schema().delete_index(
697+
key="restricted_tag", config=StringInvertedIndexConfig()
698+
)
699+
collection, _ = _create_isolated_collection(client_factories, schema=schema)
700+
701+
collection.add(
702+
ids=["restricted-doc"],
703+
embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
704+
metadatas=[{"restricted_tag": "blocked"}],
705+
documents=["doc"],
706+
)
707+
708+
assert collection.schema is not None
709+
schema_entry = collection.schema.keys["restricted_tag"].string
710+
assert schema_entry is not None
711+
index_config = schema_entry.string_inverted_index
712+
assert index_config is not None
713+
assert index_config.enabled is False
714+
715+
filter_payload: Dict[str, Any] = {"restricted_tag": "blocked"}
716+
717+
def _expect_disabled_error(operation: Callable[[], Any]) -> None:
718+
with pytest.raises(InvalidArgumentError) as exc_info:
719+
operation()
720+
assert "Cannot filter using metadata key 'restricted_tag'" in str(
721+
exc_info.value
722+
)
723+
724+
operations: List[Callable[[], Any]] = [
725+
lambda: collection.get(where=filter_payload),
726+
lambda: collection.query(
727+
query_embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
728+
n_results=1,
729+
where=filter_payload,
730+
),
731+
lambda: collection.delete(where=filter_payload),
732+
]
733+
734+
for operation in operations:
735+
_expect_disabled_error(operation)
736+
737+
654738
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
655739
def test_disabled_metadata_index_filters_raise_invalid_argument(
656740
client_factories: "ClientFactories",
@@ -1548,7 +1632,6 @@ def test_sparse_auto_embedding_with_empty_documents(
15481632
assert "empty_sparse" in metadata
15491633

15501634

1551-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
15521635
def test_default_embedding_function_when_no_schema_provided(
15531636
client_factories: "ClientFactories",
15541637
) -> None:
@@ -1604,7 +1687,6 @@ def test_default_embedding_function_when_no_schema_provided(
16041687
assert sparse_ef_config["type"] == "unknown"
16051688

16061689

1607-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
16081690
def test_custom_embedding_function_without_schema(
16091691
client_factories: "ClientFactories",
16101692
) -> None:
@@ -1679,7 +1761,6 @@ def test_custom_embedding_function_without_schema(
16791761
assert len(result["embeddings"][0]) == 8
16801762

16811763

1682-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
16831764
def test_custom_embedding_function_with_default_schema(
16841765
client_factories: "ClientFactories",
16851766
) -> None:
@@ -1755,7 +1836,6 @@ def test_custom_embedding_function_with_default_schema(
17551836
assert len(result["embeddings"][0]) == 8
17561837

17571838

1758-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
17591839
def test_conflicting_embedding_functions_in_schema_and_config_fails(
17601840
client_factories: "ClientFactories",
17611841
) -> None:

chromadb/test/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@
9191
skip_reason_spann_disabled = (
9292
"SPANN creation/modification disallowed in Rust bindings or integration test mode"
9393
)
94-
94+
skip_reason_spann_enabled = (
95+
"SPANN creation/modification allowed in Rust bindings or integration test mode"
96+
)
9597

9698

9799
def reset(api: BaseAPI) -> None:

rust/cli/src/client/chroma_client.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::client::prelude::CollectionModel;
44
use crate::client::utils::send_request;
55
use crate::utils::Profile;
66
use axum::http::Method;
7-
use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata};
7+
use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata, Schema};
88
use std::error::Error;
99
use std::ops::Deref;
1010
use thiserror::Error;
@@ -104,6 +104,7 @@ impl ChromaClient {
104104
name: String,
105105
metadata: Option<Metadata>,
106106
configuration: Option<CollectionConfiguration>,
107+
schema: Option<Schema>,
107108
) -> Result<Collection, Box<dyn Error>> {
108109
let route = format!(
109110
"/api/v2/tenants/{}/databases/{}/collections",
@@ -115,7 +116,7 @@ impl ChromaClient {
115116
configuration,
116117
metadata,
117118
get_or_create: false,
118-
schema: None,
119+
schema,
119120
};
120121
let response = send_request::<CreateCollectionPayload, CollectionModel>(
121122
&self.host,

rust/cli/src/commands/copy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ async fn copy_collections(
239239
collection.name.clone(),
240240
collection.metadata.clone(),
241241
Some(CollectionConfiguration::from(collection.config.clone())),
242+
collection.schema.clone(),
242243
)
243244
.await
244245
.map_err(|_| ChromaClientError::CreateCollection(collection.name.clone()))?;

0 commit comments

Comments
 (0)