diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 1af00e1b..7404e503 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,4 +1,4 @@ -use std::pin::Pin; +use std::{pin::Pin, time::Duration}; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; @@ -360,6 +360,13 @@ impl Client { .map_err(OpenAIError::Reqwest) .map_err(backoff::Error::Permanent)?; + let retry_after = response + .headers() + .get("retry-after") + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .map(Duration::from_secs); + let status = response.status(); match read_response(response).await { @@ -370,7 +377,7 @@ impl Client { if status.is_server_error() { Err(backoff::Error::Transient { err: OpenAIError::ApiError(api_error), - retry_after: None, + retry_after, }) } else if status.as_u16() == 429 && api_error.r#type != Some("insufficient_quota".to_string()) @@ -379,7 +386,7 @@ impl Client { tracing::warn!("Rate limited: {}", api_error.message); Err(backoff::Error::Transient { err: OpenAIError::ApiError(api_error), - retry_after: None, + retry_after, }) } else { Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))