11use  backon:: ExponentialBuilder ; 
22use  backon:: Retryable ; 
3+ use  chroma_types:: Collection ; 
34use  parking_lot:: Mutex ; 
45use  reqwest:: Method ; 
56use  reqwest:: StatusCode ; 
@@ -8,7 +9,10 @@ use std::sync::Arc;
89use  thiserror:: Error ; 
910
1011use  crate :: client:: ChromaClientOptions ; 
11- use  crate :: types:: { GetUserIdentityResponse ,  HeartbeatResponse ,  ListCollectionsRequest } ; 
12+ use  crate :: collection:: ChromaCollection ; 
13+ use  crate :: types:: { 
14+     CreateCollectionRequest ,  GetUserIdentityResponse ,  HeartbeatResponse ,  ListCollectionsRequest , 
15+ } ; 
1216
1317const  USER_AGENT :  & str  = concat ! ( 
1418    "Chroma Rust Client v" , 
@@ -32,7 +36,7 @@ pub struct ChromaClient {
3236    client :  reqwest:: Client , 
3337    retry_policy :  ExponentialBuilder , 
3438    tenant_id :  Arc < Mutex < Option < String > > > , 
35-     default_database_id :  Arc < Mutex < Option < String > > > , 
39+     default_database_name :  Arc < Mutex < Option < String > > > , 
3640    resolve_tenant_or_database_lock :  Arc < tokio:: sync:: Mutex < ( ) > > , 
3741    #[ cfg( feature = "opentelemetry" ) ]  
3842    metrics :  crate :: client:: metrics:: Metrics , 
@@ -45,7 +49,7 @@ impl Clone for ChromaClient {
4549            client :  self . client . clone ( ) , 
4650            retry_policy :  self . retry_policy , 
4751            tenant_id :  Arc :: new ( Mutex :: new ( self . tenant_id . lock ( ) . clone ( ) ) ) , 
48-             default_database_id :  Arc :: new ( Mutex :: new ( self . default_database_id . lock ( ) . clone ( ) ) ) , 
52+             default_database_name :  Arc :: new ( Mutex :: new ( self . default_database_name . lock ( ) . clone ( ) ) ) , 
4953            resolve_tenant_or_database_lock :  Arc :: new ( tokio:: sync:: Mutex :: new ( ( ) ) ) , 
5054            #[ cfg( feature = "opentelemetry" ) ]  
5155            metrics :  self . metrics . clone ( ) , 
@@ -76,16 +80,16 @@ impl ChromaClient {
7680            client, 
7781            retry_policy :  options. retry_options . into ( ) , 
7882            tenant_id :  Arc :: new ( Mutex :: new ( options. tenant_id ) ) , 
79-             default_database_id :  Arc :: new ( Mutex :: new ( options. default_database_id ) ) , 
83+             default_database_name :  Arc :: new ( Mutex :: new ( options. default_database_name ) ) , 
8084            resolve_tenant_or_database_lock :  Arc :: new ( tokio:: sync:: Mutex :: new ( ( ) ) ) , 
8185            #[ cfg( feature = "opentelemetry" ) ]  
8286            metrics :  crate :: client:: metrics:: Metrics :: new ( ) , 
8387        } 
8488    } 
8589
86-     pub  fn  set_default_database_id ( & self ,  database_id :  String )  { 
87-         let  mut  lock = self . default_database_id . lock ( ) ; 
88-         * lock = Some ( database_id ) ; 
90+     pub  fn  set_default_database_name ( & self ,  database_name :  String )  { 
91+         let  mut  lock = self . default_database_name . lock ( ) ; 
92+         * lock = Some ( database_name ) ; 
8993    } 
9094
9195    pub  async  fn  create_database ( & self ,  name :  String )  -> Result < ( ) ,  ChromaClientError >  { 
@@ -155,79 +159,135 @@ impl ChromaClient {
155159        . await 
156160    } 
157161
162+     pub  async  fn  get_or_create_collection ( 
163+         & self , 
164+         params :  CreateCollectionRequest , 
165+     )  -> Result < ChromaCollection ,  ChromaClientError >  { 
166+         self . common_create_collection ( params,  true ) . await 
167+     } 
168+ 
169+     pub  async  fn  create_collection ( 
170+         & self , 
171+         params :  CreateCollectionRequest , 
172+     )  -> Result < ChromaCollection ,  ChromaClientError >  { 
173+         self . common_create_collection ( params,  false ) . await 
174+     } 
175+ 
158176    pub  async  fn  list_collections ( 
159177        & self , 
160178        params :  ListCollectionsRequest , 
161-     )  -> Result < Vec < String > ,  ChromaClientError >  { 
179+     )  -> Result < Vec < ChromaCollection > ,  ChromaClientError >  { 
162180        let  tenant_id = self . get_tenant_id ( ) . await ?; 
163-         let  database_id  = self . get_database_id ( params. database_id ) . await ?; 
181+         let  database_name  = self . get_database_name ( params. database_name ) . await ?; 
164182
165183        #[ derive( Serialize ) ]  
166184        struct  QueryParams  { 
167185            limit :  usize , 
168186            offset :  Option < usize > , 
169187        } 
170188
171-         self . send :: < ( ) ,  _ ,  _ > ( 
172-             "list_collections" , 
173-             Method :: GET , 
174-             format ! ( 
175-                 "/api/v2/tenants/{}/databases/{}/collections" , 
176-                 tenant_id,  database_id
177-             ) , 
178-             None , 
179-             Some ( QueryParams  { 
180-                 limit :  params. limit , 
181-                 offset :  params. offset , 
182-             } ) , 
183-         ) 
184-         . await 
189+         let  collections = self 
190+             . send :: < ( ) ,  _ ,  Vec < Collection > > ( 
191+                 "list_collections" , 
192+                 Method :: GET , 
193+                 format ! ( 
194+                     "/api/v2/tenants/{}/databases/{}/collections" , 
195+                     tenant_id,  database_name
196+                 ) , 
197+                 None , 
198+                 Some ( QueryParams  { 
199+                     limit :  params. limit , 
200+                     offset :  params. offset , 
201+                 } ) , 
202+             ) 
203+             . await ?; 
204+ 
205+         Ok ( collections
206+             . into_iter ( ) 
207+             . map ( |collection| ChromaCollection  { 
208+                 client :  self . clone ( ) , 
209+                 collection :  Arc :: new ( collection) , 
210+             } ) 
211+             . collect ( ) ) 
212+     } 
213+ 
214+     async  fn  common_create_collection ( 
215+         & self , 
216+         params :  CreateCollectionRequest , 
217+         get_or_create :  bool , 
218+     )  -> Result < ChromaCollection ,  ChromaClientError >  { 
219+         let  tenant_id = self . get_tenant_id ( ) . await ?; 
220+         let  database_name = self . get_database_name ( params. database_name ) . await ?; 
221+ 
222+         let  collection:  chroma_types:: Collection  = self 
223+             . send ( 
224+                 "create_collection" , 
225+                 Method :: POST , 
226+                 format ! ( 
227+                     "/api/v2/tenants/{}/databases/{}/collections" , 
228+                     tenant_id,  database_name
229+                 ) , 
230+                 Some ( serde_json:: json!( { 
231+                     "name" :  params. name, 
232+                     "configuration" :  params. configuration, 
233+                     "metadata" :  params. metadata, 
234+                     "get_or_create" :  get_or_create, 
235+                 } ) ) , 
236+                 None :: < ( ) > , 
237+             ) 
238+             . await ?; 
239+ 
240+         Ok ( ChromaCollection  { 
241+             client :  self . clone ( ) , 
242+             collection :  Arc :: new ( collection) , 
243+         } ) 
185244    } 
186245
187-     async  fn  get_database_id ( 
246+     async  fn  get_database_name ( 
188247        & self , 
189-         id_override :  Option < String > , 
248+         name_override :  Option < String > , 
190249    )  -> Result < String ,  ChromaClientError >  { 
191-         if  let  Some ( id)  = id_override  { 
250+         if  let  Some ( id)  = name_override  { 
192251            return  Ok ( id) ; 
193252        } 
194253
195254        { 
196-             let  database_id_lock  = self . default_database_id . lock ( ) ; 
197-             if  let  Some ( database_id )  = & * database_id_lock  { 
198-                 return  Ok ( database_id . clone ( ) ) ; 
255+             let  database_name_lock  = self . default_database_name . lock ( ) ; 
256+             if  let  Some ( database_name )  = & * database_name_lock  { 
257+                 return  Ok ( database_name . clone ( ) ) ; 
199258            } 
200259        } 
201260
202261        let  _guard = self . resolve_tenant_or_database_lock . lock ( ) . await ; 
203262
204263        { 
205-             let  database_id_lock  = self . default_database_id . lock ( ) ; 
206-             if  let  Some ( database_id )  = & * database_id_lock  { 
207-                 return  Ok ( database_id . clone ( ) ) ; 
264+             let  database_name_lock  = self . default_database_name . lock ( ) ; 
265+             if  let  Some ( database_name )  = & * database_name_lock  { 
266+                 return  Ok ( database_name . clone ( ) ) ; 
208267            } 
209268        } 
210269
211270        let  identity = self . get_auth_identity ( ) . await ?; 
212271
213272        if  identity. databases . len ( )  > 1  { 
214273            return  Err ( ChromaClientError :: CouldNotResolveDatabaseId ( 
215-                 "Client has access to multiple databases; please provide a database_id" . to_string ( ) , 
274+                 "Client has access to multiple databases; please provide a database_name" 
275+                     . to_string ( ) , 
216276            ) ) ; 
217277        } 
218278
219-         let  database_id  = identity. databases . first ( ) . ok_or_else ( || { 
279+         let  database_name  = identity. databases . first ( ) . ok_or_else ( || { 
220280            ChromaClientError :: CouldNotResolveDatabaseId ( 
221281                "Client has access to no databases" . to_string ( ) , 
222282            ) 
223283        } ) ?; 
224284
225285        { 
226-             let  mut  database_id_lock  = self . default_database_id . lock ( ) ; 
227-             * database_id_lock  = Some ( database_id . clone ( ) ) ; 
286+             let  mut  database_name_lock  = self . default_database_name . lock ( ) ; 
287+             * database_name_lock  = Some ( database_name . clone ( ) ) ; 
228288        } 
229289
230-         Ok ( database_id . clone ( ) ) 
290+         Ok ( database_name . clone ( ) ) 
231291    } 
232292
233293    async  fn  get_tenant_id ( & self )  -> Result < String ,  ChromaClientError >  { 
@@ -287,7 +347,7 @@ impl ChromaClient {
287347            #[ cfg( feature = "opentelemetry" ) ]  
288348            let  started_at = std:: time:: Instant :: now ( ) ; 
289349
290-             let  response = request. send ( ) . await ?; 
350+             let  response = request. send ( ) . await . map_err ( |err|  ( err ,   None ) ) ?; 
291351
292352            #[ cfg( feature = "opentelemetry" ) ]  
293353            { 
@@ -302,13 +362,16 @@ impl ChromaClient {
302362                let  _ = operation_name; 
303363            } 
304364
305-             response. error_for_status_ref ( ) ?; 
306-             Ok :: < _ ,  reqwest:: Error > ( response) 
365+             if  let  Err ( err)  = response. error_for_status_ref ( )  { 
366+                 return  Err ( ( err,  Some ( response) ) ) ; 
367+             } 
368+ 
369+             Ok :: < reqwest:: Response ,  ( reqwest:: Error ,  Option < reqwest:: Response > ) > ( response) 
307370        } ; 
308371
309372        let  response = attempt
310373            . retry ( & self . retry_policy ) 
311-             . notify ( |err,  _| { 
374+             . notify ( |( err,  _ ) ,  _| { 
312375                tracing:: warn!( 
313376                    url = %url, 
314377                    method =? method, 
@@ -319,13 +382,33 @@ impl ChromaClient {
319382                #[ cfg( feature = "opentelemetry" ) ]  
320383                self . metrics . increment_retry ( operation_name) ; 
321384            } ) 
322-             . when ( |err| { 
385+             . when ( |( err,  _ ) | { 
323386                err. status ( ) 
324387                    . map ( |status| status == StatusCode :: TOO_MANY_REQUESTS ) 
325388                    . unwrap_or_default ( ) 
326389                    || method == Method :: GET 
327390            } ) 
328-             . await ?; 
391+             . await ; 
392+ 
393+         let  response = match  response { 
394+             Ok ( response)  => response, 
395+             Err ( ( err,  maybe_response) )  => { 
396+                 if  let  Some ( response)  = maybe_response { 
397+                     let  json = response. json :: < serde_json:: Value > ( ) . await ?; 
398+ 
399+                     if  tracing:: enabled!( tracing:: Level :: TRACE )  { 
400+                         tracing:: trace!( 
401+                             url = %url, 
402+                             method =? method, 
403+                             "Received response: {}" , 
404+                             serde_json:: to_string_pretty( & json) . unwrap_or_else( |_| "<failed to serialize>" . to_string( ) ) 
405+                         ) ; 
406+                     } 
407+                 } 
408+ 
409+                 return  Err ( ChromaClientError :: RequestError ( err) ) ; 
410+             } 
411+         } ; 
329412
330413        let  json = response. json :: < serde_json:: Value > ( ) . await ?; 
331414
@@ -389,22 +472,15 @@ mod tests {
389472        // Create isolated database for test 
390473        let  database_name = format ! ( "test_db_{}" ,  uuid:: Uuid :: new_v4( ) ) ; 
391474        client. create_database ( database_name. clone ( ) ) . await . unwrap ( ) ; 
392-         let  databases = client. list_databases ( ) . await . unwrap ( ) ; 
393-         let  database_id = databases
394-             . iter ( ) 
395-             . find ( |db| db. name  == database_name) 
396-             . unwrap ( ) 
397-             . id 
398-             . clone ( ) ; 
399-         client. set_default_database_id ( database_id. clone ( ) ) ; 
475+         client. set_default_database_name ( database_name. clone ( ) ) ; 
400476
401477        let  result = std:: panic:: AssertUnwindSafe ( callback ( client. clone ( ) ) ) 
402478            . catch_unwind ( ) 
403479            . await ; 
404480
405481        // Delete test database 
406-         if  let  Err ( err)  = client. delete_database ( database_name) . await  { 
407-             tracing:: error!( "Failed to delete test database {}: {}" ,  database_id ,  err) ; 
482+         if  let  Err ( err)  = client. delete_database ( database_name. clone ( ) ) . await  { 
483+             tracing:: error!( "Failed to delete test database {}: {}" ,  database_name ,  err) ; 
408484        } 
409485
410486        result. unwrap ( ) ; 
@@ -548,7 +624,62 @@ mod tests {
548624                . unwrap ( ) ; 
549625            assert ! ( collections. is_empty( ) ) ; 
550626
551-             // todo: create collection and assert it's returned, test limit/offset 
627+             client
628+                 . create_collection ( 
629+                     CreateCollectionRequest :: builder ( ) 
630+                         . name ( "first" . to_string ( ) ) 
631+                         . build ( ) , 
632+                 ) 
633+                 . await 
634+                 . unwrap ( ) ; 
635+ 
636+             client
637+                 . create_collection ( 
638+                     CreateCollectionRequest :: builder ( ) 
639+                         . name ( "second" . to_string ( ) ) 
640+                         . build ( ) , 
641+                 ) 
642+                 . await 
643+                 . unwrap ( ) ; 
644+ 
645+             let  collections = client
646+                 . list_collections ( ListCollectionsRequest :: builder ( ) . build ( ) ) 
647+                 . await 
648+                 . unwrap ( ) ; 
649+             assert_eq ! ( collections. len( ) ,  2 ) ; 
650+ 
651+             let  collections = client
652+                 . list_collections ( ListCollectionsRequest :: builder ( ) . limit ( 1 ) . offset ( 1 ) . build ( ) ) 
653+                 . await 
654+                 . unwrap ( ) ; 
655+             assert_eq ! ( collections. len( ) ,  1 ) ; 
656+             assert_eq ! ( collections[ 0 ] . collection. name,  "second" ) ; 
657+         } ) 
658+         . await ; 
659+     } 
660+ 
661+     #[ tokio:: test]  
662+     #[ test_log:: test]  
663+     async  fn  test_live_cloud_create_collection ( )  { 
664+         with_client ( |client| async  move  { 
665+             let  collection = client
666+                 . create_collection ( 
667+                     CreateCollectionRequest :: builder ( ) 
668+                         . name ( "foo" . to_string ( ) ) 
669+                         . build ( ) , 
670+                 ) 
671+                 . await 
672+                 . unwrap ( ) ; 
673+             assert_eq ! ( collection. collection. name,  "foo" ) ; 
674+ 
675+             client
676+                 . get_or_create_collection ( 
677+                     CreateCollectionRequest :: builder ( ) 
678+                         . name ( "foo" . to_string ( ) ) 
679+                         . build ( ) , 
680+                 ) 
681+                 . await 
682+                 . unwrap ( ) ; 
552683        } ) 
553684        . await ; 
554685    } 
0 commit comments