Skip to content

Commit a43db01

Browse files
committed
[ENH] recognize and flush new metadata keys to schema on local compaction
1 parent df8c8c8 commit a43db01

File tree

4 files changed

+150
-22
lines changed

4 files changed

+150
-22
lines changed

chromadb/test/api/test_schema_e2e.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def test_schema_defaults_enable_indexed_operations(
363363
# Ensure underlying schema persisted across fetches
364364
reloaded = client.get_collection(collection.name)
365365
assert reloaded.schema is not None
366-
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
366+
if not is_spann_disabled_mode:
367+
assert reloaded.schema.serialize_to_json() == schema.serialize_to_json()
367368

368369

369370
def test_get_or_create_and_get_collection_preserve_schema(
@@ -541,7 +542,8 @@ def test_schema_persistence_with_custom_overrides(
541542
reloaded_client = client_factories.create_client_from_system()
542543
reloaded_collection = reloaded_client.get_collection(name=collection.name)
543544
assert reloaded_collection.schema is not None
544-
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545+
if not is_spann_disabled_mode:
546+
assert reloaded_collection.schema.serialize_to_json() == expected_schema_json
545547

546548
fetched = reloaded_collection.get(where={"title": "Schema Persistence"})
547549
assert set(fetched["ids"]) == {"persist-1"}
@@ -784,7 +786,6 @@ def _expect_disabled_error(operation: Callable[[], Any]) -> None:
784786
_expect_disabled_error(operation)
785787

786788

787-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
788789
def test_schema_discovers_new_keys_after_compaction(
789790
client_factories: "ClientFactories",
790791
) -> None:
@@ -802,7 +803,8 @@ def test_schema_discovers_new_keys_after_compaction(
802803

803804
collection.add(ids=ids, documents=documents, metadatas=metadatas)
804805

805-
wait_for_version_increase(client, collection.name, initial_version)
806+
if not is_spann_disabled_mode:
807+
wait_for_version_increase(client, collection.name, initial_version)
806808

807809
reloaded = client.get_collection(collection.name)
808810
assert reloaded.schema is not None
@@ -828,7 +830,8 @@ def test_schema_discovers_new_keys_after_compaction(
828830
metadatas=upsert_metadatas,
829831
)
830832

831-
wait_for_version_increase(client, collection.name, next_version)
833+
if not is_spann_disabled_mode:
834+
wait_for_version_increase(client, collection.name, next_version)
832835

833836
post_upsert = client.get_collection(collection.name)
834837
assert post_upsert.schema is not None
@@ -852,7 +855,6 @@ def test_schema_discovers_new_keys_after_compaction(
852855
assert "discover_upsert" in persisted.schema.keys
853856

854857

855-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
856858
def test_schema_rejects_conflicting_discoverable_key_types(
857859
client_factories: "ClientFactories",
858860
) -> None:
@@ -868,7 +870,8 @@ def test_schema_rejects_conflicting_discoverable_key_types(
868870
documents = [f"doc {i}" for i in range(251)]
869871
collection.add(ids=ids, documents=documents, metadatas=metadatas)
870872

871-
wait_for_version_increase(client, collection.name, initial_version)
873+
if not is_spann_disabled_mode:
874+
wait_for_version_increase(client, collection.name, initial_version)
872875

873876
collection.upsert(
874877
ids=["conflict-bad"],
@@ -1029,7 +1032,6 @@ def test_schema_embedding_configuration_enforced(
10291032
assert "sparse_auto" not in numeric_metadata
10301033

10311034

1032-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
10331035
def test_schema_precedence_for_overrides_discoverables_and_defaults(
10341036
client_factories: "ClientFactories",
10351037
) -> None:
@@ -1054,7 +1056,9 @@ def test_schema_precedence_for_overrides_discoverables_and_defaults(
10541056

10551057
initial_version = get_collection_version(client, collection.name)
10561058
collection.add(ids=ids, documents=documents, metadatas=metadatas)
1057-
wait_for_version_increase(client, collection.name, initial_version)
1059+
1060+
if not is_spann_disabled_mode:
1061+
wait_for_version_increase(client, collection.name, initial_version)
10581062

10591063
schema_state = client.get_collection(collection.name).schema
10601064
assert schema_state is not None

rust/log/src/local_compaction_manager.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
140140
.sysdb
141141
.get_collection_with_segments(message.collection_id)
142142
.await?;
143+
let schema_previously_persisted = collection_and_segments.collection.schema.is_some();
143144
collection_and_segments
144145
.collection
145146
.reconcile_schema_with_config(KnnIndex::Hnsw)?;
@@ -206,17 +207,33 @@ impl Handler<BackfillMessage> for LocalCompactionManager {
206207
.begin()
207208
.await
208209
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
209-
metadata_writer
210+
let apply_outcome = metadata_writer
210211
.apply_logs(
211212
mt_data_chunk,
212213
collection_and_segments.metadata_segment.id,
214+
if schema_previously_persisted {
215+
collection_and_segments.collection.schema.clone()
216+
} else {
217+
None
218+
},
213219
&mut *tx,
214220
)
215221
.await
216222
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
217223
tx.commit()
218224
.await
219225
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
226+
if schema_previously_persisted {
227+
if let Some(updated_schema) = apply_outcome.schema_update {
228+
metadata_writer
229+
.update_collection_schema(
230+
collection_and_segments.collection.collection_id,
231+
&updated_schema,
232+
)
233+
.await
234+
.map_err(|_| CompactionManagerError::MetadataApplyLogsFailed)?;
235+
}
236+
}
220237
// Next apply it to the hnsw writer.
221238
let mut hnsw_writer = self
222239
.hnsw_segment_manager

rust/segment/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ roaring = { workspace = true }
1111
sea-query = { workspace = true }
1212
sea-query-binder = { workspace = true, features = ["sqlx-sqlite"] }
1313
serde = { workspace = true }
14+
serde_json = { workspace = true }
1415
sqlx = { workspace = true }
1516
serde-pickle = "1.2.0"
1617
tantivy = { workspace = true }

rust/segment/src/sqlite_metadata.rs

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ use chroma_error::{ChromaError, ErrorCodes};
77
use chroma_sqlite::{
88
db::SqliteDb,
99
helpers::{delete_metadata, update_metadata},
10-
table::{EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
10+
table::{Collections, EmbeddingFulltextSearch, EmbeddingMetadata, Embeddings, MaxSeqId},
1111
};
1212
use chroma_types::{
1313
operator::{
1414
CountResult, Filter, GetResult, Limit, Projection, ProjectionOutput, ProjectionRecord, Scan,
1515
},
1616
plan::{Count, Get},
17-
BooleanOperator, Chunk, CompositeExpression, DocumentExpression, DocumentOperator, LogRecord,
18-
MetadataComparison, MetadataExpression, MetadataSetValue, MetadataValue,
19-
MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator, SegmentUuid,
20-
SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
17+
BooleanOperator, Chunk, CollectionUuid, CompositeExpression, DocumentExpression,
18+
DocumentOperator, LogRecord, MetadataComparison, MetadataExpression, MetadataSetValue,
19+
MetadataValue, MetadataValueConversionError, Operation, OperationRecord, PrimitiveOperator,
20+
Schema, SegmentUuid, SetOperator, UpdateMetadataValue, Where, CHROMA_DOCUMENT_KEY,
2121
};
2222
use sea_query::{
2323
Alias, DeleteStatement, Expr, ExprTrait, Func, InsertStatement, LikeExpr, OnConflict, Query,
@@ -41,6 +41,8 @@ pub enum SqliteMetadataError {
4141
SeaQuery(#[from] sea_query::error::Error),
4242
#[error(transparent)]
4343
Sqlx(#[from] sqlx::Error),
44+
#[error("Could not serialize schema: {0}")]
45+
SerializeSchema(#[from] serde_json::Error),
4446
}
4547

4648
impl ChromaError for SqliteMetadataError {
@@ -53,6 +55,11 @@ pub struct SqliteMetadataWriter {
5355
pub db: SqliteDb,
5456
}
5557

58+
pub struct ApplyLogsOutcome {
59+
pub schema_update: Option<Schema>,
60+
pub max_seq_id: Option<u64>,
61+
}
62+
5663
impl SqliteMetadataWriter {
5764
pub fn new(db: SqliteDb) -> Self {
5865
Self { db }
@@ -278,19 +285,67 @@ impl SqliteMetadataWriter {
278285
Ok(self.db.get_conn().begin().await?)
279286
}
280287

288+
pub async fn update_collection_schema(
289+
&self,
290+
collection_id: CollectionUuid,
291+
schema: &Schema,
292+
) -> Result<(), SqliteMetadataError> {
293+
let schema_str = serde_json::to_string(schema)?;
294+
let (sql, values) = Query::update()
295+
.table(Collections::Table)
296+
.value(Collections::SchemaStr, schema_str)
297+
.and_where(
298+
Expr::col((Collections::Table, Collections::Id)).eq(collection_id.to_string()),
299+
)
300+
.build_sqlx(SqliteQueryBuilder);
301+
sqlx::query_with(&sql, values)
302+
.execute(self.db.get_conn())
303+
.await?;
304+
Ok(())
305+
}
306+
307+
fn ensure_schema_for_update_value(
308+
schema: &mut Option<Schema>,
309+
key: &str,
310+
value: &UpdateMetadataValue,
311+
) -> bool {
312+
if key == CHROMA_DOCUMENT_KEY {
313+
return false;
314+
}
315+
match value {
316+
UpdateMetadataValue::None => false,
317+
_ => {
318+
if let Some(schema_mut) = schema.as_mut() {
319+
if let Ok(metadata_value) = MetadataValue::try_from(value) {
320+
return schema_mut
321+
.ensure_key_from_metadata(key, metadata_value.value_type());
322+
}
323+
}
324+
false
325+
}
326+
}
327+
}
328+
281329
pub async fn apply_logs<C>(
282330
&self,
283331
logs: Chunk<LogRecord>,
284332
segment_id: SegmentUuid,
333+
schema: Option<Schema>,
285334
tx: &mut C,
286-
) -> Result<(), SqliteMetadataError>
335+
) -> Result<ApplyLogsOutcome, SqliteMetadataError>
287336
where
288337
for<'connection> &'connection mut C: sqlx::Executor<'connection, Database = sqlx::Sqlite>,
289338
{
290339
if logs.is_empty() {
291-
return Ok(());
340+
return Ok(ApplyLogsOutcome {
341+
schema_update: None,
342+
max_seq_id: None,
343+
});
292344
}
345+
let mut schema = schema;
346+
let mut schema_modified = false;
293347
let mut max_seq_id = u64::MIN;
348+
let mut saw_log = false;
294349
for (
295350
LogRecord {
296351
log_offset,
@@ -307,6 +362,7 @@ impl SqliteMetadataWriter {
307362
) in logs.iter()
308363
{
309364
let log_offset_unsigned = (*log_offset).try_into()?;
365+
saw_log = true;
310366
max_seq_id = max_seq_id.max(log_offset_unsigned);
311367
let mut metadata_owned = metadata.clone();
312368
if let Some(doc) = document {
@@ -323,6 +379,11 @@ impl SqliteMetadataWriter {
323379
Self::add_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
324380
{
325381
if let Some(meta) = metadata_owned {
382+
for (key, value) in meta.iter() {
383+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
384+
schema_modified = true;
385+
}
386+
}
326387
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
327388
}
328389

@@ -336,6 +397,11 @@ impl SqliteMetadataWriter {
336397
Self::update_record(tx, segment_id, log_offset_unsigned, id.clone()).await?
337398
{
338399
if let Some(meta) = metadata_owned {
400+
for (key, value) in meta.iter() {
401+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
402+
schema_modified = true;
403+
}
404+
}
339405
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
340406
}
341407

@@ -351,6 +417,11 @@ impl SqliteMetadataWriter {
351417
.await?;
352418

353419
if let Some(meta) = metadata_owned {
420+
for (key, value) in meta.iter() {
421+
if Self::ensure_schema_for_update_value(&mut schema, key, value) {
422+
schema_modified = true;
423+
}
424+
}
354425
update_metadata::<EmbeddingMetadata, _, _>(tx, offset_id, meta).await?;
355426
}
356427

@@ -371,7 +442,12 @@ impl SqliteMetadataWriter {
371442

372443
Self::upsert_max_seq_id(tx, segment_id, max_seq_id).await?;
373444

374-
Ok(())
445+
let max_seq_id = if saw_log { Some(max_seq_id) } else { None };
446+
447+
Ok(ApplyLogsOutcome {
448+
schema_update: if schema_modified { schema } else { None },
449+
max_seq_id,
450+
})
375451
}
376452
}
377453

@@ -910,7 +986,17 @@ mod tests {
910986
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
911987
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
912988
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
913-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
989+
runtime.block_on(sqlite_seg_writer.apply_logs(
990+
data,
991+
metadata_seg_id,
992+
test_data
993+
.collection_and_segments
994+
.collection
995+
.schema
996+
.clone(),
997+
&mut *tx,
998+
))
999+
.expect("Should be able to apply logs");
9141000
runtime.block_on(tx.commit()).expect("Should be able to commit log");
9151001

9161002
let sqlite_seg_reader = SqliteMetadataReader {
@@ -938,7 +1024,17 @@ mod tests {
9381024
ref_seg.apply_logs(test_data.logs.clone(), metadata_seg_id);
9391025
let mut tx = runtime.block_on(sqlite_seg_writer.begin()).expect("Should be able to start transaction");
9401026
let data: Chunk<LogRecord> = Chunk::new(test_data.logs.clone().into());
941-
runtime.block_on(sqlite_seg_writer.apply_logs(data, metadata_seg_id, &mut *tx)).expect("Should be able to apply logs");
1027+
runtime.block_on(sqlite_seg_writer.apply_logs(
1028+
data,
1029+
metadata_seg_id,
1030+
test_data
1031+
.collection_and_segments
1032+
.collection
1033+
.schema
1034+
.clone(),
1035+
&mut *tx,
1036+
))
1037+
.expect("Should be able to apply logs");
9421038
runtime.block_on(tx.commit()).expect("Should be able to commit log");
9431039

9441040
let sqlite_seg_reader = SqliteMetadataReader {
@@ -1020,7 +1116,12 @@ mod tests {
10201116
.expect("Should be able to start transaction");
10211117
let data: Chunk<LogRecord> = Chunk::new(logs.into());
10221118
sqlite_seg_writer
1023-
.apply_logs(data, metadata_seg_id, &mut *tx)
1119+
.apply_logs(
1120+
data,
1121+
metadata_seg_id,
1122+
collection_and_segments.collection.schema.clone(),
1123+
&mut *tx,
1124+
)
10241125
.await
10251126
.expect("Should be able to apply logs");
10261127
tx.commit().await.expect("Should be able to commit log");
@@ -1140,7 +1241,12 @@ mod tests {
11401241
.expect("Should be able to start transaction");
11411242
let data: Chunk<LogRecord> = Chunk::new(logs.into());
11421243
sqlite_seg_writer
1143-
.apply_logs(data, metadata_seg_id, &mut *tx)
1244+
.apply_logs(
1245+
data,
1246+
metadata_seg_id,
1247+
collection_and_segments.collection.schema.clone(),
1248+
&mut *tx,
1249+
)
11441250
.await
11451251
.expect("Should be able to apply logs");
11461252
tx.commit().await.expect("Should be able to commit log");

0 commit comments

Comments
 (0)