diff --git a/rust/chroma/src/client/chroma_client.rs b/rust/chroma/src/client/chroma_client.rs index 3aaca9a7d2e..9bb33be60cc 100644 --- a/rust/chroma/src/client/chroma_client.rs +++ b/rust/chroma/src/client/chroma_client.rs @@ -1,6 +1,7 @@ use backon::ExponentialBuilder; use backon::Retryable; use chroma_error::ChromaValidationError; +use chroma_types::Collection; use parking_lot::Mutex; use reqwest::Method; use reqwest::StatusCode; @@ -9,7 +10,10 @@ use std::sync::Arc; use thiserror::Error; use crate::client::ChromaClientOptions; -use crate::types::{GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest}; +use crate::collection::ChromaCollection; +use crate::types::{ + CreateCollectionRequest, GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest, +}; const USER_AGENT: &str = concat!( "Chroma Rust Client v", @@ -35,7 +39,7 @@ pub struct ChromaClient { client: reqwest::Client, retry_policy: ExponentialBuilder, tenant_id: Arc>>, - default_database_id: Arc>>, + default_database_name: Arc>>, resolve_tenant_or_database_lock: Arc>, #[cfg(feature = "opentelemetry")] metrics: crate::client::metrics::Metrics, @@ -48,7 +52,7 @@ impl Clone for ChromaClient { client: self.client.clone(), retry_policy: self.retry_policy, tenant_id: Arc::new(Mutex::new(self.tenant_id.lock().clone())), - default_database_id: Arc::new(Mutex::new(self.default_database_id.lock().clone())), + default_database_name: Arc::new(Mutex::new(self.default_database_name.lock().clone())), resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())), #[cfg(feature = "opentelemetry")] metrics: self.metrics.clone(), @@ -79,16 +83,16 @@ impl ChromaClient { client, retry_policy: options.retry_options.into(), tenant_id: Arc::new(Mutex::new(options.tenant_id)), - default_database_id: Arc::new(Mutex::new(options.default_database_id)), + default_database_name: Arc::new(Mutex::new(options.default_database_name)), resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())), #[cfg(feature = "opentelemetry")] metrics: crate::client::metrics::Metrics::new(), } } - pub fn set_default_database_id(&self, database_id: String) { - let mut lock = self.default_database_id.lock(); - *lock = Some(database_id); + pub fn set_default_database_name(&self, database_name: String) { + let mut lock = self.default_database_name.lock(); + *lock = Some(database_name); } pub async fn create_database(&self, name: String) -> Result<(), ChromaClientError> { @@ -158,12 +162,26 @@ impl ChromaClient { .await } + pub async fn get_or_create_collection( + &self, + params: CreateCollectionRequest, + ) -> Result { + self.common_create_collection(params, true).await + } + + pub async fn create_collection( + &self, + params: CreateCollectionRequest, + ) -> Result { + self.common_create_collection(params, false).await + } + pub async fn list_collections( &self, params: ListCollectionsRequest, - ) -> Result, ChromaClientError> { + ) -> Result, ChromaClientError> { let tenant_id = self.get_tenant_id().await?; - let database_id = self.get_database_id(params.database_id).await?; + let database_name = self.get_database_name(params.database_name).await?; #[derive(Serialize)] struct QueryParams { @@ -171,43 +189,84 @@ impl ChromaClient { offset: Option, } - self.send::<(), _, _>( - "list_collections", - Method::GET, - format!( - "/api/v2/tenants/{}/databases/{}/collections", - tenant_id, database_id - ), - None, - Some(QueryParams { - limit: params.limit, - offset: params.offset, - }), - ) - .await + let collections = self + .send::<(), _, Vec>( + "list_collections", + Method::GET, + format!( + "/api/v2/tenants/{}/databases/{}/collections", + tenant_id, database_name + ), + None, + Some(QueryParams { + limit: params.limit, + offset: params.offset, + }), + ) + .await?; + + Ok(collections + .into_iter() + .map(|collection| ChromaCollection { + client: self.clone(), + collection: Arc::new(collection), + }) + .collect()) + } + + async fn common_create_collection( + &self, + params: CreateCollectionRequest, + get_or_create: bool, + ) -> Result { + let tenant_id = self.get_tenant_id().await?; + let database_name = self.get_database_name(params.database_name).await?; + + let collection: chroma_types::Collection = self + .send( + "create_collection", + Method::POST, + format!( + "/api/v2/tenants/{}/databases/{}/collections", + tenant_id, database_name + ), + Some(serde_json::json!({ + "name": params.name, + "configuration": params.configuration, + "metadata": params.metadata, + "get_or_create": get_or_create, + })), + None::<()>, + ) + .await?; + + Ok(ChromaCollection { + client: self.clone(), + collection: Arc::new(collection), + }) } - async fn get_database_id( + async fn get_database_name( &self, - id_override: Option, + name_override: Option, ) -> Result { - if let Some(id) = id_override { + if let Some(id) = name_override { return Ok(id); } { - let database_id_lock = self.default_database_id.lock(); - if let Some(database_id) = &*database_id_lock { - return Ok(database_id.clone()); + let database_name_lock = self.default_database_name.lock(); + if let Some(database_name) = &*database_name_lock { + return Ok(database_name.clone()); } } let _guard = self.resolve_tenant_or_database_lock.lock().await; { - let database_id_lock = self.default_database_id.lock(); - if let Some(database_id) = &*database_id_lock { - return Ok(database_id.clone()); + let database_name_lock = self.default_database_name.lock(); + if let Some(database_name) = &*database_name_lock { + return Ok(database_name.clone()); } } @@ -215,22 +274,23 @@ impl ChromaClient { if identity.databases.len() > 1 { return Err(ChromaClientError::CouldNotResolveDatabaseId( - "Client has access to multiple databases; please provide a database_id".to_string(), + "Client has access to multiple databases; please provide a database_name" + .to_string(), )); } - let database_id = identity.databases.first().ok_or_else(|| { + let database_name = identity.databases.first().ok_or_else(|| { ChromaClientError::CouldNotResolveDatabaseId( "Client has access to no databases".to_string(), ) })?; { - let mut database_id_lock = self.default_database_id.lock(); - *database_id_lock = Some(database_id.clone()); + let mut database_name_lock = self.default_database_name.lock(); + *database_name_lock = Some(database_name.clone()); } - Ok(database_id.clone()) + Ok(database_name.clone()) } async fn get_tenant_id(&self) -> Result { @@ -290,7 +350,7 @@ impl ChromaClient { #[cfg(feature = "opentelemetry")] let started_at = std::time::Instant::now(); - let response = request.send().await?; + let response = request.send().await.map_err(|err| (err, None))?; #[cfg(feature = "opentelemetry")] { @@ -305,13 +365,16 @@ impl ChromaClient { let _ = operation_name; } - response.error_for_status_ref()?; - Ok::<_, reqwest::Error>(response) + if let Err(err) = response.error_for_status_ref() { + return Err((err, Some(response))); + } + + Ok::)>(response) }; let response = attempt .retry(&self.retry_policy) - .notify(|err, _| { + .notify(|(err, _), _| { tracing::warn!( url = %url, method =? method, @@ -322,13 +385,33 @@ impl ChromaClient { #[cfg(feature = "opentelemetry")] self.metrics.increment_retry(operation_name); }) - .when(|err| { + .when(|(err, _)| { err.status() .map(|status| status == StatusCode::TOO_MANY_REQUESTS) .unwrap_or_default() || method == Method::GET }) - .await?; + .await; + + let response = match response { + Ok(response) => response, + Err((err, maybe_response)) => { + if let Some(response) = maybe_response { + let json = response.json::().await?; + + if tracing::enabled!(tracing::Level::TRACE) { + tracing::trace!( + url = %url, + method =? method, + "Received response: {}", + serde_json::to_string_pretty(&json).unwrap_or_else(|_| "".to_string()) + ); + } + } + + return Err(ChromaClientError::RequestError(err)); + } + }; let json = response.json::().await?; @@ -392,22 +475,15 @@ mod tests { // Create isolated database for test let database_name = format!("test_db_{}", uuid::Uuid::new_v4()); client.create_database(database_name.clone()).await.unwrap(); - let databases = client.list_databases().await.unwrap(); - let database_id = databases - .iter() - .find(|db| db.name == database_name) - .unwrap() - .id - .clone(); - client.set_default_database_id(database_id.clone()); + client.set_default_database_name(database_name.clone()); let result = std::panic::AssertUnwindSafe(callback(client.clone())) .catch_unwind() .await; // Delete test database - if let Err(err) = client.delete_database(database_name).await { - tracing::error!("Failed to delete test database {}: {}", database_id, err); + if let Err(err) = client.delete_database(database_name.clone()).await { + tracing::error!("Failed to delete test database {}: {}", database_name, err); } result.unwrap(); @@ -551,7 +627,62 @@ mod tests { .unwrap(); assert!(collections.is_empty()); - // todo: create collection and assert it's returned, test limit/offset + client + .create_collection( + CreateCollectionRequest::builder() + .name("first".to_string()) + .build(), + ) + .await + .unwrap(); + + client + .create_collection( + CreateCollectionRequest::builder() + .name("second".to_string()) + .build(), + ) + .await + .unwrap(); + + let collections = client + .list_collections(ListCollectionsRequest::builder().build()) + .await + .unwrap(); + assert_eq!(collections.len(), 2); + + let collections = client + .list_collections(ListCollectionsRequest::builder().limit(1).offset(1).build()) + .await + .unwrap(); + assert_eq!(collections.len(), 1); + assert_eq!(collections[0].collection.name, "second"); + }) + .await; + } + + #[tokio::test] + #[test_log::test] + async fn test_live_cloud_create_collection() { + with_client(|client| async move { + let collection = client + .create_collection( + CreateCollectionRequest::builder() + .name("foo".to_string()) + .build(), + ) + .await + .unwrap(); + assert_eq!(collection.collection.name, "foo"); + + client + .get_or_create_collection( + CreateCollectionRequest::builder() + .name("foo".to_string()) + .build(), + ) + .await + .unwrap(); }) .await; } diff --git a/rust/chroma/src/client/options.rs b/rust/chroma/src/client/options.rs index 5312f0045d9..1b86b6fbfc2 100644 --- a/rust/chroma/src/client/options.rs +++ b/rust/chroma/src/client/options.rs @@ -64,7 +64,7 @@ pub struct ChromaClientOptions { /// Will be automatically resolved at request time if not provided pub tenant_id: Option, /// Will be automatically resolved at request time if not provided. It can only be resolved automatically if this client has access to exactly one database. - pub default_database_id: Option, + pub default_database_name: Option, } impl Default for ChromaClientOptions { @@ -74,7 +74,7 @@ impl Default for ChromaClientOptions { auth_method: ChromaAuthMethod::None, retry_options: ChromaRetryOptions::default(), tenant_id: None, - default_database_id: None, + default_database_name: None, } } } diff --git a/rust/chroma/src/collection.rs b/rust/chroma/src/collection.rs index c3b9a7df1f9..b5e545cf9e0 100644 --- a/rust/chroma/src/collection.rs +++ b/rust/chroma/src/collection.rs @@ -14,8 +14,8 @@ use crate::{client::ChromaClientError, ChromaClient}; #[derive(Clone, Debug)] pub struct ChromaCollection { - client: ChromaClient, - collection: Arc, + pub(crate) client: ChromaClient, + pub(crate) collection: Arc, } impl ChromaCollection { diff --git a/rust/chroma/src/types.rs b/rust/chroma/src/types.rs index 37a743d4761..b099f3634a6 100644 --- a/rust/chroma/src/types.rs +++ b/rust/chroma/src/types.rs @@ -1,4 +1,5 @@ mod requests; pub use chroma_api_types::{GetUserIdentityResponse, HeartbeatResponse}; +pub use requests::CreateCollectionRequest; pub use requests::ListCollectionsRequest; diff --git a/rust/chroma/src/types/requests.rs b/rust/chroma/src/types/requests.rs index e4f9b836662..63cc3211c1b 100644 --- a/rust/chroma/src/types/requests.rs +++ b/rust/chroma/src/types/requests.rs @@ -1,9 +1,18 @@ use bon::Builder; +use chroma_types::{CollectionConfiguration, Metadata}; #[derive(Builder)] +pub struct CreateCollectionRequest { + pub name: String, + pub configuration: Option, + pub metadata: Option, + pub database_name: Option, +} + +#[derive(Default, Builder)] pub struct ListCollectionsRequest { #[builder(default = 100)] pub limit: usize, pub offset: Option, - pub database_id: Option, + pub database_name: Option, }