Skip to content

Commit eaa396b

Browse files
authored
feat: sync audio api + fix error types (#464)
* updates to CreateSpeechRequest * types::audio; streaming audio types and api * fix imports in the examples * fix import * fix api; implement trait * add audio-speech-stream example * updated CreateTranscriptionRequest * updates for CreateTranscriptionResponseJson * udpated CreateTranscriptionResponseVerboseJson * updated CreateTranscriptionResponseDiarizedJson * update types for diarized * update transcription example * streaming from form submission * add streaming example to audio-transcribe * updates for translation * update to example * update audio api groups * update audio examples * fix webhooks error type * fix errors reported by clippy * fix based on clippy * fix for: https://rust-lang.github.io/rust-clippy/master/index.html\#result_large_err * cargo fmt
1 parent 8e9639a commit eaa396b

File tree

27 files changed

+1233
-472
lines changed

27 files changed

+1233
-472
lines changed

async-openai/src/audio.rs

Lines changed: 10 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,4 @@
1-
use bytes::Bytes;
2-
3-
use crate::{
4-
config::Config,
5-
error::OpenAIError,
6-
types::{
7-
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
8-
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
9-
CreateTranslationRequest, CreateTranslationResponseJson,
10-
CreateTranslationResponseVerboseJson,
11-
},
12-
Client,
13-
};
1+
use crate::{config::Config, Client, Speech, Transcriptions, Translations};
142

153
/// Turn audio into text or text into audio.
164
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
@@ -23,89 +11,18 @@ impl<'c, C: Config> Audio<'c, C> {
2311
Self { client }
2412
}
2513

26-
/// Transcribes audio into the input language.
27-
#[crate::byot(
28-
T0 = Clone,
29-
R = serde::de::DeserializeOwned,
30-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
31-
)]
32-
pub async fn transcribe(
33-
&self,
34-
request: CreateTranscriptionRequest,
35-
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
36-
self.client
37-
.post_form("/audio/transcriptions", request)
38-
.await
39-
}
40-
41-
/// Transcribes audio into the input language.
42-
#[crate::byot(
43-
T0 = Clone,
44-
R = serde::de::DeserializeOwned,
45-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
46-
)]
47-
pub async fn transcribe_verbose_json(
48-
&self,
49-
request: CreateTranscriptionRequest,
50-
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
51-
self.client
52-
.post_form("/audio/transcriptions", request)
53-
.await
54-
}
55-
56-
/// Transcribes audio into the input language.
57-
pub async fn transcribe_raw(
58-
&self,
59-
request: CreateTranscriptionRequest,
60-
) -> Result<Bytes, OpenAIError> {
61-
self.client
62-
.post_form_raw("/audio/transcriptions", request)
63-
.await
14+
/// APIs in Speech group.
15+
pub fn speech(&self) -> Speech<'_, C> {
16+
Speech::new(self.client)
6417
}
6518

66-
/// Translates audio into English.
67-
#[crate::byot(
68-
T0 = Clone,
69-
R = serde::de::DeserializeOwned,
70-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
71-
)]
72-
pub async fn translate(
73-
&self,
74-
request: CreateTranslationRequest,
75-
) -> Result<CreateTranslationResponseJson, OpenAIError> {
76-
self.client.post_form("/audio/translations", request).await
19+
/// APIs in Transcription group.
20+
pub fn transcription(&self) -> Transcriptions<'_, C> {
21+
Transcriptions::new(self.client)
7722
}
7823

79-
/// Translates audio into English.
80-
#[crate::byot(
81-
T0 = Clone,
82-
R = serde::de::DeserializeOwned,
83-
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
84-
)]
85-
pub async fn translate_verbose_json(
86-
&self,
87-
request: CreateTranslationRequest,
88-
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
89-
self.client.post_form("/audio/translations", request).await
90-
}
91-
92-
/// Transcribes audio into the input language.
93-
pub async fn translate_raw(
94-
&self,
95-
request: CreateTranslationRequest,
96-
) -> Result<Bytes, OpenAIError> {
97-
self.client
98-
.post_form_raw("/audio/translations", request)
99-
.await
100-
}
101-
102-
/// Generates audio from the input text.
103-
pub async fn speech(
104-
&self,
105-
request: CreateSpeechRequest,
106-
) -> Result<CreateSpeechResponse, OpenAIError> {
107-
let bytes = self.client.post_raw("/audio/speech", request).await?;
108-
109-
Ok(CreateSpeechResponse { bytes })
24+
/// APIs in Translation group.
25+
pub fn translation(&self) -> Translations<'_, C> {
26+
Translations::new(self.client)
11027
}
11128
}

async-openai/src/client.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,79 @@ impl<C: Config> Client<C> {
351351
self.execute(request_maker).await
352352
}
353353

