Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add retry on rate limit errors #1095

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// A truncate agent that truncates the conversation history when it exceeds the model's context limit
/// It makes no attempt to handle context limits, and cannot read resources
use std::time::Duration;

use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::Mutex;
Expand All @@ -8,6 +9,7 @@ use tracing::{debug, error, instrument, warn};
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config;
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand All @@ -22,18 +24,32 @@ use serde_json::{json, Value};
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;

// Multiplier to use in exponential backoff for retries on rate limit errors.
// Backoff duration will be calculated via `2**RETRY_COUNT * GOOSE_AGENT_RETRY_MULTIPLIER_MS`.
// So, for example if you set `GOOSE_AGENT_RETRY_MULTIPLIER_MS` to `1000` (1s),
// requests will be retried after `2000` (2s), `4000` (4s), `8000` (8s), ... up to `MAX_RETRIES` times.
const GOOSE_AGENT_RETRY_MULTIPLIER_MS: &str = "GOOSE_AGENT_RETRY_MULTIPLIER_MS";
const MAX_RETRIES: u32 = 5;

/// Truncate implementation of an Agent
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
// If this is set, `ProviderError::RateLimitExceeded` errors will be retried
// with an exponential backoff using this number as multiplier up to `MAX_RETRIES` times.
retry_multiplier: Option<u64>,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let config = Config::global();
let retry_multiplier = config.get::<u64>(GOOSE_AGENT_RETRY_MULTIPLIER_MS).ok();

let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
retry_multiplier,
}
}

Expand Down Expand Up @@ -131,6 +147,7 @@ impl Agent for TruncateAgent {
let mut capabilities = self.capabilities.lock().await;
let mut tools = capabilities.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;
let mut retry_count: u32 = 0;

// we add in the read_resource tool by default
// TODO: make sure there is no collision with another extension's tool name
Expand Down Expand Up @@ -200,8 +217,9 @@ impl Agent for TruncateAgent {
Ok((response, usage)) => {
capabilities.record_usage(usage).await;

// Reset truncation attempt
// Reset truncation and retry attempts
truncation_attempt = 0;
retry_count = 0;

// Yield the assistant's response
yield response.clone();
Expand Down Expand Up @@ -267,6 +285,35 @@ impl Agent for TruncateAgent {
break;
}

// Re-acquire the lock
capabilities = self.capabilities.lock().await;

// Retry the loop after truncation
continue;
},
Err(ProviderError::RateLimitExceeded(err)) => {
let Some(multiplier_ms) = self.retry_multiplier else {
yield Message::assistant().with_text(
format!("Ran into rate limit error: {err}.\n\n\
Please retry if you think this is a transient error.\
If you want agent to retry on rate limit errors, consider setting `{GOOSE_AGENT_RETRY_MULTIPLIER_MS}` to enable retries."));
break;
};

if retry_count >= MAX_RETRIES {
yield Message::assistant().with_text(format!("Error: Rate limit exceeded even after {MAX_RETRIES} retires with an exponential-backoff. Please try to increase API limits or try again later."));
break;
}

retry_count += 1;
let delay = Duration::from_millis((2_u64.pow(retry_count)) * multiplier_ms);
warn!("Rate limit exceeded: {err}. Retrying after {delay:?}, retry count: {retry_count}/{MAX_RETRIES}.");

// release the lock before truncation to prevent deadlock
drop(capabilities);

// backoff
tokio::time::sleep(delay).await;

// Re-acquire the lock
capabilities = self.capabilities.lock().await;
Expand Down
Loading