Skip to content

Commit 703a271

Browse files
committed
[ENH]: add create_collection() to Rust client
1 parent 3c52c9b commit 703a271

File tree

5 files changed

+200
-59
lines changed

5 files changed

+200
-59
lines changed

rust/chroma/src/client/chroma_client.rs

Lines changed: 185 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use backon::ExponentialBuilder;
22
use backon::Retryable;
3+
use chroma_types::Collection;
34
use parking_lot::Mutex;
45
use reqwest::Method;
56
use reqwest::StatusCode;
@@ -8,7 +9,10 @@ use std::sync::Arc;
89
use thiserror::Error;
910

1011
use crate::client::ChromaClientOptions;
11-
use crate::types::{GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest};
12+
use crate::collection::ChromaCollection;
13+
use crate::types::{
14+
CreateCollectionRequest, GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest,
15+
};
1216

1317
const USER_AGENT: &str = concat!(
1418
"Chroma Rust Client v",
@@ -32,7 +36,7 @@ pub struct ChromaClient {
3236
client: reqwest::Client,
3337
retry_policy: ExponentialBuilder,
3438
tenant_id: Arc<Mutex<Option<String>>>,
35-
default_database_id: Arc<Mutex<Option<String>>>,
39+
default_database_name: Arc<Mutex<Option<String>>>,
3640
resolve_tenant_or_database_lock: Arc<tokio::sync::Mutex<()>>,
3741
#[cfg(feature = "opentelemetry")]
3842
metrics: crate::client::metrics::Metrics,
@@ -45,7 +49,7 @@ impl Clone for ChromaClient {
4549
client: self.client.clone(),
4650
retry_policy: self.retry_policy,
4751
tenant_id: Arc::new(Mutex::new(self.tenant_id.lock().clone())),
48-
default_database_id: Arc::new(Mutex::new(self.default_database_id.lock().clone())),
52+
default_database_name: Arc::new(Mutex::new(self.default_database_name.lock().clone())),
4953
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
5054
#[cfg(feature = "opentelemetry")]
5155
metrics: self.metrics.clone(),
@@ -76,16 +80,16 @@ impl ChromaClient {
7680
client,
7781
retry_policy: options.retry_options.into(),
7882
tenant_id: Arc::new(Mutex::new(options.tenant_id)),
79-
default_database_id: Arc::new(Mutex::new(options.default_database_id)),
83+
default_database_name: Arc::new(Mutex::new(options.default_database_name)),
8084
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
8185
#[cfg(feature = "opentelemetry")]
8286
metrics: crate::client::metrics::Metrics::new(),
8387
}
8488
}
8589

86-
pub fn set_default_database_id(&self, database_id: String) {
87-
let mut lock = self.default_database_id.lock();
88-
*lock = Some(database_id);
90+
pub fn set_default_database_name(&self, database_name: String) {
91+
let mut lock = self.default_database_name.lock();
92+
*lock = Some(database_name);
8993
}
9094

9195
pub async fn create_database(&self, name: String) -> Result<(), ChromaClientError> {
@@ -155,79 +159,135 @@ impl ChromaClient {
155159
.await
156160
}
157161

162+
pub async fn get_or_create_collection(
163+
&self,
164+
params: CreateCollectionRequest,
165+
) -> Result<ChromaCollection, ChromaClientError> {
166+
self.common_create_collection(params, true).await
167+
}
168+
169+
pub async fn create_collection(
170+
&self,
171+
params: CreateCollectionRequest,
172+
) -> Result<ChromaCollection, ChromaClientError> {
173+
self.common_create_collection(params, false).await
174+
}
175+
158176
pub async fn list_collections(
159177
&self,
160178
params: ListCollectionsRequest,
161-
) -> Result<Vec<String>, ChromaClientError> {
179+
) -> Result<Vec<ChromaCollection>, ChromaClientError> {
162180
let tenant_id = self.get_tenant_id().await?;
163-
let database_id = self.get_database_id(params.database_id).await?;
181+
let database_name = self.get_database_name(params.database_name).await?;
164182

165183
#[derive(Serialize)]
166184
struct QueryParams {
167185
limit: usize,
168186
offset: Option<usize>,
169187
}
170188

171-
self.send::<(), _, _>(
172-
"list_collections",
173-
Method::GET,
174-
format!(
175-
"/api/v2/tenants/{}/databases/{}/collections",
176-
tenant_id, database_id
177-
),
178-
None,
179-
Some(QueryParams {
180-
limit: params.limit,
181-
offset: params.offset,
182-
}),
183-
)
184-
.await
189+
let collections = self
190+
.send::<(), _, Vec<Collection>>(
191+
"list_collections",
192+
Method::GET,
193+
format!(
194+
"/api/v2/tenants/{}/databases/{}/collections",
195+
tenant_id, database_name
196+
),
197+
None,
198+
Some(QueryParams {
199+
limit: params.limit,
200+
offset: params.offset,
201+
}),
202+
)
203+
.await?;
204+
205+
Ok(collections
206+
.into_iter()
207+
.map(|collection| ChromaCollection {
208+
client: self.clone(),
209+
collection: Arc::new(collection),
210+
})
211+
.collect())
212+
}
213+
214+
async fn common_create_collection(
215+
&self,
216+
params: CreateCollectionRequest,
217+
get_or_create: bool,
218+
) -> Result<ChromaCollection, ChromaClientError> {
219+
let tenant_id = self.get_tenant_id().await?;
220+
let database_name = self.get_database_name(params.database_name).await?;
221+
222+
let collection: chroma_types::Collection = self
223+
.send(
224+
"create_collection",
225+
Method::POST,
226+
format!(
227+
"/api/v2/tenants/{}/databases/{}/collections",
228+
tenant_id, database_name
229+
),
230+
Some(serde_json::json!({
231+
"name": params.name,
232+
"configuration": params.configuration,
233+
"metadata": params.metadata,
234+
"get_or_create": get_or_create,
235+
})),
236+
None::<()>,
237+
)
238+
.await?;
239+
240+
Ok(ChromaCollection {
241+
client: self.clone(),
242+
collection: Arc::new(collection),
243+
})
185244
}
186245

187-
async fn get_database_id(
246+
async fn get_database_name(
188247
&self,
189-
id_override: Option<String>,
248+
name_override: Option<String>,
190249
) -> Result<String, ChromaClientError> {
191-
if let Some(id) = id_override {
250+
if let Some(id) = name_override {
192251
return Ok(id);
193252
}
194253

195254
{
196-
let database_id_lock = self.default_database_id.lock();
197-
if let Some(database_id) = &*database_id_lock {
198-
return Ok(database_id.clone());
255+
let database_name_lock = self.default_database_name.lock();
256+
if let Some(database_name) = &*database_name_lock {
257+
return Ok(database_name.clone());
199258
}
200259
}
201260

202261
let _guard = self.resolve_tenant_or_database_lock.lock().await;
203262

204263
{
205-
let database_id_lock = self.default_database_id.lock();
206-
if let Some(database_id) = &*database_id_lock {
207-
return Ok(database_id.clone());
264+
let database_name_lock = self.default_database_name.lock();
265+
if let Some(database_name) = &*database_name_lock {
266+
return Ok(database_name.clone());
208267
}
209268
}
210269

211270
let identity = self.get_auth_identity().await?;
212271

213272
if identity.databases.len() > 1 {
214273
return Err(ChromaClientError::CouldNotResolveDatabaseId(
215-
"Client has access to multiple databases; please provide a database_id".to_string(),
274+
"Client has access to multiple databases; please provide a database_name"
275+
.to_string(),
216276
));
217277
}
218278

219-
let database_id = identity.databases.first().ok_or_else(|| {
279+
let database_name = identity.databases.first().ok_or_else(|| {
220280
ChromaClientError::CouldNotResolveDatabaseId(
221281
"Client has access to no databases".to_string(),
222282
)
223283
})?;
224284

225285
{
226-
let mut database_id_lock = self.default_database_id.lock();
227-
*database_id_lock = Some(database_id.clone());
286+
let mut database_name_lock = self.default_database_name.lock();
287+
*database_name_lock = Some(database_name.clone());
228288
}
229289

230-
Ok(database_id.clone())
290+
Ok(database_name.clone())
231291
}
232292

233293
async fn get_tenant_id(&self) -> Result<String, ChromaClientError> {
@@ -287,7 +347,7 @@ impl ChromaClient {
287347
#[cfg(feature = "opentelemetry")]
288348
let started_at = std::time::Instant::now();
289349

290-
let response = request.send().await?;
350+
let response = request.send().await.map_err(|err| (err, None))?;
291351

292352
#[cfg(feature = "opentelemetry")]
293353
{
@@ -302,13 +362,16 @@ impl ChromaClient {
302362
let _ = operation_name;
303363
}
304364

305-
response.error_for_status_ref()?;
306-
Ok::<_, reqwest::Error>(response)
365+
if let Err(err) = response.error_for_status_ref() {
366+
return Err((err, Some(response)));
367+
}
368+
369+
Ok::<reqwest::Response, (reqwest::Error, Option<reqwest::Response>)>(response)
307370
};
308371

309372
let response = attempt
310373
.retry(&self.retry_policy)
311-
.notify(|err, _| {
374+
.notify(|(err, _), _| {
312375
tracing::warn!(
313376
url = %url,
314377
method =? method,
@@ -319,13 +382,33 @@ impl ChromaClient {
319382
#[cfg(feature = "opentelemetry")]
320383
self.metrics.increment_retry(operation_name);
321384
})
322-
.when(|err| {
385+
.when(|(err, _)| {
323386
err.status()
324387
.map(|status| status == StatusCode::TOO_MANY_REQUESTS)
325388
.unwrap_or_default()
326389
|| method == Method::GET
327390
})
328-
.await?;
391+
.await;
392+
393+
let response = match response {
394+
Ok(response) => response,
395+
Err((err, maybe_response)) => {
396+
if let Some(response) = maybe_response {
397+
let json = response.json::<serde_json::Value>().await?;
398+
399+
if tracing::enabled!(tracing::Level::TRACE) {
400+
tracing::trace!(
401+
url = %url,
402+
method =? method,
403+
"Received response: {}",
404+
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "<failed to serialize>".to_string())
405+
);
406+
}
407+
}
408+
409+
return Err(ChromaClientError::RequestError(err));
410+
}
411+
};
329412

330413
let json = response.json::<serde_json::Value>().await?;
331414

@@ -389,22 +472,15 @@ mod tests {
389472
// Create isolated database for test
390473
let database_name = format!("test_db_{}", uuid::Uuid::new_v4());
391474
client.create_database(database_name.clone()).await.unwrap();
392-
let databases = client.list_databases().await.unwrap();
393-
let database_id = databases
394-
.iter()
395-
.find(|db| db.name == database_name)
396-
.unwrap()
397-
.id
398-
.clone();
399-
client.set_default_database_id(database_id.clone());
475+
client.set_default_database_name(database_name.clone());
400476

401477
let result = std::panic::AssertUnwindSafe(callback(client.clone()))
402478
.catch_unwind()
403479
.await;
404480

405481
// Delete test database
406-
if let Err(err) = client.delete_database(database_name).await {
407-
tracing::error!("Failed to delete test database {}: {}", database_id, err);
482+
if let Err(err) = client.delete_database(database_name.clone()).await {
483+
tracing::error!("Failed to delete test database {}: {}", database_name, err);
408484
}
409485

410486
result.unwrap();
@@ -548,7 +624,62 @@ mod tests {
548624
.unwrap();
549625
assert!(collections.is_empty());
550626

551-
// todo: create collection and assert it's returned, test limit/offset
627+
client
628+
.create_collection(
629+
CreateCollectionRequest::builder()
630+
.name("first".to_string())
631+
.build(),
632+
)
633+
.await
634+
.unwrap();
635+
636+
client
637+
.create_collection(
638+
CreateCollectionRequest::builder()
639+
.name("second".to_string())
640+
.build(),
641+
)
642+
.await
643+
.unwrap();
644+
645+
let collections = client
646+
.list_collections(ListCollectionsRequest::builder().build())
647+
.await
648+
.unwrap();
649+
assert_eq!(collections.len(), 2);
650+
651+
let collections = client
652+
.list_collections(ListCollectionsRequest::builder().limit(1).offset(1).build())
653+
.await
654+
.unwrap();
655+
assert_eq!(collections.len(), 1);
656+
assert_eq!(collections[0].collection.name, "second");
657+
})
658+
.await;
659+
}
660+
661+
#[tokio::test]
662+
#[test_log::test]
663+
async fn test_live_cloud_create_collection() {
664+
with_client(|client| async move {
665+
let collection = client
666+
.create_collection(
667+
CreateCollectionRequest::builder()
668+
.name("foo".to_string())
669+
.build(),
670+
)
671+
.await
672+
.unwrap();
673+
assert_eq!(collection.collection.name, "foo");
674+
675+
client
676+
.get_or_create_collection(
677+
CreateCollectionRequest::builder()
678+
.name("foo".to_string())
679+
.build(),
680+
)
681+
.await
682+
.unwrap();
552683
})
553684
.await;
554685
}

rust/chroma/src/client/options.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub struct ChromaClientOptions {
6464
/// Will be automatically resolved at request time if not provided
6565
pub tenant_id: Option<String>,
6666
/// Will be automatically resolved at request time if not provided. It can only be resolved automatically if this client has access to exactly one database.
67-
pub default_database_id: Option<String>,
67+
pub default_database_name: Option<String>,
6868
}
6969

7070
impl Default for ChromaClientOptions {
@@ -74,7 +74,7 @@ impl Default for ChromaClientOptions {
7474
auth_method: ChromaAuthMethod::None,
7575
retry_options: ChromaRetryOptions::default(),
7676
tenant_id: None,
77-
default_database_id: None,
77+
default_database_name: None,
7878
}
7979
}
8080
}

0 commit comments

Comments
 (0)