354+
pub(crate) async fn post_form_stream<O, F>(
355+
&self,
356+
path: &str,
357+
form: F,
358+
) -> Result<Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>, OpenAIError>
359+
where
360+
F: Clone,
361+
Form: AsyncTryFrom<F, Error = OpenAIError>,
362+
O: DeserializeOwned + std::marker::Send + 'static,
363+
{
364+
// Build and execute request manually since multipart::Form is not Clone
365+
// and .eventsource() requires cloneability
366+
let response = self
367+
.http_client
368+
.post(self.config.url(path))
369+
.query(&self.config.query())
370+
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
371+
.headers(self.config.headers())
372+
.send()
373+
.await
374+
.map_err(OpenAIError::Reqwest)?;
375+
376+
// Check for error status
377+
if !response.status().is_success() {
378+
return Err(read_response(response).await.unwrap_err());
379+
}
380+
381+
// Convert response body to EventSource stream
382+
let stream = response
383+
.bytes_stream()
384+
.map(|result| result.map_err(std::io::Error::other));
385+
let event_stream = eventsource_stream::EventStream::new(stream);
386+
387+
// Convert EventSource stream to our expected format
388+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
389+
390+
tokio::spawn(async move {
391+
use futures::StreamExt;
392+
let mut event_stream = std::pin::pin!(event_stream);
393+
394+
while let Some(event_result) = event_stream.next().await {
395+
match event_result {
396+
Err(e) => {
397+
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(Box::new(
398+
StreamError::EventStream(e.to_string()),
399+
)))) {
400+
break;
401+
}
402+
}
403+
Ok(event) => {
404+
// eventsource_stream::Event is a struct with data field
405+
if event.data == "[DONE]" {
406+
break;
407+
}
408+
409+
let response = match serde_json::from_str::<O>(&event.data) {
410+
Err(e) => Err(map_deserialization_error(e, event.data.as_bytes())),
411+
Ok(output) => Ok(output),
412+
};
413+
414+
if let Err(_e) = tx.send(response) {
415+
break;
416+
}
417+
}
418+
}
419+
}
420+
});
421+
422+
Ok(Box::pin(
423+
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
424+
))
425+
}
426+
354427
/// Execute a HTTP request and retry on rate limit
355428
///
356429
/// request_maker serves one purpose: to be able to create request again
@@ -524,7 +597,7 @@ async fn map_stream_error(value: EventSourceError) -> OpenAIError {
524597
"Unreachable because read_response returns err when status_code {status_code} is invalid"
525598
))
526599
}
527-
_ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value)),
600+
_ => OpenAIError::StreamError(Box::new(StreamError::ReqwestEventSource(value))),
528601
}
529602
}
530603

async-openai/src/error.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub enum OpenAIError {
2121
FileReadError(String),
2222
/// Error on SSE streaming
2323
#[error("stream failed: {0}")]
24-
StreamError(StreamError),
24+
StreamError(Box<StreamError>),
2525
/// Error from client side validation
2626
/// or when builder fails to build request before making API call
2727
#[error("invalid args: {0}")]
@@ -36,6 +36,9 @@ pub enum StreamError {
3636
/// Error when a stream event does not match one of the expected values
3737
#[error("Unknown event: {0:#?}")]
3838
UnknownEvent(eventsource_stream::Event),
39+
/// Error from eventsource_stream when parsing SSE
40+
#[error("EventStream error: {0}")]
41+
EventStream(String),
3942
}
4043

4144
/// OpenAI API returns error object on failure

async-openai/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,12 @@ mod project_users;
168168
mod projects;
169169
mod responses;
170170
mod runs;
171+
mod speech;
171172
mod steps;
172173
mod threads;
173174
pub mod traits;
175+
mod transcriptions;
176+
mod translations;
174177
pub mod types;
175178
mod uploads;
176179
mod users;
@@ -207,8 +210,11 @@ pub use project_users::ProjectUsers;
207210
pub use projects::Projects;
208211
pub use responses::Responses;
209212
pub use runs::Runs;
213+
pub use speech::Speech;
210214
pub use steps::Steps;
211215
pub use threads::Threads;
216+
pub use transcriptions::Transcriptions;
217+
pub use translations::Translations;
212218
pub use uploads::Uploads;
213219
pub use users::Users;
214220
pub use vector_store_file_batches::VectorStoreFileBatches;

async-openai/src/speech.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use crate::{
2+
config::Config,
3+
error::OpenAIError,
4+
types::audio::{CreateSpeechRequest, CreateSpeechResponse, SpeechResponseStream},
5+
Client,
6+
};
7+
8+
pub struct Speech<'c, C: Config> {
9+
client: &'c Client<C>,
10+
}
11+
12+
impl<'c, C: Config> Speech<'c, C> {
13+
pub fn new(client: &'c Client<C>) -> Self {
14+
Self { client }
15+
}
16+
17+
/// Generates audio from the input text.
18+
pub async fn create(
19+
&self,
20+
request: CreateSpeechRequest,
21+
) -> Result<CreateSpeechResponse, OpenAIError> {
22+
let bytes = self.client.post_raw("/audio/speech", request).await?;
23+
24+
Ok(CreateSpeechResponse { bytes })
25+
}
26+
27+
/// Generates audio from the input text in SSE stream format.
28+
#[crate::byot(
29+
T0 = serde::Serialize,
30+
R = serde::de::DeserializeOwned,
31+
stream = "true",
32+
where_clause = "R: std::marker::Send + 'static"
33+
)]
34+
#[allow(unused_mut)]
35+
pub async fn create_stream(
36+
&self,
37+
mut request: CreateSpeechRequest,
38+
) -> Result<SpeechResponseStream, OpenAIError> {
39+
#[cfg(not(feature = "byot"))]
40+
{
41+
use crate::types::audio::StreamFormat;
42+
if let Some(stream_format) = request.stream_format {
43+
if stream_format != StreamFormat::SSE {
44+
return Err(OpenAIError::InvalidArgument(
45+
"When stream_format is not SSE, use Audio::speech".into(),
46+
));
47+
}
48+
}
49+
50+
request.stream_format = Some(StreamFormat::SSE);
51+
}
52+
Ok(self.client.post_stream("/audio/speech", request).await)
53+
}
54+
}

0 commit comments

Comments
 (0)