Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 185 additions & 54 deletions rust/chroma/src/client/chroma_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use backon::ExponentialBuilder;
use backon::Retryable;
use chroma_error::ChromaValidationError;
use chroma_types::Collection;
use parking_lot::Mutex;
use reqwest::Method;
use reqwest::StatusCode;
Expand All @@ -9,7 +10,10 @@ use std::sync::Arc;
use thiserror::Error;

use crate::client::ChromaClientOptions;
use crate::types::{GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest};
use crate::collection::ChromaCollection;
use crate::types::{
CreateCollectionRequest, GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest,
};

const USER_AGENT: &str = concat!(
"Chroma Rust Client v",
Expand All @@ -35,7 +39,7 @@ pub struct ChromaClient {
client: reqwest::Client,
retry_policy: ExponentialBuilder,
tenant_id: Arc<Mutex<Option<String>>>,
default_database_id: Arc<Mutex<Option<String>>>,
default_database_name: Arc<Mutex<Option<String>>>,
resolve_tenant_or_database_lock: Arc<tokio::sync::Mutex<()>>,
#[cfg(feature = "opentelemetry")]
metrics: crate::client::metrics::Metrics,
Expand All @@ -48,7 +52,7 @@ impl Clone for ChromaClient {
client: self.client.clone(),
retry_policy: self.retry_policy,
tenant_id: Arc::new(Mutex::new(self.tenant_id.lock().clone())),
default_database_id: Arc::new(Mutex::new(self.default_database_id.lock().clone())),
default_database_name: Arc::new(Mutex::new(self.default_database_name.lock().clone())),
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
#[cfg(feature = "opentelemetry")]
metrics: self.metrics.clone(),
Expand Down Expand Up @@ -79,16 +83,16 @@ impl ChromaClient {
client,
retry_policy: options.retry_options.into(),
tenant_id: Arc::new(Mutex::new(options.tenant_id)),
default_database_id: Arc::new(Mutex::new(options.default_database_id)),
default_database_name: Arc::new(Mutex::new(options.default_database_name)),
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
#[cfg(feature = "opentelemetry")]
metrics: crate::client::metrics::Metrics::new(),
}
}

pub fn set_default_database_id(&self, database_id: String) {
let mut lock = self.default_database_id.lock();
*lock = Some(database_id);
pub fn set_default_database_name(&self, database_name: String) {
let mut lock = self.default_database_name.lock();
*lock = Some(database_name);
}

pub async fn create_database(&self, name: String) -> Result<(), ChromaClientError> {
Expand Down Expand Up @@ -158,79 +162,135 @@ impl ChromaClient {
.await
}

pub async fn get_or_create_collection(
&self,
params: CreateCollectionRequest,
) -> Result<ChromaCollection, ChromaClientError> {
self.common_create_collection(params, true).await
}

pub async fn create_collection(
&self,
params: CreateCollectionRequest,
) -> Result<ChromaCollection, ChromaClientError> {
self.common_create_collection(params, false).await
}

pub async fn list_collections(
&self,
params: ListCollectionsRequest,
) -> Result<Vec<String>, ChromaClientError> {
) -> Result<Vec<ChromaCollection>, ChromaClientError> {
let tenant_id = self.get_tenant_id().await?;
let database_id = self.get_database_id(params.database_id).await?;
let database_name = self.get_database_name(params.database_name).await?;

#[derive(Serialize)]
struct QueryParams {
limit: usize,
offset: Option<usize>,
}

self.send::<(), _, _>(
"list_collections",
Method::GET,
format!(
"/api/v2/tenants/{}/databases/{}/collections",
tenant_id, database_id
),
None,
Some(QueryParams {
limit: params.limit,
offset: params.offset,
}),
)
.await
let collections = self
.send::<(), _, Vec<Collection>>(
"list_collections",
Method::GET,
format!(
"/api/v2/tenants/{}/databases/{}/collections",
tenant_id, database_name
),
None,
Some(QueryParams {
limit: params.limit,
offset: params.offset,
}),
)
.await?;

Ok(collections
.into_iter()
.map(|collection| ChromaCollection {
client: self.clone(),
collection: Arc::new(collection),
})
.collect())
}

async fn common_create_collection(
&self,
params: CreateCollectionRequest,
get_or_create: bool,
) -> Result<ChromaCollection, ChromaClientError> {
let tenant_id = self.get_tenant_id().await?;
let database_name = self.get_database_name(params.database_name).await?;

let collection: chroma_types::Collection = self
.send(
"create_collection",
Method::POST,
format!(
"/api/v2/tenants/{}/databases/{}/collections",
tenant_id, database_name
),
Some(serde_json::json!({
"name": params.name,
"configuration": params.configuration,
"metadata": params.metadata,
"get_or_create": get_or_create,
})),
None::<()>,
)
.await?;

Ok(ChromaCollection {
client: self.clone(),
collection: Arc::new(collection),
})
}

async fn get_database_id(
async fn get_database_name(
&self,
id_override: Option<String>,
name_override: Option<String>,
) -> Result<String, ChromaClientError> {
if let Some(id) = id_override {
if let Some(id) = name_override {
return Ok(id);
}

{
let database_id_lock = self.default_database_id.lock();
if let Some(database_id) = &*database_id_lock {
return Ok(database_id.clone());
let database_name_lock = self.default_database_name.lock();
if let Some(database_name) = &*database_name_lock {
return Ok(database_name.clone());
}
}

let _guard = self.resolve_tenant_or_database_lock.lock().await;

{
let database_id_lock = self.default_database_id.lock();
if let Some(database_id) = &*database_id_lock {
return Ok(database_id.clone());
let database_name_lock = self.default_database_name.lock();
if let Some(database_name) = &*database_name_lock {
return Ok(database_name.clone());
}
}

let identity = self.get_auth_identity().await?;

if identity.databases.len() > 1 {
return Err(ChromaClientError::CouldNotResolveDatabaseId(
"Client has access to multiple databases; please provide a database_id".to_string(),
"Client has access to multiple databases; please provide a database_name"
.to_string(),
));
}

let database_id = identity.databases.first().ok_or_else(|| {
let database_name = identity.databases.first().ok_or_else(|| {
ChromaClientError::CouldNotResolveDatabaseId(
"Client has access to no databases".to_string(),
)
})?;

{
let mut database_id_lock = self.default_database_id.lock();
*database_id_lock = Some(database_id.clone());
let mut database_name_lock = self.default_database_name.lock();
*database_name_lock = Some(database_name.clone());
}

Ok(database_id.clone())
Ok(database_name.clone())
}

async fn get_tenant_id(&self) -> Result<String, ChromaClientError> {
Expand Down Expand Up @@ -290,7 +350,7 @@ impl ChromaClient {
#[cfg(feature = "opentelemetry")]
let started_at = std::time::Instant::now();

let response = request.send().await?;
let response = request.send().await.map_err(|err| (err, None))?;

#[cfg(feature = "opentelemetry")]
{
Expand All @@ -305,13 +365,16 @@ impl ChromaClient {
let _ = operation_name;
}

response.error_for_status_ref()?;
Ok::<_, reqwest::Error>(response)
if let Err(err) = response.error_for_status_ref() {
return Err((err, Some(response)));
}

Ok::<reqwest::Response, (reqwest::Error, Option<reqwest::Response>)>(response)
};

let response = attempt
.retry(&self.retry_policy)
.notify(|err, _| {
.notify(|(err, _), _| {
tracing::warn!(
url = %url,
method =? method,
Expand All @@ -322,13 +385,33 @@ impl ChromaClient {
#[cfg(feature = "opentelemetry")]
self.metrics.increment_retry(operation_name);
})
.when(|err| {
.when(|(err, _)| {
err.status()
.map(|status| status == StatusCode::TOO_MANY_REQUESTS)
.unwrap_or_default()
|| method == Method::GET
})
.await?;
.await;

let response = match response {
Ok(response) => response,
Err((err, maybe_response)) => {
if let Some(response) = maybe_response {
let json = response.json::<serde_json::Value>().await?;

if tracing::enabled!(tracing::Level::TRACE) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now logs response JSON on failed requests
will add a custom error type for when we receive parsable JSON errors in a follow up PR

tracing::trace!(
url = %url,
method =? method,
"Received response: {}",
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "<failed to serialize>".to_string())
);
}
}

return Err(ChromaClientError::RequestError(err));
}
};

let json = response.json::<serde_json::Value>().await?;

Expand Down Expand Up @@ -392,22 +475,15 @@ mod tests {
// Create isolated database for test
let database_name = format!("test_db_{}", uuid::Uuid::new_v4());
client.create_database(database_name.clone()).await.unwrap();
let databases = client.list_databases().await.unwrap();
let database_id = databases
.iter()
.find(|db| db.name == database_name)
.unwrap()
.id
.clone();
client.set_default_database_id(database_id.clone());
client.set_default_database_name(database_name.clone());

let result = std::panic::AssertUnwindSafe(callback(client.clone()))
.catch_unwind()
.await;

// Delete test database
if let Err(err) = client.delete_database(database_name).await {
tracing::error!("Failed to delete test database {}: {}", database_id, err);
if let Err(err) = client.delete_database(database_name.clone()).await {
tracing::error!("Failed to delete test database {}: {}", database_name, err);
}

result.unwrap();
Expand Down Expand Up @@ -551,7 +627,62 @@ mod tests {
.unwrap();
assert!(collections.is_empty());

// todo: create collection and assert it's returned, test limit/offset
client
.create_collection(
CreateCollectionRequest::builder()
.name("first".to_string())
.build(),
)
.await
.unwrap();

client
.create_collection(
CreateCollectionRequest::builder()
.name("second".to_string())
.build(),
)
.await
.unwrap();

let collections = client
.list_collections(ListCollectionsRequest::builder().build())
.await
.unwrap();
assert_eq!(collections.len(), 2);

let collections = client
.list_collections(ListCollectionsRequest::builder().limit(1).offset(1).build())
.await
.unwrap();
assert_eq!(collections.len(), 1);
assert_eq!(collections[0].collection.name, "second");
})
.await;
}

#[tokio::test]
#[test_log::test]
async fn test_live_cloud_create_collection() {
with_client(|client| async move {
let collection = client
.create_collection(
CreateCollectionRequest::builder()
.name("foo".to_string())
.build(),
)
.await
.unwrap();
assert_eq!(collection.collection.name, "foo");

client
.get_or_create_collection(
CreateCollectionRequest::builder()
.name("foo".to_string())
.build(),
)
.await
.unwrap();
})
.await;
}
Expand Down
4 changes: 2 additions & 2 deletions rust/chroma/src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub struct ChromaClientOptions {
/// Will be automatically resolved at request time if not provided
pub tenant_id: Option<String>,
/// Will be automatically resolved at request time if not provided. It can only be resolved automatically if this client has access to exactly one database.
pub default_database_id: Option<String>,
pub default_database_name: Option<String>,
}

impl Default for ChromaClientOptions {
Expand All @@ -74,7 +74,7 @@ impl Default for ChromaClientOptions {
auth_method: ChromaAuthMethod::None,
retry_options: ChromaRetryOptions::default(),
tenant_id: None,
default_database_id: None,
default_database_name: None,
}
}
}
Expand Down
Loading
Loading