Skip to content

Commit 4170cd0

Browse files
committed
[ENH] Add local support for schema
1 parent a1642ad commit 4170cd0

File tree

9 files changed

+73
-18
lines changed

9 files changed

+73
-18
lines changed

chromadb/api/rust.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def list_collections(
193193
CollectionModel(
194194
id=collection.id,
195195
name=collection.name,
196-
serialized_schema=None,
196+
serialized_schema=collection.schema,
197197
configuration_json=collection.configuration,
198198
metadata=collection.metadata,
199199
dimension=collection.dimension,
@@ -229,14 +229,25 @@ def create_collection(
229229
else:
230230
configuration_json_str = None
231231

232+
if schema:
233+
schema_str = schema.serialize_to_json()
234+
else:
235+
schema_str = None
236+
232237
collection = self.bindings.create_collection(
233-
name, configuration_json_str, metadata, get_or_create, tenant, database
238+
name,
239+
configuration_json_str,
240+
schema_str,
241+
metadata,
242+
get_or_create,
243+
tenant,
244+
database,
234245
)
235246
collection_model = CollectionModel(
236247
id=collection.id,
237248
name=collection.name,
238249
configuration_json=collection.configuration,
239-
serialized_schema=None,
250+
serialized_schema=collection.schema,
240251
metadata=collection.metadata,
241252
dimension=collection.dimension,
242253
tenant=collection.tenant,
@@ -256,7 +267,7 @@ def get_collection(
256267
id=collection.id,
257268
name=collection.name,
258269
configuration_json=collection.configuration,
259-
serialized_schema=None,
270+
serialized_schema=collection.schema,
260271
metadata=collection.metadata,
261272
dimension=collection.dimension,
262273
tenant=collection.tenant,

chromadb/chromadb_rust_bindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Bindings:
9595
self,
9696
name: str,
9797
configuration_json_str: Optional[str] = None,
98+
schema_str: Optional[str] = None,
9899
metadata: Optional[CollectionMetadata] = None,
99100
get_or_create: bool = False,
100101
tenant: str = DEFAULT_TENANT,

rust/frontend/src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ fn default_enable_span_indexing() -> bool {
142142
}
143143

144144
fn default_enable_schema() -> bool {
145-
false
145+
true
146146
}
147147

148148
pub fn default_min_records_for_task() -> u64 {

rust/python_bindings/src/bindings.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,22 +252,22 @@ impl Bindings {
252252

253253
#[allow(clippy::too_many_arguments)]
254254
#[pyo3(
255-
signature = (name, configuration_json_str, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string())
255+
signature = (name, configuration_json_str = None, schema_str = None, metadata = None, get_or_create = false, tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string())
256256
)]
257257
fn create_collection(
258258
&self,
259259
name: String,
260260
configuration_json_str: Option<String>,
261+
schema_str: Option<String>,
261262
metadata: Option<Metadata>,
262263
get_or_create: bool,
263264
tenant: String,
264265
database: String,
265266
) -> ChromaPyResult<Collection> {
266267
let configuration_json = match configuration_json_str {
267268
Some(configuration_json_str) => {
268-
let configuration_json =
269-
serde_json::from_str::<CollectionConfiguration>(&configuration_json_str)
270-
.map_err(WrappedSerdeJsonError::SerdeJsonError)?;
269+
let configuration_json = serde_json::from_str(&configuration_json_str)
270+
.map_err(WrappedSerdeJsonError::SerdeJsonError)?;
271271

272272
Some(configuration_json)
273273
}
@@ -291,13 +291,20 @@ impl Bindings {
291291
)?),
292292
};
293293

294+
let schema = match schema_str {
295+
Some(schema_str) => {
296+
serde_json::from_str(&schema_str).map_err(WrappedSerdeJsonError::SerdeJsonError)?
297+
}
298+
None => None,
299+
};
300+
294301
let request = CreateCollectionRequest::try_new(
295302
tenant,
296303
database,
297304
name,
298305
metadata,
299306
configuration,
300-
None,
307+
schema,
301308
get_or_create,
302309
)?;
303310

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-- Stores collection schema as stringified json
2+
ALTER TABLE collections ADD COLUMN schema_str TEXT;

rust/sqlite/src/table.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub enum Collections {
4343
Dimension,
4444
DatabaseId,
4545
ConfigJsonStr,
46+
SchemaStr,
4647
}
4748

4849
#[derive(Iden)]

rust/sysdb/src/sqlite.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use chroma_types::{
1313
DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError,
1414
GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse,
1515
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError,
16-
Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
17-
SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse,
16+
Metadata, MetadataValue, ResetError, ResetResponse, Schema, SchemaError, Segment, SegmentScope,
17+
SegmentType, SegmentUuid, UpdateCollectionError, UpdateTenantError, UpdateTenantResponse,
1818
};
1919
use futures::TryStreamExt;
2020
use sea_query_binder::SqlxBinder;
@@ -251,6 +251,7 @@ impl SqliteSysDb {
251251
name: String,
252252
segments: Vec<Segment>,
253253
configuration: InternalCollectionConfiguration,
254+
schema: Option<Schema>,
254255
metadata: Option<Metadata>,
255256
dimension: Option<i32>,
256257
get_or_create: bool,
@@ -307,13 +308,18 @@ impl SqliteSysDb {
307308
sqlx::query(
308309
r#"
309310
INSERT INTO collections
310-
(id, name, config_json_str, dimension, database_id)
311-
VALUES ($1, $2, $3, $4, $5)
311+
(id, name, config_json_str, schema_str, dimension, database_id)
312+
VALUES ($1, $2, $3, $4, $5, $6)
312313
"#,
313314
)
314315
.bind(collection_id.to_string())
315316
.bind(&name)
316317
.bind(serde_json::to_string(&configuration).map_err(CreateCollectionError::Configuration)?)
318+
.bind(serde_json::to_string(&schema).map_err(|e| {
319+
CreateCollectionError::Schema(SchemaError::InvalidSchema {
320+
reason: e.to_string(),
321+
})
322+
})?)
317323
.bind(dimension)
318324
.bind(database_id)
319325
.execute(&mut *tx)
@@ -347,7 +353,7 @@ impl SqliteSysDb {
347353
database,
348354
config: configuration,
349355
metadata,
350-
schema: None,
356+
schema,
351357
dimension,
352358
log_position: 0,
353359
total_records_post_compaction: 0,
@@ -685,6 +691,7 @@ impl SqliteSysDb {
685691
.column((table::Collections::Table, table::Collections::ConfigJsonStr))
686692
.column((table::Collections::Table, table::Collections::Dimension))
687693
.column((table::Collections::Table, table::Collections::DatabaseId))
694+
.column((table::Collections::Table, table::Collections::SchemaStr))
688695
.inner_join(
689696
table::Databases::Table,
690697
sea_query::Expr::col((table::Databases::Table, table::Databases::Id))
@@ -739,6 +746,7 @@ impl SqliteSysDb {
739746
.column((table::Databases::Table, table::Databases::TenantId))
740747
.column((table::Databases::Table, table::Databases::Name))
741748
.column((table::Collections::Table, table::Collections::DatabaseId))
749+
.column((table::Collections::Table, table::Collections::SchemaStr))
742750
.columns([
743751
table::CollectionMetadata::Key,
744752
table::CollectionMetadata::StrValue,
@@ -788,6 +796,17 @@ impl SqliteSysDb {
788796
}
789797
None => InternalCollectionConfiguration::default_hnsw(),
790798
};
799+
let schema = match first_row.get::<Option<&str>, _>(7) {
800+
Some(json_str) => {
801+
match serde_json::from_str::<Schema>(json_str)
802+
.map_err(GetCollectionsError::Schema)
803+
{
804+
Ok(schema) => Some(schema),
805+
Err(e) => return Some(Err(e)),
806+
}
807+
}
808+
None => None,
809+
};
791810
let database_id = match DatabaseUuid::from_str(first_row.get(6)) {
792811
Ok(db_id) => db_id,
793812
Err(_) => return Some(Err(GetCollectionsError::DatabaseId)),
@@ -796,7 +815,7 @@ impl SqliteSysDb {
796815
Some(Ok(Collection {
797816
collection_id,
798817
config: configuration,
799-
schema: None,
818+
schema,
800819
metadata,
801820
total_records_post_compaction: 0,
802821
version: 0,
@@ -1112,7 +1131,7 @@ mod tests {
11121131
use super::*;
11131132
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
11141133
use chroma_types::{
1115-
InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid,
1134+
InternalUpdateCollectionConfiguration, KnnIndex, SegmentScope, SegmentType, SegmentUuid,
11161135
UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue,
11171136
UpdateVectorIndexConfiguration, VectorIndexConfiguration,
11181137
};
@@ -1295,6 +1314,7 @@ mod tests {
12951314
"test_collection".to_string(),
12961315
segments.clone(),
12971316
InternalCollectionConfiguration::default_hnsw(),
1317+
Some(Schema::new_default(KnnIndex::Hnsw)),
12981318
Some(collection_metadata.clone()),
12991319
None,
13001320
false,
@@ -1338,6 +1358,7 @@ mod tests {
13381358
"test_collection".to_string(),
13391359
segments.clone(),
13401360
InternalCollectionConfiguration::default_hnsw(),
1361+
Some(Schema::new_default(KnnIndex::Hnsw)),
13411362
None,
13421363
None,
13431364
false,
@@ -1355,6 +1376,7 @@ mod tests {
13551376
"test_collection".to_string(),
13561377
segments,
13571378
InternalCollectionConfiguration::default_hnsw(),
1379+
Some(Schema::new_default(KnnIndex::Hnsw)),
13581380
None,
13591381
None,
13601382
false,
@@ -1385,6 +1407,7 @@ mod tests {
13851407
"test_collection".to_string(),
13861408
segments.clone(),
13871409
InternalCollectionConfiguration::default_hnsw(),
1410+
Some(Schema::new_default(KnnIndex::Hnsw)),
13881411
None,
13891412
None,
13901413
false,
@@ -1402,6 +1425,7 @@ mod tests {
14021425
"test_collection".to_string(),
14031426
vec![],
14041427
InternalCollectionConfiguration::default_hnsw(),
1428+
Some(Schema::new_default(KnnIndex::Hnsw)),
14051429
None,
14061430
None,
14071431
true,
@@ -1425,6 +1449,7 @@ mod tests {
14251449
"test_collection".to_string(),
14261450
vec![],
14271451
InternalCollectionConfiguration::default_hnsw(),
1452+
Some(Schema::new_default(KnnIndex::Hnsw)),
14281453
None,
14291454
None,
14301455
false,
@@ -1498,6 +1523,7 @@ mod tests {
14981523
"test_collection".to_string(),
14991524
vec![],
15001525
InternalCollectionConfiguration::default_hnsw(),
1526+
Some(Schema::new_default(KnnIndex::Hnsw)),
15011527
None,
15021528
None,
15031529
false,
@@ -1579,6 +1605,7 @@ mod tests {
15791605
"test_collection".to_string(),
15801606
segments.clone(),
15811607
InternalCollectionConfiguration::default_hnsw(),
1608+
Some(Schema::new_default(KnnIndex::Hnsw)),
15821609
Some(collection_metadata.clone()),
15831610
None,
15841611
false,
@@ -1629,6 +1656,7 @@ mod tests {
16291656
"test_collection".to_string(),
16301657
segments.clone(),
16311658
InternalCollectionConfiguration::default_hnsw(),
1659+
Some(Schema::new_default(KnnIndex::Hnsw)),
16321660
Some(collection_metadata.clone()),
16331661
None,
16341662
false,
@@ -1659,6 +1687,7 @@ mod tests {
16591687
"test_collection".to_string(),
16601688
vec![],
16611689
InternalCollectionConfiguration::default_hnsw(),
1690+
Some(Schema::new_default(KnnIndex::Hnsw)),
16621691
None,
16631692
None,
16641693
false,

rust/sysdb/src/sysdb.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ impl SysDb {
328328
name,
329329
segments,
330330
configuration.unwrap_or(InternalCollectionConfiguration::default_hnsw()),
331+
schema.clone(),
331332
metadata,
332333
dimension,
333334
get_or_create,

rust/types/src/api_types.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,11 +782,13 @@ pub enum GetCollectionsError {
782782
#[error(transparent)]
783783
Internal(#[from] Box<dyn ChromaError>),
784784
#[error("Could not deserialize configuration")]
785-
Configuration(#[from] serde_json::Error),
785+
Configuration(#[source] serde_json::Error),
786786
#[error("Could not deserialize collection ID")]
787787
CollectionId(#[from] uuid::Error),
788788
#[error("Could not deserialize database ID")]
789789
DatabaseId,
790+
#[error("Could not deserialize schema")]
791+
Schema(#[source] serde_json::Error),
790792
}
791793

792794
impl ChromaError for GetCollectionsError {
@@ -797,6 +799,7 @@ impl ChromaError for GetCollectionsError {
797799
GetCollectionsError::Configuration(_) => ErrorCodes::Internal,
798800
GetCollectionsError::CollectionId(_) => ErrorCodes::Internal,
799801
GetCollectionsError::DatabaseId => ErrorCodes::Internal,
802+
GetCollectionsError::Schema(_) => ErrorCodes::Internal,
800803
}
801804
}
802805
}

0 commit comments

Comments
 (0)