Skip to content

Commit 32d4b95

Browse files
committed
[ENH] Add local support for schema
1 parent c8846eb commit 32d4b95

File tree

26 files changed

+687
-163
lines changed

26 files changed

+687
-163
lines changed

chromadb/api/rust.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def list_collections(
193193
CollectionModel(
194194
id=collection.id,
195195
name=collection.name,
196-
serialized_schema=None,
196+
serialized_schema=collection.schema,
197197
configuration_json=collection.configuration,
198198
metadata=collection.metadata,
199199
dimension=collection.dimension,
@@ -229,14 +229,25 @@ def create_collection(
229229
else:
230230
configuration_json_str = None
231231

232+
if schema:
233+
schema_str = json.dumps(schema.serialize_to_json())
234+
else:
235+
schema_str = None
236+
232237
collection = self.bindings.create_collection(
233-
name, configuration_json_str, metadata, get_or_create, tenant, database
238+
name,
239+
configuration_json_str,
240+
schema_str,
241+
metadata,
242+
get_or_create,
243+
tenant,
244+
database,
234245
)
235246
collection_model = CollectionModel(
236247
id=collection.id,
237248
name=collection.name,
238249
configuration_json=collection.configuration,
239-
serialized_schema=None,
250+
serialized_schema=collection.schema,
240251
metadata=collection.metadata,
241252
dimension=collection.dimension,
242253
tenant=collection.tenant,
@@ -256,7 +267,7 @@ def get_collection(
256267
id=collection.id,
257268
name=collection.name,
258269
configuration_json=collection.configuration,
259-
serialized_schema=None,
270+
serialized_schema=collection.schema,
260271
metadata=collection.metadata,
261272
dimension=collection.dimension,
262273
tenant=collection.tenant,

chromadb/chromadb_rust_bindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Bindings:
9595
self,
9696
name: str,
9797
configuration_json_str: Optional[str] = None,
98+
schema_str: Optional[str] = None,
9899
metadata: Optional[CollectionMetadata] = None,
99100
get_or_create: bool = False,
100101
tenant: str = DEFAULT_TENANT,

chromadb/test/api/test_schema_e2e.py

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def build_from_config(config: Dict[str, Any]) -> "RecordingSearchEmbeddingFuncti
9393
return RecordingSearchEmbeddingFunction(config.get("label", "default"))
9494

9595

96-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
9796
def test_schema_spann_vector_config_persistence(
9897
client_factories: "ClientFactories",
9998
) -> None:
@@ -135,29 +134,46 @@ def test_schema_spann_vector_config_persistence(
135134
assert vector_index is not None
136135
assert vector_index.enabled is True
137136
assert vector_index.config is not None
138-
assert vector_index.config.spann is not None
139-
spann_config = vector_index.config.spann
140-
assert spann_config.search_nprobe == 16
141-
assert spann_config.write_nprobe == 32
142-
assert spann_config.ef_construction == 120
143-
assert spann_config.max_neighbors == 24
137+
138+
if not is_spann_disabled_mode:
139+
assert vector_index.config.spann is not None
140+
spann_config = vector_index.config.spann
141+
assert spann_config.search_nprobe == 16
142+
assert spann_config.write_nprobe == 32
143+
assert spann_config.ef_construction == 120
144+
assert spann_config.max_neighbors == 24
145+
else:
146+
assert vector_index.config.spann is None
147+
assert vector_index.config.hnsw is not None
148+
hnsw_config = vector_index.config.hnsw
149+
assert hnsw_config.ef_construction == 100
150+
assert hnsw_config.ef_search == 100
151+
assert hnsw_config.max_neighbors == 16
152+
assert hnsw_config.resize_factor == 1.2
144153

145154
ef = vector_index.config.embedding_function
146155
assert ef is not None
147156
assert ef.name() == "simple_ef"
148157
assert ef.get_config() == {"dim": 6}
149158

150159
persisted_json = persisted_schema.serialize_to_json()
151-
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
152-
"config"
153-
]["spann"]
154-
assert spann_json["search_nprobe"] == 16
155-
assert spann_json["write_nprobe"] == 32
160+
if not is_spann_disabled_mode:
161+
spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
162+
"config"
163+
]["spann"]
164+
assert spann_json["search_nprobe"] == 16
165+
assert spann_json["write_nprobe"] == 32
166+
else:
167+
hnsw_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][
168+
"config"
169+
]["hnsw"]
170+
assert hnsw_json["ef_construction"] == 100
171+
assert hnsw_json["ef_search"] == 100
172+
assert hnsw_json["max_neighbors"] == 16
156173

157174
client_reloaded = client_factories.create_client_from_system()
158175
reloaded_collection = client_reloaded.get_collection(
159176
name=collection_name,
160-
embedding_function=SimpleEmbeddingFunction(dim=6), # type: ignore[arg-type]
161177
)
162178

163179
reloaded_schema = reloaded_collection.schema
@@ -167,9 +183,23 @@ def test_schema_spann_vector_config_persistence(
167183
reloaded_vector_index = reloaded_embedding_override.vector_index
168184
assert reloaded_vector_index is not None
169185
assert reloaded_vector_index.config is not None
170-
assert reloaded_vector_index.config.spann is not None
171-
assert reloaded_vector_index.config.spann.search_nprobe == 16
172-
assert reloaded_vector_index.config.spann.write_nprobe == 32
186+
if not is_spann_disabled_mode:
187+
assert reloaded_vector_index.config.spann is not None
188+
assert reloaded_vector_index.config.spann.search_nprobe == 16
189+
assert reloaded_vector_index.config.spann.write_nprobe == 32
190+
else:
191+
assert reloaded_vector_index.config.hnsw is not None
192+
assert reloaded_vector_index.config.hnsw.ef_construction == 100
193+
assert reloaded_vector_index.config.hnsw.ef_search == 100
194+
assert reloaded_vector_index.config.hnsw.max_neighbors == 16
195+
assert reloaded_vector_index.config.hnsw.resize_factor == 1.2
196+
197+
config = reloaded_collection.configuration
198+
assert config is not None
199+
config_ef = config.get("embedding_function")
200+
assert config_ef is not None
201+
assert config_ef.name() == "simple_ef"
202+
assert config_ef.get_config() == {"dim": 6}
173203

174204

175205
@register_sparse_embedding_function
@@ -650,6 +680,52 @@ def test_schema_delete_index_and_restore(
650680
assert set(search["ids"]) == {"key-enabled"}
651681

652682

683+
def test_disabled_metadata_index_filters_raise_invalid_argument_all_modes(
684+
client_factories: "ClientFactories",
685+
) -> None:
686+
"""Disabled metadata inverted index should block filter-based operations in get, query, and delete for local, single node, and distributed."""
687+
schema = Schema().delete_index(
688+
key="restricted_tag", config=StringInvertedIndexConfig()
689+
)
690+
collection, _ = _create_isolated_collection(client_factories, schema=schema)
691+
692+
collection.add(
693+
ids=["restricted-doc"],
694+
embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
695+
metadatas=[{"restricted_tag": "blocked"}],
696+
documents=["doc"],
697+
)
698+
699+
assert collection.schema is not None
700+
schema_entry = collection.schema.keys["restricted_tag"].string
701+
assert schema_entry is not None
702+
index_config = schema_entry.string_inverted_index
703+
assert index_config is not None
704+
assert index_config.enabled is False
705+
706+
filter_payload: Dict[str, Any] = {"restricted_tag": "blocked"}
707+
708+
def _expect_disabled_error(operation: Callable[[], Any]) -> None:
709+
with pytest.raises(InvalidArgumentError) as exc_info:
710+
operation()
711+
assert "Cannot filter using metadata key 'restricted_tag'" in str(
712+
exc_info.value
713+
)
714+
715+
operations: List[Callable[[], Any]] = [
716+
lambda: collection.get(where=filter_payload),
717+
lambda: collection.query(
718+
query_embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]),
719+
n_results=1,
720+
where=filter_payload,
721+
),
722+
lambda: collection.delete(where=filter_payload),
723+
]
724+
725+
for operation in operations:
726+
_expect_disabled_error(operation)
727+
728+
653729
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
654730
def test_disabled_metadata_index_filters_raise_invalid_argument(
655731
client_factories: "ClientFactories",

rust/cli/src/client/chroma_client.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::client::prelude::CollectionModel;
44
use crate::client::utils::send_request;
55
use crate::utils::Profile;
66
use axum::http::Method;
7-
use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata};
7+
use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata, Schema};
88
use std::error::Error;
99
use std::ops::Deref;
1010
use thiserror::Error;
@@ -104,6 +104,7 @@ impl ChromaClient {
104104
name: String,
105105
metadata: Option<Metadata>,
106106
configuration: Option<CollectionConfiguration>,
107+
schema: Option<Schema>,
107108
) -> Result<Collection, Box<dyn Error>> {
108109
let route = format!(
109110
"/api/v2/tenants/{}/databases/{}/collections",
@@ -115,7 +116,7 @@ impl ChromaClient {
115116
configuration,
116117
metadata,
117118
get_or_create: false,
118-
schema: None,
119+
schema,
119120
};
120121
let response = send_request::<CreateCollectionPayload, CollectionModel>(
121122
&self.host,

rust/cli/src/commands/copy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ async fn copy_collections(
239239
collection.name.clone(),
240240
collection.metadata.clone(),
241241
Some(CollectionConfiguration::from(collection.config.clone())),
242+
collection.schema.clone(),
242243
)
243244
.await
244245
.map_err(|_| ChromaClientError::CreateCollection(collection.name.clone()))?;

rust/cli/src/commands/vacuum.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use chroma_segment::local_segment_manager::LocalSegmentManager;
1111
use chroma_sqlite::db::SqliteDb;
1212
use chroma_sysdb::SysDb;
1313
use chroma_system::System;
14-
use chroma_types::{CollectionUuid, ListCollectionsRequest};
14+
use chroma_types::{CollectionUuid, KnnIndex, ListCollectionsRequest};
1515
use clap::Parser;
1616
use colored::Colorize;
1717
use dialoguer::Confirm;
@@ -101,11 +101,17 @@ async fn trigger_vector_segments_max_seq_id_migration(
101101
sqlite: &SqliteDb,
102102
sysdb: &mut SysDb,
103103
segment_manager: &LocalSegmentManager,
104+
default_knn_index: KnnIndex,
104105
) -> Result<(), Box<dyn Error>> {
105106
let collection_ids = get_collection_ids_to_migrate(sqlite).await?;
106107

107108
for collection_id in collection_ids {
108-
let collection = sysdb.get_collection_with_segments(collection_id).await?;
109+
let mut collection = sysdb.get_collection_with_segments(collection_id).await?;
110+
111+
collection
112+
.collection
113+
.reconcile_schema_with_config(default_knn_index)
114+
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
109115

110116
// If collection is uninitialized, that means nothing has been written yet.
111117
let dim = match collection.collection.dimension {
@@ -138,7 +144,7 @@ async fn configure_sql_embedding_queue(log: &SqliteLog) -> Result<(), Box<dyn Er
138144
pub async fn vacuum_chroma(config: FrontendConfig) -> Result<(), Box<dyn Error>> {
139145
let system = System::new();
140146
let registry = Registry::new();
141-
let mut frontend = Frontend::try_from_config(&(config, system), &registry).await?;
147+
let mut frontend = Frontend::try_from_config(&(config.clone(), system), &registry).await?;
142148

143149
let sqlite = registry.get::<SqliteDb>()?;
144150
let segment_manager = registry.get::<LocalSegmentManager>()?;
@@ -147,7 +153,13 @@ pub async fn vacuum_chroma(config: FrontendConfig) -> Result<(), Box<dyn Error>>
147153

148154
println!("Purging the log...\n");
149155

150-
trigger_vector_segments_max_seq_id_migration(&sqlite, &mut sysdb, &segment_manager).await?;
156+
trigger_vector_segments_max_seq_id_migration(
157+
&sqlite,
158+
&mut sysdb,
159+
&segment_manager,
160+
config.default_knn_index,
161+
)
162+
.await?;
151163

152164
let tenant = String::from("default_tenant");
153165
let database = String::from("default_database");

rust/frontend/src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ fn default_enable_span_indexing() -> bool {
142142
}
143143

144144
fn default_enable_schema() -> bool {
145-
false
145+
true
146146
}
147147

148148
pub fn default_min_records_for_task() -> u64 {

rust/frontend/src/executor/local.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,21 @@ impl LocalExecutor {
204204
allowed_offset_ids.push(offset_id);
205205
}
206206

207-
let distance_function = match collection_and_segments
207+
let hnsw_config = collection_and_segments
208208
.collection
209-
.config
210-
.get_hnsw_config_with_legacy_fallback(&plan.scan.collection_and_segments.vector_segment)
211-
{
212-
Ok(Some(config)) => config.space,
213-
Ok(None) => return Err(ExecutorError::CollectionMissingHnswConfiguration),
214-
Err(err) => {
215-
return Err(ExecutorError::Internal(Box::new(err)));
216-
}
217-
};
209+
.schema
210+
.as_ref()
211+
.map(|schema| {
212+
schema.get_internal_hnsw_config_with_legacy_fallback(
213+
&plan.scan.collection_and_segments.vector_segment,
214+
)
215+
})
216+
.transpose()
217+
.map_err(|err| ExecutorError::Internal(Box::new(err)))?
218+
.flatten()
219+
.ok_or(ExecutorError::CollectionMissingHnswConfiguration)?;
220+
221+
let distance_function = hnsw_config.space;
218222

219223
let mut results = Vec::new();
220224
let mut returned_user_ids = Vec::new();

rust/frontend/src/get_collection_with_segments_provider.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use chroma_config::Configurable;
44
use chroma_error::{ChromaError, ErrorCodes};
55
use chroma_sysdb::SysDb;
66
use chroma_types::{
7-
CollectionAndSegments, CollectionUuid, GetCollectionWithSegmentsError, Schema, SchemaError,
7+
CollectionAndSegments, CollectionUuid, GetCollectionWithSegmentsError, KnnIndex, Schema,
8+
SchemaError,
89
};
910
use serde::{Deserialize, Serialize};
1011
use std::{
@@ -142,6 +143,7 @@ impl CollectionsWithSegmentsProvider {
142143
pub(crate) async fn get_collection_with_segments(
143144
&mut self,
144145
collection_id: CollectionUuid,
146+
knn_index: KnnIndex,
145147
) -> Result<CollectionAndSegments, CollectionsWithSegmentsProviderError> {
146148
if let Some(collection_and_segments_with_ttl) = self
147149
.collections_with_segments_cache
@@ -187,6 +189,7 @@ impl CollectionsWithSegmentsProvider {
187189
let reconciled_schema = Schema::reconcile_schema_and_config(
188190
collection_and_segments_sysdb.collection.schema.as_ref(),
189191
Some(&collection_and_segments_sysdb.collection.config),
192+
knn_index,
190193
)
191194
.map_err(CollectionsWithSegmentsProviderError::InvalidSchema)?;
192195
collection_and_segments_sysdb.collection.schema = Some(reconciled_schema);

0 commit comments

Comments
 (0)