From c1a83d55f46abf37fcda36c461b499818ec19101 Mon Sep 17 00:00:00 2001 From: Ryan Olson Date: Thu, 12 Dec 2024 05:03:51 -0700 Subject: [PATCH] prototype moving pub(crate) methods to ClientExt; extending Chat with a ChatExt using ClientExt and ClientProvider --- async-openai/Cargo.toml | 1 + async-openai/src/chat.rs | 62 +++++++++++++++++ async-openai/src/client.rs | 121 +++++++++++++++++++++------------ async-openai/src/completion.rs | 2 +- async-openai/src/config.rs | 2 +- async-openai/src/lib.rs | 2 +- async-openai/src/runs.rs | 1 + async-openai/src/threads.rs | 1 + 8 files changed, 145 insertions(+), 47 deletions(-) diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 5c4cb94d..10872001 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"] realtime = ["dep:tokio-tungstenite"] [dependencies] +async-trait = "0.1" backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.30" diff --git a/async-openai/src/chat.rs b/async-openai/src/chat.rs index c7f9b962..da86bced 100644 --- a/async-openai/src/chat.rs +++ b/async-openai/src/chat.rs @@ -1,4 +1,5 @@ use crate::{ + client::{ClientExt, ClientProvider}, config::Config, error::OpenAIError, types::{ @@ -52,3 +53,64 @@ impl<'c, C: Config> Chat<'c, C> { Ok(self.client.post_stream("/chat/completions", request).await) } } + +impl<'c, C: Config + Send> ClientProvider<'c, C> for Chat<'c, C> { + fn client(&self) -> &'c Client { + self.client + } +} + +#[cfg(test)] +mod tests { + + use crate::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs}; + + use super::*; + + #[async_trait::async_trait] + trait ChatExt { + async fn create_annotated_stream( + &self, + mut request: CreateChatCompletionRequest, + ) -> Result; + } + + #[async_trait::async_trait] + impl<'c, C: Config> ChatExt for Chat<'c, C> { + async fn create_annotated_stream( + &self, + mut request: CreateChatCompletionRequest, + ) -> Result { + if request.stream.is_some() && !request.stream.unwrap() { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Chat::create".into(), + )); + } + + request.stream = Some(true); + + Ok(self.client.post_stream("/chat/completions", request).await) + } + } + + #[tokio::test] + async fn test() { + let client = Client::new(); + let chat = client.chat(); + + let request = CreateChatCompletionRequestArgs::default() + .model("gpt-3.5-turbo") + .max_tokens(512u32) + .messages([ChatCompletionRequestUserMessageArgs::default() + .content( + "Write a marketing blog praising and introducing Rust library async-openai", + ) + .build() + .unwrap() + .into()]) + .build() + .unwrap(); + + let _stream = chat.create_annotated_stream(request).await.unwrap(); + } +} diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 22c74a24..ecc511f1 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -377,50 +377,6 @@ impl Client { Ok(response) } - /// Make HTTP POST request to receive SSE - pub(crate) async fn post_stream( - &self, - path: &str, - request: I, - ) -> Pin> + Send>> - where - I: Serialize, - O: DeserializeOwned + std::marker::Send + 'static, - { - let event_source = self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .json(&request) - .eventsource() - .unwrap(); - - stream(event_source).await - } - - pub(crate) async fn post_stream_mapped_raw_events( - &self, - path: &str, - request: I, - event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static, - ) -> Pin> + Send>> - where - I: Serialize, - O: DeserializeOwned + std::marker::Send + 'static, - { - let event_source = self - .http_client - .post(self.config.url(path)) - .query(&self.config.query()) - .headers(self.config.headers()) - .json(&request) - .eventsource() - .unwrap(); - - stream_mapped_raw_events(event_source, event_mapper).await - } - /// Make HTTP GET request to receive SSE pub(crate) async fn _get_stream( &self, @@ -537,3 +493,80 @@ where Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) } + +pub trait ClientProvider<'c, C: Config + Send> { + fn client(&self) -> &'c Client; +} + +#[async_trait::async_trait] +pub trait ClientExt: Send { + /// Make HTTP POST request to receive SSE + async fn post_stream( + &self, + path: &str, + request: I, + ) -> Pin> + Send>> + where + I: Serialize + Send, + O: DeserializeOwned + std::marker::Send + 'static; + + /// Make HTTP POST request to receive SSE with a custom event source handler + async fn post_stream_mapped_raw_events( + &self, + path: &str, + request: I, + event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static, + ) -> Pin> + Send>> + where + I: Serialize + Send, + O: DeserializeOwned + std::marker::Send + 'static; +} + +#[async_trait::async_trait] +impl ClientExt for Client +where + C: Send, +{ + async fn post_stream( + &self, + path: &str, + request: I, + ) -> Pin> + Send>> + where + I: Serialize + Send, + O: DeserializeOwned + std::marker::Send + 'static, + { + let event_source = self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .eventsource() + .unwrap(); + + stream(event_source).await + } + + async fn post_stream_mapped_raw_events( + &self, + path: &str, + request: I, + event_mapper: impl Fn(eventsource_stream::Event) -> Result + Send + 'static, + ) -> Pin> + Send>> + where + I: Serialize + Send, + O: DeserializeOwned + std::marker::Send + 'static, + { + let event_source = self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .eventsource() + .unwrap(); + + stream_mapped_raw_events(event_source, event_mapper).await + } +} diff --git a/async-openai/src/completion.rs b/async-openai/src/completion.rs index 6e8468fd..1caa4957 100644 --- a/async-openai/src/completion.rs +++ b/async-openai/src/completion.rs @@ -1,5 +1,5 @@ use crate::{ - client::Client, + client::{Client, ClientExt}, config::Config, error::OpenAIError, types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse}, diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 91b3699a..f97073a2 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -15,7 +15,7 @@ pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta"; /// [crate::Client] relies on this for every API call on OpenAI /// or Azure OpenAI service -pub trait Config: Clone { +pub trait Config: Clone + Send + Sync { fn headers(&self) -> HeaderMap; fn url(&self, path: &str) -> String; fn query(&self) -> Vec<(&str, &str)>; diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index c8a06edd..9ff720e5 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -83,7 +83,7 @@ mod audio; mod audit_logs; mod batches; mod chat; -mod client; +pub mod client; mod completion; pub mod config; mod download; diff --git a/async-openai/src/runs.rs b/async-openai/src/runs.rs index a0d68152..65227b48 100644 --- a/async-openai/src/runs.rs +++ b/async-openai/src/runs.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::{ + client::ClientExt, config::Config, error::OpenAIError, steps::Steps, diff --git a/async-openai/src/threads.rs b/async-openai/src/threads.rs index 31c2c6e0..1c5d2613 100644 --- a/async-openai/src/threads.rs +++ b/async-openai/src/threads.rs @@ -1,4 +1,5 @@ use crate::{ + client::ClientExt, config::Config, error::OpenAIError, types::{