diff --git a/chromadb/api/collection_configuration.py b/chromadb/api/collection_configuration.py index 029d4283912..e18ca62cf7d 100644 --- a/chromadb/api/collection_configuration.py +++ b/chromadb/api/collection_configuration.py @@ -812,12 +812,12 @@ def update_schema_from_collection_configuration( Returns: Updated Schema object """ - # TODO: Remove this check once schema is enabled in local. - if schema.defaults.float_list is None: - return schema # Get the vector index from defaults and #embedding key - if schema.defaults.float_list is None or schema.defaults.float_list.vector_index is None: + if ( + schema.defaults.float_list is None + or schema.defaults.float_list.vector_index is None + ): raise ValueError("Schema is missing defaults.float_list.vector_index") embedding_key = "#embedding" @@ -825,8 +825,13 @@ def update_schema_from_collection_configuration( raise ValueError(f"Schema is missing keys[{embedding_key}]") embedding_value_types = schema.keys[embedding_key] - if embedding_value_types.float_list is None or embedding_value_types.float_list.vector_index is None: - raise ValueError(f"Schema is missing keys[{embedding_key}].float_list.vector_index") + if ( + embedding_value_types.float_list is None + or embedding_value_types.float_list.vector_index is None + ): + raise ValueError( + f"Schema is missing keys[{embedding_key}].float_list.vector_index" + ) # Update vector index config in both locations for vector_index in [ @@ -868,7 +873,10 @@ def update_schema_from_collection_configuration( spann_config.ef_search = update_spann["ef_search"] # Update embedding function if present - if "embedding_function" in configuration and configuration["embedding_function"] is not None: + if ( + "embedding_function" in configuration + and configuration["embedding_function"] is not None + ): vector_index.config.embedding_function = configuration["embedding_function"] return schema diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 3aae75d030e..cf87f6ed728 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -193,7 +193,7 @@ def list_collections( CollectionModel( id=collection.id, name=collection.name, - serialized_schema=None, + serialized_schema=collection.schema, configuration_json=collection.configuration, metadata=collection.metadata, dimension=collection.dimension, @@ -229,14 +229,25 @@ def create_collection( else: configuration_json_str = None + if schema: + schema_str = json.dumps(schema.serialize_to_json()) + else: + schema_str = None + collection = self.bindings.create_collection( - name, configuration_json_str, metadata, get_or_create, tenant, database + name, + configuration_json_str, + schema_str, + metadata, + get_or_create, + tenant, + database, ) collection_model = CollectionModel( id=collection.id, name=collection.name, configuration_json=collection.configuration, - serialized_schema=None, + serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, @@ -256,7 +267,7 @@ def get_collection( id=collection.id, name=collection.name, configuration_json=collection.configuration, - serialized_schema=None, + serialized_schema=collection.schema, metadata=collection.metadata, dimension=collection.dimension, tenant=collection.tenant, diff --git a/chromadb/chromadb_rust_bindings.pyi b/chromadb/chromadb_rust_bindings.pyi index a001425ed9c..26377dc65fa 100644 --- a/chromadb/chromadb_rust_bindings.pyi +++ b/chromadb/chromadb_rust_bindings.pyi @@ -95,6 +95,7 @@ class Bindings: self, name: str, configuration_json_str: Optional[str] = None, + schema_str: Optional[str] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, diff --git a/chromadb/test/api/test_schema_e2e.py b/chromadb/test/api/test_schema_e2e.py index 5f2b1d1154d..65a86ad15c1 100644 --- a/chromadb/test/api/test_schema_e2e.py +++ b/chromadb/test/api/test_schema_e2e.py @@ -19,6 +19,7 @@ is_spann_disabled_mode, skip_if_not_cluster, skip_reason_spann_disabled, + skip_reason_spann_enabled, ) from chromadb.test.utils.wait_for_version_increase import ( get_collection_version, @@ -29,7 +30,7 @@ register_sparse_embedding_function, ) from chromadb.api.models.Collection import Collection -from chromadb.errors import InvalidArgumentError +from chromadb.errors import InvalidArgumentError, InternalError from chromadb.execution.expression import Knn, Search from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast from uuid import uuid4 @@ -94,8 +95,7 @@ def build_from_config(config: Dict[str, Any]) -> "RecordingSearchEmbeddingFuncti return RecordingSearchEmbeddingFunction(config.get("label", "default")) -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) -def test_schema_spann_vector_config_persistence( +def test_schema_vector_config_persistence( client_factories: "ClientFactories", ) -> None: """Ensure schema-provided SPANN settings persist across client restarts.""" @@ -136,12 +136,22 @@ def test_schema_spann_vector_config_persistence( assert vector_index is not None assert vector_index.enabled is True assert vector_index.config is not None - assert vector_index.config.spann is not None - spann_config = vector_index.config.spann - assert spann_config.search_nprobe == 16 - assert spann_config.write_nprobe == 32 - assert spann_config.ef_construction == 120 - assert spann_config.max_neighbors == 24 + + if not is_spann_disabled_mode: + assert vector_index.config.spann is not None + spann_config = vector_index.config.spann + assert spann_config.search_nprobe == 16 + assert spann_config.write_nprobe == 32 + assert spann_config.ef_construction == 120 + assert spann_config.max_neighbors == 24 + else: + assert vector_index.config.spann is None + assert vector_index.config.hnsw is not None + hnsw_config = vector_index.config.hnsw + assert hnsw_config.ef_construction == 100 + assert hnsw_config.ef_search == 100 + assert hnsw_config.max_neighbors == 16 + assert hnsw_config.resize_factor == 1.2 ef = vector_index.config.embedding_function assert ef is not None @@ -149,16 +159,23 @@ def test_schema_spann_vector_config_persistence( assert ef.get_config() == {"dim": 6} persisted_json = persisted_schema.serialize_to_json() - spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][ - "config" - ]["spann"] - assert spann_json["search_nprobe"] == 16 - assert spann_json["write_nprobe"] == 32 + if not is_spann_disabled_mode: + spann_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][ + "config" + ]["spann"] + assert spann_json["search_nprobe"] == 16 + assert spann_json["write_nprobe"] == 32 + else: + hnsw_json = persisted_json["keys"]["#embedding"]["float_list"]["vector_index"][ + "config" + ]["hnsw"] + assert hnsw_json["ef_construction"] == 100 + assert hnsw_json["ef_search"] == 100 + assert hnsw_json["max_neighbors"] == 16 client_reloaded = client_factories.create_client_from_system() reloaded_collection = client_reloaded.get_collection( name=collection_name, - embedding_function=SimpleEmbeddingFunction(dim=6), # type: ignore[arg-type] ) reloaded_schema = reloaded_collection.schema @@ -168,9 +185,23 @@ def test_schema_spann_vector_config_persistence( reloaded_vector_index = reloaded_embedding_override.vector_index assert reloaded_vector_index is not None assert reloaded_vector_index.config is not None - assert reloaded_vector_index.config.spann is not None - assert reloaded_vector_index.config.spann.search_nprobe == 16 - assert reloaded_vector_index.config.spann.write_nprobe == 32 + if not is_spann_disabled_mode: + assert reloaded_vector_index.config.spann is not None + assert reloaded_vector_index.config.spann.search_nprobe == 16 + assert reloaded_vector_index.config.spann.write_nprobe == 32 + else: + assert reloaded_vector_index.config.hnsw is not None + assert reloaded_vector_index.config.hnsw.ef_construction == 100 + assert reloaded_vector_index.config.hnsw.ef_search == 100 + assert reloaded_vector_index.config.hnsw.max_neighbors == 16 + assert reloaded_vector_index.config.hnsw.resize_factor == 1.2 + + config = reloaded_collection.configuration + assert config is not None + config_ef = config.get("embedding_function") + assert config_ef is not None + assert config_ef.name() == "simple_ef" + assert config_ef.get_config() == {"dim": 6} @register_sparse_embedding_function @@ -258,7 +289,6 @@ def _collect_knn_queries(rank: Any) -> List[Any]: return queries -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_schema_defaults_enable_indexed_operations( client_factories: "ClientFactories", ) -> None: @@ -336,7 +366,6 @@ def test_schema_defaults_enable_indexed_operations( assert reloaded.schema.serialize_to_json() == schema.serialize_to_json() -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_get_or_create_and_get_collection_preserve_schema( client_factories: "ClientFactories", ) -> None: @@ -379,7 +408,6 @@ def test_get_or_create_and_get_collection_preserve_schema( assert set(stored["ids"]) == {"schema-preserve"} -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_delete_collection_resets_schema_configuration( client_factories: "ClientFactories", ) -> None: @@ -408,6 +436,19 @@ def test_delete_collection_resets_schema_configuration( assert set(recreated_json["keys"].keys()) == set(baseline_json["keys"].keys()) +@pytest.mark.skipif(not is_spann_disabled_mode, reason=skip_reason_spann_enabled) +def test_sparse_vector_not_allowed_locally( + client_factories: "ClientFactories", +) -> None: + """Sparse vector configs are not allowed to be created locally.""" + schema = Schema() + schema.create_index(key="sparse_metadata", config=SparseVectorIndexConfig()) + with pytest.raises( + InternalError, match="Sparse vector indexing is not enabled in local" + ): + _create_isolated_collection(client_factories, schema=schema) + + @pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_sparse_vector_source_key_and_index_constraints( client_factories: "ClientFactories", @@ -466,7 +507,6 @@ def test_sparse_vector_source_key_and_index_constraints( assert set(string_filter["ids"]) == {"sparse-1"} -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_schema_persistence_with_custom_overrides( client_factories: "ClientFactories", ) -> None: @@ -507,7 +547,6 @@ def test_schema_persistence_with_custom_overrides( assert set(fetched["ids"]) == {"persist-1"} -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_collection_embed_uses_schema_or_collection_embedding_function( client_factories: "ClientFactories", ) -> None: @@ -594,7 +633,6 @@ def test_search_embeds_string_queries_in_nested_ranks( assert all(isinstance(query, list) for query in all_queries) -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_schema_delete_index_and_restore( client_factories: "ClientFactories", ) -> None: @@ -651,6 +689,52 @@ def test_schema_delete_index_and_restore( assert set(search["ids"]) == {"key-enabled"} +def test_disabled_metadata_index_filters_raise_invalid_argument_all_modes( + client_factories: "ClientFactories", +) -> None: + """Disabled metadata inverted index should block filter-based operations in get, query, and delete for local, single node, and distributed.""" + schema = Schema().delete_index( + key="restricted_tag", config=StringInvertedIndexConfig() + ) + collection, _ = _create_isolated_collection(client_factories, schema=schema) + + collection.add( + ids=["restricted-doc"], + embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]), + metadatas=[{"restricted_tag": "blocked"}], + documents=["doc"], + ) + + assert collection.schema is not None + schema_entry = collection.schema.keys["restricted_tag"].string + assert schema_entry is not None + index_config = schema_entry.string_inverted_index + assert index_config is not None + assert index_config.enabled is False + + filter_payload: Dict[str, Any] = {"restricted_tag": "blocked"} + + def _expect_disabled_error(operation: Callable[[], Any]) -> None: + with pytest.raises(InvalidArgumentError) as exc_info: + operation() + assert "Cannot filter using metadata key 'restricted_tag'" in str( + exc_info.value + ) + + operations: List[Callable[[], Any]] = [ + lambda: collection.get(where=filter_payload), + lambda: collection.query( + query_embeddings=cast(Embeddings, [[0.1, 0.2, 0.3, 0.4]]), + n_results=1, + where=filter_payload, + ), + lambda: collection.delete(where=filter_payload), + ] + + for operation in operations: + _expect_disabled_error(operation) + + @pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_disabled_metadata_index_filters_raise_invalid_argument( client_factories: "ClientFactories", @@ -1548,7 +1632,6 @@ def test_sparse_auto_embedding_with_empty_documents( assert "empty_sparse" in metadata -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_default_embedding_function_when_no_schema_provided( client_factories: "ClientFactories", ) -> None: @@ -1604,7 +1687,6 @@ def test_default_embedding_function_when_no_schema_provided( assert sparse_ef_config["type"] == "unknown" -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_custom_embedding_function_without_schema( client_factories: "ClientFactories", ) -> None: @@ -1679,7 +1761,6 @@ def test_custom_embedding_function_without_schema( assert len(result["embeddings"][0]) == 8 -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_custom_embedding_function_with_default_schema( client_factories: "ClientFactories", ) -> None: @@ -1755,7 +1836,6 @@ def test_custom_embedding_function_with_default_schema( assert len(result["embeddings"][0]) == 8 -@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled) def test_conflicting_embedding_functions_in_schema_and_config_fails( client_factories: "ClientFactories", ) -> None: diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 87c0d48cd0a..583d3d497d6 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -91,7 +91,9 @@ skip_reason_spann_disabled = ( "SPANN creation/modification disallowed in Rust bindings or integration test mode" ) - +skip_reason_spann_enabled = ( + "SPANN creation/modification allowed in Rust bindings or integration test mode" +) def reset(api: BaseAPI) -> None: diff --git a/rust/cli/src/client/chroma_client.rs b/rust/cli/src/client/chroma_client.rs index 8b6164342a9..bb0bab498ed 100644 --- a/rust/cli/src/client/chroma_client.rs +++ b/rust/cli/src/client/chroma_client.rs @@ -4,7 +4,7 @@ use crate::client::prelude::CollectionModel; use crate::client::utils::send_request; use crate::utils::Profile; use axum::http::Method; -use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata}; +use chroma_types::{CollectionConfiguration, CreateCollectionPayload, Metadata, Schema}; use std::error::Error; use std::ops::Deref; use thiserror::Error; @@ -104,6 +104,7 @@ impl ChromaClient { name: String, metadata: Option, configuration: Option, + schema: Option, ) -> Result> { let route = format!( "/api/v2/tenants/{}/databases/{}/collections", @@ -115,7 +116,7 @@ impl ChromaClient { configuration, metadata, get_or_create: false, - schema: None, + schema, }; let response = send_request::( &self.host, diff --git a/rust/cli/src/commands/copy.rs b/rust/cli/src/commands/copy.rs index 8e77705cd71..b6bc0bfdd95 100644 --- a/rust/cli/src/commands/copy.rs +++ b/rust/cli/src/commands/copy.rs @@ -239,6 +239,7 @@ async fn copy_collections( collection.name.clone(), collection.metadata.clone(), Some(CollectionConfiguration::from(collection.config.clone())), + collection.schema.clone(), ) .await .map_err(|_| ChromaClientError::CreateCollection(collection.name.clone()))?; diff --git a/rust/cli/src/commands/vacuum.rs b/rust/cli/src/commands/vacuum.rs index b14c3b91955..9f505193364 100644 --- a/rust/cli/src/commands/vacuum.rs +++ b/rust/cli/src/commands/vacuum.rs @@ -11,7 +11,7 @@ use chroma_segment::local_segment_manager::LocalSegmentManager; use chroma_sqlite::db::SqliteDb; use chroma_sysdb::SysDb; use chroma_system::System; -use chroma_types::{CollectionUuid, ListCollectionsRequest}; +use chroma_types::{CollectionUuid, KnnIndex, ListCollectionsRequest}; use clap::Parser; use colored::Colorize; use dialoguer::Confirm; @@ -101,11 +101,17 @@ async fn trigger_vector_segments_max_seq_id_migration( sqlite: &SqliteDb, sysdb: &mut SysDb, segment_manager: &LocalSegmentManager, + default_knn_index: KnnIndex, ) -> Result<(), Box> { let collection_ids = get_collection_ids_to_migrate(sqlite).await?; for collection_id in collection_ids { - let collection = sysdb.get_collection_with_segments(collection_id).await?; + let mut collection = sysdb.get_collection_with_segments(collection_id).await?; + + collection + .collection + .reconcile_schema_with_config(default_knn_index) + .map_err(|e| Box::new(e) as Box)?; // If collection is uninitialized, that means nothing has been written yet. let dim = match collection.collection.dimension { @@ -138,7 +144,7 @@ async fn configure_sql_embedding_queue(log: &SqliteLog) -> Result<(), Box Result<(), Box> { let system = System::new(); let registry = Registry::new(); - let mut frontend = Frontend::try_from_config(&(config, system), ®istry).await?; + let mut frontend = Frontend::try_from_config(&(config.clone(), system), ®istry).await?; let sqlite = registry.get::()?; let segment_manager = registry.get::()?; @@ -147,7 +153,13 @@ pub async fn vacuum_chroma(config: FrontendConfig) -> Result<(), Box> println!("Purging the log...\n"); - trigger_vector_segments_max_seq_id_migration(&sqlite, &mut sysdb, &segment_manager).await?; + trigger_vector_segments_max_seq_id_migration( + &sqlite, + &mut sysdb, + &segment_manager, + config.default_knn_index, + ) + .await?; let tenant = String::from("default_tenant"); let database = String::from("default_database"); diff --git a/rust/frontend/src/config.rs b/rust/frontend/src/config.rs index 57b55e11d77..cdf5fa59a85 100644 --- a/rust/frontend/src/config.rs +++ b/rust/frontend/src/config.rs @@ -142,7 +142,7 @@ fn default_enable_span_indexing() -> bool { } fn default_enable_schema() -> bool { - false + true } pub fn default_min_records_for_task() -> u64 { diff --git a/rust/frontend/src/executor/local.rs b/rust/frontend/src/executor/local.rs index feffbd716d0..609d486e401 100644 --- a/rust/frontend/src/executor/local.rs +++ b/rust/frontend/src/executor/local.rs @@ -204,17 +204,21 @@ impl LocalExecutor { allowed_offset_ids.push(offset_id); } - let distance_function = match collection_and_segments + let hnsw_config = collection_and_segments .collection - .config - .get_hnsw_config_with_legacy_fallback(&plan.scan.collection_and_segments.vector_segment) - { - Ok(Some(config)) => config.space, - Ok(None) => return Err(ExecutorError::CollectionMissingHnswConfiguration), - Err(err) => { - return Err(ExecutorError::Internal(Box::new(err))); - } - }; + .schema + .as_ref() + .map(|schema| { + schema.get_internal_hnsw_config_with_legacy_fallback( + &plan.scan.collection_and_segments.vector_segment, + ) + }) + .transpose() + .map_err(|err| ExecutorError::Internal(Box::new(err)))? + .flatten() + .ok_or(ExecutorError::CollectionMissingHnswConfiguration)?; + + let distance_function = hnsw_config.space; let mut results = Vec::new(); let mut returned_user_ids = Vec::new(); diff --git a/rust/frontend/src/get_collection_with_segments_provider.rs b/rust/frontend/src/get_collection_with_segments_provider.rs index 7eef28b12c1..c5622220acd 100644 --- a/rust/frontend/src/get_collection_with_segments_provider.rs +++ b/rust/frontend/src/get_collection_with_segments_provider.rs @@ -4,7 +4,8 @@ use chroma_config::Configurable; use chroma_error::{ChromaError, ErrorCodes}; use chroma_sysdb::SysDb; use chroma_types::{ - CollectionAndSegments, CollectionUuid, GetCollectionWithSegmentsError, Schema, SchemaError, + CollectionAndSegments, CollectionUuid, GetCollectionWithSegmentsError, KnnIndex, Schema, + SchemaError, }; use serde::{Deserialize, Serialize}; use std::{ @@ -142,6 +143,7 @@ impl CollectionsWithSegmentsProvider { pub(crate) async fn get_collection_with_segments( &mut self, collection_id: CollectionUuid, + knn_index: KnnIndex, ) -> Result { if let Some(collection_and_segments_with_ttl) = self .collections_with_segments_cache @@ -187,6 +189,7 @@ impl CollectionsWithSegmentsProvider { let reconciled_schema = Schema::reconcile_schema_and_config( collection_and_segments_sysdb.collection.schema.as_ref(), Some(&collection_and_segments_sysdb.collection.config), + knn_index, ) .map_err(CollectionsWithSegmentsProviderError::InvalidSchema)?; collection_and_segments_sysdb.collection.schema = Some(reconciled_schema); diff --git a/rust/frontend/src/impls/in_memory_frontend.rs b/rust/frontend/src/impls/in_memory_frontend.rs index 7665b1fd344..e28c8287c73 100644 --- a/rust/frontend/src/impls/in_memory_frontend.rs +++ b/rust/frontend/src/impls/in_memory_frontend.rs @@ -6,7 +6,7 @@ use chroma_types::operator::{Filter, KnnBatch, KnnProjection, Limit, Projection, use chroma_types::plan::{Count, Get, Knn}; use chroma_types::{ test_segment, Collection, CollectionAndSegments, CreateCollectionError, Database, Include, - IncludeList, InternalCollectionConfiguration, Segment, VectorIndexConfiguration, + IncludeList, InternalCollectionConfiguration, KnnIndex, Segment, VectorIndexConfiguration, }; use std::collections::HashSet; @@ -221,16 +221,22 @@ impl InMemoryFrontend { )); } - let collection = Collection { - name: request.name, - tenant: request.tenant_id, - database: request.database_name, + let mut collection = Collection { + name: request.name.clone(), + tenant: request.tenant_id.clone(), + database: request.database_name.clone(), + metadata: request.metadata, config: request .configuration .unwrap_or(InternalCollectionConfiguration::default_hnsw()), + schema: request.schema, ..Default::default() }; + collection + .reconcile_schema_with_config(KnnIndex::Hnsw) + .map_err(CreateCollectionError::InvalidSchema)?; + // Prevent SPANN usage in InMemoryFrontend if matches!( collection.config.vector_index, @@ -620,10 +626,15 @@ impl InMemoryFrontend { let params = collection .collection - .config - .get_hnsw_config_with_legacy_fallback(&collection.vector_segment) + .schema + .as_ref() + .map(|schema| { + schema.get_internal_hnsw_config_with_legacy_fallback(&collection.vector_segment) + }) + .transpose() .map_err(|e| e.boxed())? - .unwrap(); + .flatten() + .expect("HNSW configuration missing for collection schema"); let distance_function: DistanceFunction = params.space.into(); let query_response = collection diff --git a/rust/frontend/src/impls/service_based_frontend.rs b/rust/frontend/src/impls/service_based_frontend.rs index 37f4d5b66ee..e5f747b4e52 100644 --- a/rust/frontend/src/impls/service_based_frontend.rs +++ b/rust/frontend/src/impls/service_based_frontend.rs @@ -37,7 +37,7 @@ use chroma_types::{ ListCollectionsRequest, ListCollectionsResponse, ListDatabasesError, ListDatabasesRequest, ListDatabasesResponse, Operation, OperationRecord, QueryError, QueryRequest, QueryResponse, RemoveTaskError, RemoveTaskRequest, RemoveTaskResponse, ResetError, ResetResponse, Schema, - SearchRequest, SearchResponse, Segment, SegmentScope, SegmentType, SegmentUuid, + SchemaError, SearchRequest, SearchResponse, Segment, SegmentScope, SegmentType, SegmentUuid, UpdateCollectionError, UpdateCollectionRecordsError, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, UpdateCollectionRequest, UpdateCollectionResponse, UpdateTenantError, UpdateTenantRequest, UpdateTenantResponse, UpsertCollectionRecordsError, @@ -176,7 +176,7 @@ impl ServiceBasedFrontend { ) -> Result { Ok(self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)? .collection) @@ -188,7 +188,7 @@ impl ServiceBasedFrontend { ) -> Result, GetCollectionError> { Ok(self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)? .collection @@ -381,7 +381,7 @@ impl ServiceBasedFrontend { if self.enable_schema { for collection in collections.iter_mut() { collection - .reconcile_schema_with_config() + .reconcile_schema_with_config(self.default_knn_index) .map_err(GetCollectionsError::InvalidSchema)?; } } @@ -425,7 +425,7 @@ impl ServiceBasedFrontend { if self.enable_schema { for collection in &mut collections { collection - .reconcile_schema_with_config() + .reconcile_schema_with_config(self.default_knn_index) .map_err(GetCollectionError::InvalidSchema)?; } } @@ -450,7 +450,7 @@ impl ServiceBasedFrontend { if self.enable_schema { collection - .reconcile_schema_with_config() + .reconcile_schema_with_config(self.default_knn_index) .map_err(GetCollectionByCrnError::InvalidSchema)?; } Ok(collection) @@ -517,6 +517,7 @@ impl ServiceBasedFrontend { match Schema::reconcile_schema_and_config( schema.as_ref(), config_for_reconcile.as_ref(), + self.default_knn_index, ) { Ok(schema) => Some(schema), Err(e) => { @@ -571,6 +572,19 @@ impl ServiceBasedFrontend { ] } Executor::Local(_) => { + if self.enable_schema { + if let Some(schema) = reconciled_schema.as_ref() { + if schema.is_sparse_index_enabled() { + return Err(CreateCollectionError::InvalidSchema( + SchemaError::InvalidSchema { + reason: "Sparse vector indexing is not enabled in local" + .to_string(), + }, + )); + } + } + } + vec![ Segment { id: SegmentUuid::new(), @@ -616,7 +630,7 @@ impl ServiceBasedFrontend { // that was retrieved from sysdb, rather than the one that was passed in if self.enable_schema { collection - .reconcile_schema_with_config() + .reconcile_schema_with_config(self.default_knn_index) .map_err(CreateCollectionError::InvalidSchema)?; } Ok(collection) @@ -721,7 +735,7 @@ impl ServiceBasedFrontend { .await?; collection_and_segments .collection - .reconcile_schema_with_config() + .reconcile_schema_with_config(self.default_knn_index) .map_err(ForkCollectionError::InvalidSchema)?; let collection = collection_and_segments.collection.clone(); let latest_collection_logical_size_bytes = collection_and_segments @@ -1085,7 +1099,7 @@ impl ServiceBasedFrontend { let read_event = if let Some(where_clause) = r#where { let collection_and_segments = self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)?; if self.enable_schema { @@ -1295,7 +1309,7 @@ impl ServiceBasedFrontend { ) -> Result { let collection_and_segments = self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)?; let latest_collection_logical_size_bytes = collection_and_segments @@ -1410,7 +1424,7 @@ impl ServiceBasedFrontend { ) -> Result { let collection_and_segments = self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)?; if self.enable_schema { @@ -1555,7 +1569,7 @@ impl ServiceBasedFrontend { ) -> Result { let collection_and_segments = self .collections_with_segments_provider - .get_collection_with_segments(collection_id) + .get_collection_with_segments(collection_id, self.default_knn_index) .await .map_err(|err| Box::new(err) as Box)?; if self.enable_schema { @@ -1712,7 +1726,7 @@ impl ServiceBasedFrontend { // Get collection and segments once for all queries let collection_and_segments = self .collections_with_segments_provider - .get_collection_with_segments(request.collection_id) + .get_collection_with_segments(request.collection_id, self.default_knn_index) .await .map_err(|err| QueryError::Other(Box::new(err) as Box))?; if self.enable_schema { diff --git a/rust/log/src/local_compaction_manager.rs b/rust/log/src/local_compaction_manager.rs index 3658b79fa6d..13556e386eb 100644 --- a/rust/log/src/local_compaction_manager.rs +++ b/rust/log/src/local_compaction_manager.rs @@ -14,7 +14,9 @@ use chroma_sqlite::db::SqliteDb; use chroma_sysdb::SysDb; use chroma_system::Handler; use chroma_system::{Component, ComponentContext}; -use chroma_types::{Chunk, CollectionUuid, GetCollectionWithSegmentsError, LogRecord}; +use chroma_types::{ + Chunk, CollectionUuid, GetCollectionWithSegmentsError, KnnIndex, LogRecord, SchemaError, +}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -94,6 +96,8 @@ pub enum CompactionManagerError { HnswReaderConstructionError(#[from] LocalSegmentManagerError), #[error("Error purging logs")] PurgeLogsFailure, + #[error("Failed to reconcile collection schema: {0}")] + SchemaReconcileError(#[from] SchemaError), } impl ChromaError for CompactionManagerError { @@ -108,6 +112,7 @@ impl ChromaError for CompactionManagerError { CompactionManagerError::HnswReaderError(e) => e.code(), CompactionManagerError::HnswReaderConstructionError(e) => e.code(), CompactionManagerError::PurgeLogsFailure => ErrorCodes::Internal, + CompactionManagerError::SchemaReconcileError(e) => e.code(), } } } @@ -131,10 +136,13 @@ impl Handler for LocalCompactionManager { message: BackfillMessage, _: &ComponentContext, ) -> Self::Result { - let collection_and_segments = self + let mut collection_and_segments = self .sysdb .get_collection_with_segments(message.collection_id) .await?; + collection_and_segments + .collection + .reconcile_schema_with_config(KnnIndex::Hnsw)?; // If collection is uninitialized, that means nothing has been written yet. let dim = match collection_and_segments.collection.dimension { Some(dim) => dim, @@ -240,8 +248,10 @@ impl Handler for LocalCompactionManager { .sysdb .get_collection_with_segments(message.collection_id) .await?; + let mut collection = collection_segments.collection.clone(); + collection.reconcile_schema_with_config(KnnIndex::Hnsw)?; // If dimension is None, that means nothing has been written yet. - let dim = match collection_segments.collection.dimension { + let dim = match collection.dimension { Some(dim) => dim, None => return Ok(()), }; @@ -252,7 +262,7 @@ impl Handler for LocalCompactionManager { let hnsw_reader = self .hnsw_segment_manager .get_hnsw_reader( - &collection_segments.collection, + &collection, &collection_segments.vector_segment, dim as usize, ) diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 3a97f2f2426..646f28cd586 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -113,7 +113,7 @@ impl Bindings { let executor_config = ExecutorConfig::Local(LocalExecutorConfig {}); let knn_index = KnnIndex::Hnsw; - let enable_schema = false; + let enable_schema = true; let frontend_config = FrontendConfig { allow_reset, @@ -252,12 +252,13 @@ impl Bindings { #[allow(clippy::too_many_arguments)] #[pyo3( - signature = (name, configuration_json_str, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) + signature = (name, configuration_json_str = None, schema_str = None, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) )] fn create_collection( &self, name: String, configuration_json_str: Option, + schema_str: Option, metadata: Option, get_or_create: bool, tenant: String, @@ -265,9 +266,8 @@ impl Bindings { ) -> ChromaPyResult { let configuration_json = match configuration_json_str { Some(configuration_json_str) => { - let configuration_json = - serde_json::from_str::(&configuration_json_str) - .map_err(WrappedSerdeJsonError::SerdeJsonError)?; + let configuration_json = serde_json::from_str(&configuration_json_str) + .map_err(WrappedSerdeJsonError::SerdeJsonError)?; Some(configuration_json) } @@ -291,13 +291,20 @@ impl Bindings { )?), }; + let schema = match schema_str { + Some(schema_str) => { + serde_json::from_str(&schema_str).map_err(WrappedSerdeJsonError::SerdeJsonError)? + } + None => None, + }; + let request = CreateCollectionRequest::try_new( tenant, database, name, metadata, configuration, - None, + schema, get_or_create, )?; diff --git a/rust/segment/src/distributed_spann.rs b/rust/segment/src/distributed_spann.rs index f71abf2c142..e6f5fceb9f9 100644 --- a/rust/segment/src/distributed_spann.rs +++ b/rust/segment/src/distributed_spann.rs @@ -15,6 +15,7 @@ use chroma_index::spann::types::{ use chroma_index::IndexUuid; use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter}; use chroma_types::Collection; +use chroma_types::KnnIndex; use chroma_types::Schema; use chroma_types::SchemaError; use chroma_types::SegmentUuid; @@ -114,6 +115,7 @@ impl SpannSegmentWriter { let reconciled_schema = Schema::reconcile_schema_and_config( collection.schema.as_ref(), Some(&collection.config), + KnnIndex::Spann, ) .map_err(SpannSegmentWriterError::InvalidSchema)?; @@ -619,8 +621,8 @@ mod test { use chroma_storage::{local::LocalStorage, Storage}; use chroma_types::{ Chunk, Collection, CollectionUuid, DatabaseUuid, InternalCollectionConfiguration, - InternalSpannConfiguration, LogRecord, Operation, OperationRecord, Schema, SegmentUuid, - SpannPostingList, + InternalSpannConfiguration, KnnIndex, LogRecord, Operation, OperationRecord, Schema, + SegmentUuid, SpannPostingList, }; use crate::{ @@ -688,7 +690,7 @@ mod test { ..Default::default() }; collection.schema = Some( - Schema::reconcile_schema_and_config(None, Some(&collection.config)) + Schema::reconcile_schema_and_config(None, Some(&collection.config), KnnIndex::Spann) .expect("Error reconciling schema for test collection"), ); @@ -925,7 +927,7 @@ mod test { ..Default::default() }; collection.schema = Some( - Schema::reconcile_schema_and_config(None, Some(&collection.config)) + Schema::reconcile_schema_and_config(None, Some(&collection.config), KnnIndex::Spann) .expect("Error reconciling schema for test collection"), ); @@ -1087,7 +1089,7 @@ mod test { ..Default::default() }; collection.schema = Some( - Schema::reconcile_schema_and_config(None, Some(&collection.config)) + Schema::reconcile_schema_and_config(None, Some(&collection.config), KnnIndex::Spann) .expect("Error reconciling schema for test collection"), ); @@ -1218,7 +1220,7 @@ mod test { ..Default::default() }; collection.schema = Some( - Schema::reconcile_schema_and_config(None, Some(&collection.config)) + Schema::reconcile_schema_and_config(None, Some(&collection.config), KnnIndex::Spann) .expect("Error reconciling schema for test collection"), ); diff --git a/rust/segment/src/local_hnsw.rs b/rust/segment/src/local_hnsw.rs index c81041df73e..206e47ebfde 100644 --- a/rust/segment/src/local_hnsw.rs +++ b/rust/segment/src/local_hnsw.rs @@ -106,8 +106,11 @@ impl LocalHnswSegmentReader { sql_db: SqliteDb, ) -> Result { let hnsw_configuration = collection - .config - .get_hnsw_config_with_legacy_fallback(segment)? + .schema + .as_ref() + .map(|schema| schema.get_internal_hnsw_config_with_legacy_fallback(segment)) + .transpose()? + .flatten() .ok_or(LocalHnswSegmentReaderError::MissingHnswConfiguration)?; match persist_root { @@ -490,8 +493,11 @@ impl LocalHnswSegmentWriter { sql_db: SqliteDb, ) -> Result { let hnsw_configuration = collection - .config - .get_hnsw_config_with_legacy_fallback(segment)? + .schema + .as_ref() + .map(|schema| schema.get_internal_hnsw_config_with_legacy_fallback(segment)) + .transpose()? + .flatten() .ok_or(LocalHnswSegmentWriterError::MissingHnswConfiguration)?; match persist_root { diff --git a/rust/segment/src/test.rs b/rust/segment/src/test.rs index 6919e07347c..3c12f94cd8e 100644 --- a/rust/segment/src/test.rs +++ b/rust/segment/src/test.rs @@ -20,10 +20,10 @@ use chroma_types::{ }, plan::{Count, Get, Knn}, test_segment, BooleanOperator, Chunk, Collection, CollectionAndSegments, CompositeExpression, - DocumentExpression, DocumentOperator, LogRecord, Metadata, MetadataComparison, + DocumentExpression, DocumentOperator, KnnIndex, LogRecord, Metadata, MetadataComparison, MetadataExpression, MetadataSetValue, MetadataValue, Operation, OperationRecord, - PrimitiveOperator, Segment, SegmentScope, SegmentUuid, SetOperator, UpdateMetadata, Where, - CHROMA_KEY, + PrimitiveOperator, Schema, Segment, SegmentScope, SegmentUuid, SetOperator, UpdateMetadata, + Where, CHROMA_KEY, }; use regex::Regex; use std::collections::BinaryHeap; @@ -48,7 +48,8 @@ pub struct TestDistributedSegment { impl TestDistributedSegment { pub async fn new_with_dimension(dimension: usize) -> Self { - let collection = Collection::test_collection(dimension as i32); + let mut collection = Collection::test_collection(dimension as i32); + collection.schema = Some(Schema::new_default(KnnIndex::Hnsw)); let collection_uuid = collection.collection_id; let (blockfile_dir, blockfile_provider) = test_arrow_blockfile_provider(2 << 22); let (hnsw_dir, hnsw_provider) = test_hnsw_index_provider(); diff --git a/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql b/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql new file mode 100644 index 00000000000..055fe9147a1 --- /dev/null +++ b/rust/sqlite/migrations/sysdb/00010-collection-schema.sqlite.sql @@ -0,0 +1,2 @@ +-- Stores collection schema as stringified json +ALTER TABLE collections ADD COLUMN schema_str TEXT; diff --git a/rust/sqlite/src/table.rs b/rust/sqlite/src/table.rs index bb16b6a5862..8af31e49e1c 100644 --- a/rust/sqlite/src/table.rs +++ b/rust/sqlite/src/table.rs @@ -43,6 +43,7 @@ pub enum Collections { Dimension, DatabaseId, ConfigJsonStr, + SchemaStr, } #[derive(Iden)] diff --git a/rust/sysdb/src/sqlite.rs b/rust/sysdb/src/sqlite.rs index 9b23982ab53..1b13fca4b4b 100644 --- a/rust/sysdb/src/sqlite.rs +++ b/rust/sysdb/src/sqlite.rs @@ -13,8 +13,8 @@ use chroma_types::{ DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse, InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError, - Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType, - SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse, + Metadata, MetadataValue, ResetError, ResetResponse, Schema, SchemaError, Segment, SegmentScope, + SegmentType, SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse, }; use futures::TryStreamExt; use sea_query_binder::SqlxBinder; @@ -250,7 +250,8 @@ impl SqliteSysDb { collection_id: CollectionUuid, name: String, segments: Vec, - configuration: InternalCollectionConfiguration, + configuration: Option, + schema: Option, metadata: Option, dimension: Option, get_or_create: bool, @@ -304,16 +305,34 @@ impl SqliteSysDb { let database_uuid = DatabaseUuid::from_str(database_id) .map_err(|_| CreateCollectionError::DatabaseIdParseError)?; + let configuration_json_str = match configuration { + Some(configuration) => serde_json::to_string(&configuration) + .map_err(CreateCollectionError::Configuration)?, + None => "{}".to_string(), + }; + + let schema_json = schema + .as_ref() + .map(|schema| { + serde_json::to_string(schema).map_err(|e| { + CreateCollectionError::Schema(SchemaError::InvalidSchema { + reason: e.to_string(), + }) + }) + }) + .transpose()?; + sqlx::query( r#" INSERT INTO collections - (id, name, config_json_str, dimension, database_id) - VALUES ($1, $2, $3, $4, $5) + (id, name, config_json_str, schema_str, dimension, database_id) + VALUES ($1, $2, $3, $4, $5, $6) "#, ) .bind(collection_id.to_string()) .bind(&name) - .bind(serde_json::to_string(&configuration).map_err(CreateCollectionError::Configuration)?) + .bind(configuration_json_str.clone()) + .bind(schema_json) .bind(dimension) .bind(database_id) .execute(&mut *tx) @@ -345,9 +364,10 @@ impl SqliteSysDb { name, tenant, database, - config: configuration, + config: serde_json::from_str(&configuration_json_str) + .map_err(CreateCollectionError::Configuration)?, metadata, - schema: None, + schema, dimension, log_position: 0, total_records_post_compaction: 0, @@ -378,18 +398,29 @@ impl SqliteSysDb { .map_err(|e| UpdateCollectionError::Internal(e.into()))?; let mut configuration_json_str = None; + let mut schema_str = None; if let Some(configuration) = configuration { let collections = self .get_collections_with_conn(&mut *tx, Some(collection_id), None, None, None, None, 0) .await; let collections = collections.unwrap(); let collection = collections.into_iter().next().unwrap(); - let mut existing_configuration = collection.config; - existing_configuration.update(&configuration); - configuration_json_str = Some( - serde_json::to_string(&existing_configuration) - .map_err(UpdateCollectionError::Configuration)?, - ); + // if schema exists, update schema instead of configuration + if collection.schema.is_some() { + let mut existing_schema = collection.schema.unwrap(); + existing_schema.update(&configuration); + schema_str = Some( + serde_json::to_string(&existing_schema) + .map_err(UpdateCollectionError::Schema)?, + ); + } else { + let mut existing_configuration = collection.config; + existing_configuration.update(&configuration); + configuration_json_str = Some( + serde_json::to_string(&existing_configuration) + .map_err(UpdateCollectionError::Configuration)?, + ); + } } if name.is_some() || dimension.is_some() { @@ -436,6 +467,25 @@ impl SqliteSysDb { return Err(UpdateCollectionError::NotFound(collection_id.to_string())); } } + + if let Some(schema_str) = schema_str { + let mut query = sea_query::Query::update(); + let mut query = query.table(table::Collections::Table).cond_where( + sea_query::Expr::col((table::Collections::Table, table::Collections::Id)) + .eq(collection_id.to_string()), + ); + query = query.value(table::Collections::SchemaStr, schema_str); + + let (sql, values) = query.build_sqlx(sea_query::SqliteQueryBuilder); + + let result = sqlx::query_with(&sql, values) + .execute(&mut *tx) + .await + .map_err(|e| UpdateCollectionError::Internal(e.into()))?; + if result.rows_affected() == 0 { + return Err(UpdateCollectionError::NotFound(collection_id.to_string())); + } + } if let Some(metadata) = metadata { delete_metadata::(&mut *tx, collection_id.to_string()) .await @@ -685,6 +735,7 @@ impl SqliteSysDb { .column((table::Collections::Table, table::Collections::ConfigJsonStr)) .column((table::Collections::Table, table::Collections::Dimension)) .column((table::Collections::Table, table::Collections::DatabaseId)) + .column((table::Collections::Table, table::Collections::SchemaStr)) .inner_join( table::Databases::Table, sea_query::Expr::col((table::Databases::Table, table::Databases::Id)) @@ -739,6 +790,7 @@ impl SqliteSysDb { .column((table::Databases::Table, table::Databases::TenantId)) .column((table::Databases::Table, table::Databases::Name)) .column((table::Collections::Table, table::Collections::DatabaseId)) + .column((table::Collections::Table, table::Collections::SchemaStr)) .columns([ table::CollectionMetadata::Key, table::CollectionMetadata::StrValue, @@ -788,6 +840,18 @@ impl SqliteSysDb { } None => InternalCollectionConfiguration::default_hnsw(), }; + let schema = match first_row.get::, _>(7) { + Some(json_str) if !json_str.trim().is_empty() && json_str.trim() != "null" => { + match serde_json::from_str::(json_str) + .map_err(GetCollectionsError::Schema) + { + Ok(schema) => Some(schema), + Err(e) => return Some(Err(e)), + } + } + None => None, + _ => None, + }; let database_id = match DatabaseUuid::from_str(first_row.get(6)) { Ok(db_id) => db_id, Err(_) => return Some(Err(GetCollectionsError::DatabaseId)), @@ -796,7 +860,7 @@ impl SqliteSysDb { Some(Ok(Collection { collection_id, config: configuration, - schema: None, + schema, metadata, total_records_post_compaction: 0, version: 0, @@ -1112,7 +1176,7 @@ mod tests { use super::*; use chroma_sqlite::db::test_utils::get_new_sqlite_db; use chroma_types::{ - InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid, + InternalUpdateCollectionConfiguration, KnnIndex, SegmentScope, SegmentType, SegmentUuid, UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue, UpdateVectorIndexConfiguration, VectorIndexConfiguration, }; @@ -1294,7 +1358,8 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), Some(collection_metadata.clone()), None, false, @@ -1337,7 +1402,8 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, false, @@ -1354,7 +1420,8 @@ mod tests { collection_id, "test_collection".to_string(), segments, - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, false, @@ -1384,7 +1451,8 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, false, @@ -1401,7 +1469,8 @@ mod tests { CollectionUuid::new(), "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, true, @@ -1424,7 +1493,8 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + None, None, None, false, @@ -1497,7 +1567,8 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, false, @@ -1578,7 +1649,8 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), Some(collection_metadata.clone()), None, false, @@ -1628,7 +1700,8 @@ mod tests { collection_id, "test_collection".to_string(), segments.clone(), - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), Some(collection_metadata.clone()), None, false, @@ -1658,7 +1731,8 @@ mod tests { collection_id, "test_collection".to_string(), vec![], - InternalCollectionConfiguration::default_hnsw(), + Some(InternalCollectionConfiguration::default_hnsw()), + Some(Schema::new_default(KnnIndex::Hnsw)), None, None, false, diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index a035be2df68..15008f6be8e 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -330,7 +330,8 @@ impl SysDb { collection_id, name, segments, - configuration.unwrap_or(InternalCollectionConfiguration::default_hnsw()), + configuration, + schema.clone(), metadata, dimension, get_or_create, diff --git a/rust/types/src/api_types.rs b/rust/types/src/api_types.rs index f407f0ff6d9..b0c9405c503 100644 --- a/rust/types/src/api_types.rs +++ b/rust/types/src/api_types.rs @@ -782,11 +782,13 @@ pub enum GetCollectionsError { #[error(transparent)] Internal(#[from] Box), #[error("Could not deserialize configuration")] - Configuration(#[from] serde_json::Error), + Configuration(#[source] serde_json::Error), #[error("Could not deserialize collection ID")] CollectionId(#[from] uuid::Error), #[error("Could not deserialize database ID")] DatabaseId, + #[error("Could not deserialize schema")] + Schema(#[source] serde_json::Error), } impl ChromaError for GetCollectionsError { @@ -797,6 +799,7 @@ impl ChromaError for GetCollectionsError { GetCollectionsError::Configuration(_) => ErrorCodes::Internal, GetCollectionsError::CollectionId(_) => ErrorCodes::Internal, GetCollectionsError::DatabaseId => ErrorCodes::Internal, + GetCollectionsError::Schema(_) => ErrorCodes::Internal, } } } @@ -913,13 +916,15 @@ pub enum UpdateCollectionError { #[error("Metadata reset unsupported")] MetadataResetUnsupported, #[error("Could not serialize configuration")] - Configuration(#[from] serde_json::Error), + Configuration(#[source] serde_json::Error), #[error(transparent)] Internal(#[from] Box), #[error("Could not parse config: {0}")] InvalidConfig(#[from] CollectionConfigurationToInternalConfigurationError), #[error("SPANN is still in development. Not allowed to created spann indexes")] SpannNotImplemented, + #[error("Could not serialize schema: {0}")] + Schema(#[source] serde_json::Error), } impl ChromaError for UpdateCollectionError { @@ -931,6 +936,7 @@ impl ChromaError for UpdateCollectionError { UpdateCollectionError::Internal(err) => err.code(), UpdateCollectionError::InvalidConfig(_) => ErrorCodes::InvalidArgument, UpdateCollectionError::SpannNotImplemented => ErrorCodes::InvalidArgument, + UpdateCollectionError::Schema(_) => ErrorCodes::Internal, } } } diff --git a/rust/types/src/collection.rs b/rust/types/src/collection.rs index 91a079e5318..4189c0a6abe 100644 --- a/rust/types/src/collection.rs +++ b/rust/types/src/collection.rs @@ -2,8 +2,8 @@ use std::str::FromStr; use super::{Metadata, MetadataValueConversionError}; use crate::{ - chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration, Schema, - SchemaError, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata, + chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration, KnnIndex, + Schema, SchemaError, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata, }; use chroma_error::{ChromaError, ErrorCodes}; use serde::{Deserialize, Serialize}; @@ -12,7 +12,7 @@ use thiserror::Error; use uuid::Uuid; #[cfg(feature = "pyo3")] -use pyo3::types::PyAnyMethods; +use pyo3::{exceptions::PyValueError, types::PyAnyMethods}; /// CollectionUuid is a wrapper around Uuid to provide a type for the collection id. #[derive( @@ -183,6 +183,24 @@ impl Collection { Ok(res) } + #[getter] + fn schema<'py>( + &self, + py: pyo3::Python<'py>, + ) -> pyo3::PyResult>> { + match self.schema.as_ref() { + Some(schema) => { + let schema_json = serde_json::to_string(schema) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + let res = pyo3::prelude::PyModule::import(py, "json")? + .getattr("loads")? + .call1((schema_json,))?; + Ok(Some(res)) + } + None => Ok(None), + } + } + #[getter] pub fn name(&self) -> &str { &self.name @@ -211,9 +229,12 @@ impl Collection { impl Collection { /// Reconcile the collection schema and configuration, ensuring both are consistent. - pub fn reconcile_schema_with_config(&mut self) -> Result<(), SchemaError> { - let reconciled_schema = - Schema::reconcile_schema_and_config(self.schema.as_ref(), Some(&self.config))?; + pub fn reconcile_schema_with_config(&mut self, knn_index: KnnIndex) -> Result<(), SchemaError> { + let reconciled_schema = Schema::reconcile_schema_and_config( + self.schema.as_ref(), + Some(&self.config), + knn_index, + )?; self.config = InternalCollectionConfiguration::try_from(&reconciled_schema) .map_err(|reason| SchemaError::InvalidSchema { reason })?; diff --git a/rust/types/src/collection_configuration.rs b/rust/types/src/collection_configuration.rs index e31d44b66d5..80a24a19003 100644 --- a/rust/types/src/collection_configuration.rs +++ b/rust/types/src/collection_configuration.rs @@ -292,30 +292,40 @@ impl InternalCollectionConfiguration { (Some(_), Some(_)) => Err(CollectionConfigurationToInternalConfigurationError::MultipleVectorIndexConfigurations), (Some(hnsw), None) => { match default_knn_index { - // Create a spann index. Only inherit the space if it exists in the hnsw config. + // Create a spann index. Only inherit the space if it exists in the hnsw config or legacy metadata. // This is for backwards compatibility so that users who migrate to distributed // from local don't break their code. KnnIndex::Spann => { - let internal_config = if let Some(space) = hnsw.space { - InternalSpannConfiguration { - space, - ..Default::default() - } - } else { - InternalSpannConfiguration::default() + let mut hnsw: InternalHnswConfiguration = hnsw.into(); + let temp_config = InternalCollectionConfiguration { + vector_index: VectorIndexConfiguration::Hnsw(hnsw.clone()), + embedding_function: None, + }; + let hnsw_params = temp_config.get_hnsw_config_from_legacy_metadata(&metadata)?; + if let Some(hnsw_params) = hnsw_params { + hnsw = hnsw_params; + } + let spann_config = InternalSpannConfiguration { + space: hnsw.space, + ..Default::default() }; Ok(InternalCollectionConfiguration { - vector_index: VectorIndexConfiguration::Spann(internal_config), + vector_index: VectorIndexConfiguration::Spann(spann_config), embedding_function: value.embedding_function, }) }, KnnIndex::Hnsw => { let hnsw: InternalHnswConfiguration = hnsw.into(); - Ok(InternalCollectionConfiguration { - vector_index: hnsw.into(), + let mut internal_config = InternalCollectionConfiguration { + vector_index: VectorIndexConfiguration::Hnsw(hnsw), embedding_function: value.embedding_function, - }) + }; + let hnsw_params = internal_config.get_hnsw_config_from_legacy_metadata(&metadata)?; + if let Some(hnsw_params) = hnsw_params { + internal_config.vector_index = VectorIndexConfiguration::Hnsw(hnsw_params); + } + Ok(internal_config) } } } @@ -578,8 +588,10 @@ impl TryFrom for InternalUpdateCollectionConfigur #[cfg(test)] mod tests { + use crate::collection_schema::Schema; use crate::hnsw_configuration::HnswConfiguration; use crate::hnsw_configuration::Space; + use crate::metadata::MetadataValue; use crate::spann_configuration::SpannConfiguration; use crate::{test_segment, CollectionUuid, Metadata}; @@ -633,6 +645,59 @@ mod tests { assert_eq!(overridden_config.ef_construction, 2); } + #[test] + fn metadata_populates_config_when_not_set() { + let mut metadata = Metadata::new(); + metadata.insert("hnsw:sync_threshold".to_string(), MetadataValue::Int(10)); + metadata.insert("hnsw:batch_size".to_string(), MetadataValue::Int(7)); + + let config = InternalCollectionConfiguration::try_from_config( + CollectionConfiguration { + hnsw: None, + spann: None, + embedding_function: None, + }, + KnnIndex::Hnsw, + Some(metadata), + ) + .expect("config from metadata should succeed"); + + match config.vector_index { + VectorIndexConfiguration::Hnsw(hnsw) => { + assert_eq!(hnsw.sync_threshold, 10); + assert_eq!(hnsw.batch_size, 7); + } + _ => panic!("expected HNSW configuration"), + } + } + + #[test] + fn schema_reconcile_preserves_metadata_overrides() { + let mut metadata = Metadata::new(); + metadata.insert("hnsw:sync_threshold".to_string(), MetadataValue::Int(10)); + metadata.insert("hnsw:batch_size".to_string(), MetadataValue::Int(7)); + + let config = InternalCollectionConfiguration::try_from_config( + CollectionConfiguration { + hnsw: None, + spann: None, + embedding_function: None, + }, + KnnIndex::Hnsw, + Some(metadata), + ) + .expect("config from metadata should succeed"); + + let schema = Schema::reconcile_schema_and_config(None, Some(&config), KnnIndex::Hnsw) + .expect("schema reconcile should succeed"); + + let hnsw_config = schema + .get_internal_hnsw_config() + .expect("schema should contain hnsw config"); + assert_eq!(hnsw_config.sync_threshold, 10); + assert_eq!(hnsw_config.batch_size, 7); + } + #[test] fn test_hnsw_config_with_hnsw_default() { let hnsw_config = HnswConfiguration { diff --git a/rust/types/src/collection_schema.rs b/rust/types/src/collection_schema.rs index 651ce60a5f2..e4c8c467aaf 100644 --- a/rust/types/src/collection_schema.rs +++ b/rust/types/src/collection_schema.rs @@ -5,7 +5,8 @@ use thiserror::Error; use validator::Validate; use crate::collection_configuration::{ - EmbeddingFunctionConfiguration, InternalCollectionConfiguration, VectorIndexConfiguration, + EmbeddingFunctionConfiguration, InternalCollectionConfiguration, + UpdateVectorIndexConfiguration, VectorIndexConfiguration, }; use crate::hnsw_configuration::Space; use crate::metadata::{MetadataComparison, MetadataValueType, Where}; @@ -18,7 +19,8 @@ use crate::{ default_search_ef_spann, default_search_nprobe, default_search_rng_epsilon, default_search_rng_factor, default_space, default_split_threshold, default_sync_threshold, default_write_nprobe, default_write_rng_epsilon, default_write_rng_factor, - InternalSpannConfiguration, KnnIndex, + HnswParametersFromSegmentError, InternalHnswConfiguration, InternalSpannConfiguration, + InternalUpdateCollectionConfiguration, KnnIndex, Segment, }; impl ChromaError for SchemaError { @@ -140,6 +142,101 @@ pub struct Schema { pub keys: HashMap, } +impl Schema { + pub fn update(&mut self, configuration: &InternalUpdateCollectionConfiguration) { + if let Some(vector_update) = &configuration.vector_index { + if let Some(default_vector_index) = self.defaults_vector_index_mut() { + Self::apply_vector_index_update(default_vector_index, vector_update); + } + if let Some(embedding_vector_index) = self.embedding_vector_index_mut() { + Self::apply_vector_index_update(embedding_vector_index, vector_update); + } + } + + if let Some(embedding_function) = configuration.embedding_function.as_ref() { + if let Some(default_vector_index) = self.defaults_vector_index_mut() { + default_vector_index.config.embedding_function = Some(embedding_function.clone()); + } + if let Some(embedding_vector_index) = self.embedding_vector_index_mut() { + embedding_vector_index.config.embedding_function = Some(embedding_function.clone()); + } + } + } + + fn defaults_vector_index_mut(&mut self) -> Option<&mut VectorIndexType> { + self.defaults + .float_list + .as_mut() + .and_then(|float_list| float_list.vector_index.as_mut()) + } + + fn embedding_vector_index_mut(&mut self) -> Option<&mut VectorIndexType> { + self.keys + .get_mut(EMBEDDING_KEY) + .and_then(|value_types| value_types.float_list.as_mut()) + .and_then(|float_list| float_list.vector_index.as_mut()) + } + + fn apply_vector_index_update( + vector_index: &mut VectorIndexType, + update: &UpdateVectorIndexConfiguration, + ) { + match update { + UpdateVectorIndexConfiguration::Hnsw(Some(hnsw_update)) => { + if let Some(hnsw_config) = vector_index.config.hnsw.as_mut() { + if let Some(ef_search) = hnsw_update.ef_search { + hnsw_config.ef_search = Some(ef_search); + } + if let Some(max_neighbors) = hnsw_update.max_neighbors { + hnsw_config.max_neighbors = Some(max_neighbors); + } + if let Some(num_threads) = hnsw_update.num_threads { + hnsw_config.num_threads = Some(num_threads); + } + if let Some(resize_factor) = hnsw_update.resize_factor { + hnsw_config.resize_factor = Some(resize_factor); + } + if let Some(sync_threshold) = hnsw_update.sync_threshold { + hnsw_config.sync_threshold = Some(sync_threshold); + } + if let Some(batch_size) = hnsw_update.batch_size { + hnsw_config.batch_size = Some(batch_size); + } + } + } + UpdateVectorIndexConfiguration::Hnsw(None) => {} + UpdateVectorIndexConfiguration::Spann(Some(spann_update)) => { + if let Some(spann_config) = vector_index.config.spann.as_mut() { + if let Some(search_nprobe) = spann_update.search_nprobe { + spann_config.search_nprobe = Some(search_nprobe); + } + if let Some(ef_search) = spann_update.ef_search { + spann_config.ef_search = Some(ef_search); + } + } + } + UpdateVectorIndexConfiguration::Spann(None) => {} + } + } + + pub fn is_sparse_index_enabled(&self) -> bool { + let defaults_enabled = self + .defaults + .sparse_vector + .as_ref() + .and_then(|sv| sv.sparse_vector_index.as_ref()) + .is_some_and(|idx| idx.enabled); + let key_enabled = self.keys.values().any(|value_types| { + value_types + .sparse_vector + .as_ref() + .and_then(|sv| sv.sparse_vector_index.as_ref()) + .is_some_and(|idx| idx.enabled) + }); + defaults_enabled || key_enabled + } +} + impl Default for Schema { /// Create a default Schema that matches Python's behavior exactly. /// @@ -655,28 +752,76 @@ impl Schema { }) } + pub fn get_internal_hnsw_config(&self) -> Option { + let to_internal = |vector_index: &VectorIndexType| { + if vector_index.config.spann.is_some() { + return None; + } + let space = vector_index.config.space.as_ref(); + let hnsw_config = vector_index.config.hnsw.as_ref(); + Some((space, hnsw_config).into()) + }; + + self.keys + .get(EMBEDDING_KEY) + .and_then(|value_types| value_types.float_list.as_ref()) + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(to_internal) + .or_else(|| { + self.defaults + .float_list + .as_ref() + .and_then(|float_list| float_list.vector_index.as_ref()) + .and_then(to_internal) + }) + } + + pub fn get_internal_hnsw_config_with_legacy_fallback( + &self, + segment: &Segment, + ) -> Result, HnswParametersFromSegmentError> { + if let Some(config) = self.get_internal_hnsw_config() { + let config_from_metadata = + InternalHnswConfiguration::from_legacy_segment_metadata(&segment.metadata)?; + + if config == InternalHnswConfiguration::default() && config != config_from_metadata { + return Ok(Some(config_from_metadata)); + } + + return Ok(Some(config)); + } + + Ok(None) + } + /// Reconcile user-provided schema with system defaults /// /// This method merges user configurations with system defaults, ensuring that: /// - User overrides take precedence over defaults /// - Missing user configurations fall back to system defaults /// - Field-level merging for complex configurations (Vector, HNSW, SPANN, etc.) - pub fn reconcile_with_defaults(user_schema: Option<&Schema>) -> Result { - let default_schema = Schema::new_default(KnnIndex::Spann); + pub fn reconcile_with_defaults( + user_schema: Option<&Schema>, + knn_index: KnnIndex, + ) -> Result { + let default_schema = Schema::new_default(knn_index); match user_schema { Some(user) => { // Merge defaults with user overrides let merged_defaults = - Self::merge_value_types(&default_schema.defaults, &user.defaults)?; + Self::merge_value_types(&default_schema.defaults, &user.defaults, knn_index)?; // Merge key overrides let mut merged_keys = default_schema.keys.clone(); for (key, user_value_types) in &user.keys { if let Some(default_value_types) = merged_keys.get(key) { // Merge with existing default key override - let merged_value_types = - Self::merge_value_types(default_value_types, user_value_types)?; + let merged_value_types = Self::merge_value_types( + default_value_types, + user_value_types, + knn_index, + )?; merged_keys.insert(key.clone(), merged_value_types); } else { // New key override from user @@ -884,10 +1029,14 @@ impl Schema { fn merge_value_types( default: &ValueTypes, user: &ValueTypes, + knn_index: KnnIndex, ) -> Result { // Merge float_list first - let float_list = - Self::merge_float_list_type(default.float_list.as_ref(), user.float_list.as_ref()); + let float_list = Self::merge_float_list_type( + default.float_list.as_ref(), + user.float_list.as_ref(), + knn_index, + ); // Validate the merged float_list (covers all merge cases) if let Some(ref fl) = float_list { @@ -987,12 +1136,14 @@ impl Schema { fn merge_float_list_type( default: Option<&FloatListValueType>, user: Option<&FloatListValueType>, + knn_index: KnnIndex, ) -> Option { match (default, user) { (Some(default), Some(user)) => Some(FloatListValueType { vector_index: Self::merge_vector_index_type( default.vector_index.as_ref(), user.vector_index.as_ref(), + knn_index, ), }), (Some(default), None) => Some(default.clone()), @@ -1100,11 +1251,12 @@ impl Schema { fn merge_vector_index_type( default: Option<&VectorIndexType>, user: Option<&VectorIndexType>, + knn_index: KnnIndex, ) -> Option { match (default, user) { (Some(default), Some(user)) => Some(VectorIndexType { enabled: user.enabled, - config: Self::merge_vector_index_config(&default.config, &user.config), + config: Self::merge_vector_index_config(&default.config, &user.config, knn_index), }), (Some(default), None) => Some(default.clone()), (None, Some(user)) => Some(user.clone()), @@ -1145,16 +1297,29 @@ impl Schema { fn merge_vector_index_config( default: &VectorIndexConfig, user: &VectorIndexConfig, + knn_index: KnnIndex, ) -> VectorIndexConfig { - VectorIndexConfig { - space: user.space.clone().or(default.space.clone()), - embedding_function: user - .embedding_function - .clone() - .or(default.embedding_function.clone()), - source_key: user.source_key.clone().or(default.source_key.clone()), - hnsw: Self::merge_hnsw_configs(default.hnsw.as_ref(), user.hnsw.as_ref()), - spann: Self::merge_spann_configs(default.spann.as_ref(), user.spann.as_ref()), + match knn_index { + KnnIndex::Hnsw => VectorIndexConfig { + space: user.space.clone().or(default.space.clone()), + embedding_function: user + .embedding_function + .clone() + .or(default.embedding_function.clone()), + source_key: user.source_key.clone().or(default.source_key.clone()), + hnsw: Self::merge_hnsw_configs(default.hnsw.as_ref(), user.hnsw.as_ref()), + spann: None, + }, + KnnIndex::Spann => VectorIndexConfig { + space: user.space.clone().or(default.space.clone()), + embedding_function: user + .embedding_function + .clone() + .or(default.embedding_function.clone()), + source_key: user.source_key.clone().or(default.source_key.clone()), + hnsw: None, + spann: Self::merge_spann_configs(default.spann.as_ref(), user.spann.as_ref()), + }, } } @@ -1262,6 +1427,7 @@ impl Schema { pub fn reconcile_schema_and_config( schema: Option<&Schema>, configuration: Option<&InternalCollectionConfiguration>, + knn_index: KnnIndex, ) -> Result { // Early validation: check if both user-provided schema and config are non-default if let (Some(user_schema), Some(config)) = (schema, configuration) { @@ -1270,7 +1436,7 @@ impl Schema { } } - let reconciled_schema = Self::reconcile_with_defaults(schema)?; + let reconciled_schema = Self::reconcile_with_defaults(schema, knn_index)?; if let Some(config) = configuration { Self::reconcile_with_collection_config(&reconciled_schema, config) } else { @@ -2502,7 +2668,7 @@ mod tests { #[test] fn test_reconcile_with_defaults_none_user_schema() { // Test that when no user schema is provided, we get the default schema - let result = Schema::reconcile_with_defaults(None).unwrap(); + let result = Schema::reconcile_with_defaults(None, KnnIndex::Spann).unwrap(); let expected = Schema::new_default(KnnIndex::Spann); assert_eq!(result, expected); } @@ -2515,7 +2681,7 @@ mod tests { keys: HashMap::new(), }; - let result = Schema::reconcile_with_defaults(Some(&user_schema)).unwrap(); + let result = Schema::reconcile_with_defaults(Some(&user_schema), KnnIndex::Spann).unwrap(); let expected = Schema::new_default(KnnIndex::Spann); assert_eq!(result, expected); } @@ -2536,7 +2702,7 @@ mod tests { fts_index: None, }); - let result = Schema::reconcile_with_defaults(Some(&user_schema)).unwrap(); + let result = Schema::reconcile_with_defaults(Some(&user_schema), KnnIndex::Spann).unwrap(); // Check that the user override took precedence assert!( @@ -2587,13 +2753,21 @@ mod tests { // Use HNSW defaults for this test so we have HNSW config to merge with let result = { let default_schema = Schema::new_default(KnnIndex::Hnsw); - let merged_defaults = - Schema::merge_value_types(&default_schema.defaults, &user_schema.defaults).unwrap(); + let merged_defaults = Schema::merge_value_types( + &default_schema.defaults, + &user_schema.defaults, + KnnIndex::Hnsw, + ) + .unwrap(); let mut merged_keys = default_schema.keys.clone(); for (key, user_value_types) in user_schema.keys { if let Some(default_value_types) = merged_keys.get(&key) { - let merged_value_types = - Schema::merge_value_types(default_value_types, &user_value_types).unwrap(); + let merged_value_types = Schema::merge_value_types( + default_value_types, + &user_value_types, + KnnIndex::Hnsw, + ) + .unwrap(); merged_keys.insert(key, merged_value_types); } else { merged_keys.insert(key, user_value_types); @@ -2658,7 +2832,7 @@ mod tests { .keys .insert("custom_key".to_string(), custom_key_types); - let result = Schema::reconcile_with_defaults(Some(&user_schema)).unwrap(); + let result = Schema::reconcile_with_defaults(Some(&user_schema), KnnIndex::Spann).unwrap(); // Check that default key overrides are preserved assert!(result.keys.contains_key(EMBEDDING_KEY)); @@ -2707,7 +2881,7 @@ mod tests { .keys .insert(EMBEDDING_KEY.to_string(), embedding_override); - let result = Schema::reconcile_with_defaults(Some(&user_schema)).unwrap(); + let result = Schema::reconcile_with_defaults(Some(&user_schema), KnnIndex::Spann).unwrap(); let embedding_config = result.keys.get(EMBEDDING_KEY).unwrap(); let vector_config = &embedding_config @@ -3155,7 +3329,8 @@ mod tests { }), // Add SPANN config }; - let result = Schema::merge_vector_index_config(&default_config, &user_config); + let result = + Schema::merge_vector_index_config(&default_config, &user_config, KnnIndex::Hnsw); // Check field-level merging assert_eq!(result.space, Some(Space::L2)); // User override @@ -3169,9 +3344,8 @@ mod tests { assert_eq!(result.hnsw.as_ref().unwrap().ef_construction, Some(300)); // User override assert_eq!(result.hnsw.as_ref().unwrap().max_neighbors, Some(16)); // Default preserved - // Check SPANN was added from user - assert!(result.spann.is_some()); - assert_eq!(result.spann.as_ref().unwrap().search_nprobe, Some(15)); + // Check SPANN is not present, since merging in the context of HNSW + assert!(result.spann.is_none()); } #[test] @@ -3259,13 +3433,21 @@ mod tests { // Use HNSW defaults for this test so we have HNSW config to merge with let result = { let default_schema = Schema::new_default(KnnIndex::Hnsw); - let merged_defaults = - Schema::merge_value_types(&default_schema.defaults, &user_schema.defaults).unwrap(); + let merged_defaults = Schema::merge_value_types( + &default_schema.defaults, + &user_schema.defaults, + KnnIndex::Hnsw, + ) + .unwrap(); let mut merged_keys = default_schema.keys.clone(); for (key, user_value_types) in user_schema.keys { if let Some(default_value_types) = merged_keys.get(&key) { - let merged_value_types = - Schema::merge_value_types(default_value_types, &user_value_types).unwrap(); + let merged_value_types = Schema::merge_value_types( + default_value_types, + &user_value_types, + KnnIndex::Hnsw, + ) + .unwrap(); merged_keys.insert(key, merged_value_types); } else { merged_keys.insert(key, user_value_types); @@ -3382,7 +3564,11 @@ mod tests { } // Use reconcile_schema_and_config which has the early validation - let result = Schema::reconcile_schema_and_config(Some(&schema), Some(&collection_config)); + let result = Schema::reconcile_schema_and_config( + Some(&schema), + Some(&collection_config), + KnnIndex::Spann, + ); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index 63860210a00..7afa4f349d7 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -375,10 +375,18 @@ impl Handler> for KnnFilterOrchestrator { .ok_or_terminate( self.collection_and_segments .collection - .config - .get_hnsw_config_with_legacy_fallback( - &self.collection_and_segments.vector_segment, - ), + .schema + .as_ref() + .ok_or(KnnError::InvalidSchema(SchemaError::InvalidSchema { + reason: "Schema is None".to_string(), + })) + .and_then(|schema| { + schema + .get_internal_hnsw_config_with_legacy_fallback( + &self.collection_and_segments.vector_segment, + ) + .map_err(KnnError::from) + }), ctx, ) .await