Skip to content

Added model name label to metrics and added an optional argument --served-model-name #3064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/llamacpp/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct Args {
#[clap(long, env)]
model_id: String,

/// Name under which the model is served. Defaults to `model_id` if not provided.
#[clap(long, env)]
served_model_name: Option<String>,

/// Revision of the model.
#[clap(default_value = "main", long, env)]
revision: String,
Expand Down Expand Up @@ -152,6 +156,10 @@ struct Args {
async fn main() -> Result<(), RouterError> {
let args = Args::parse();

let served_model_name = args.served_model_name
.clone()
.unwrap_or_else(|| args.model_id.clone());

logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);

let n_threads = match args.n_threads {
Expand Down Expand Up @@ -264,6 +272,7 @@ async fn main() -> Result<(), RouterError> {
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
served_model_name
)
.await?;
Ok(())
Expand Down
8 changes: 8 additions & 0 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ struct Args {
revision: Option<String>,
#[clap(long, env)]
model_id: String,
#[clap(long, env)]
served_model_name: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
Expand Down Expand Up @@ -227,6 +229,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_config_path,
revision,
model_id,
served_model_name,
validation_workers,
json_output,
otlp_endpoint,
Expand All @@ -239,6 +242,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
payload_limit,
} = args;

let served_model_name = args.served_model_name
.clone()
.unwrap_or_else(|| args.model_id.clone());

// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);

Expand Down Expand Up @@ -318,6 +325,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
served_model_name,
)
.await?;
Ok(())
Expand Down
70 changes: 38 additions & 32 deletions backends/v2/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl BackendV2 {
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
served_model_name: String,
) -> Self {
// Infer shared state
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
Expand All @@ -44,7 +45,7 @@ impl BackendV2 {
_ => unreachable!(),
};

let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let queue = Queue::new(requires_padding, block_size, window_size, speculate, served_model_name.clone());
let batching_task_notifier = Arc::new(Notify::new());

// Spawn batching background task that contains all the inference logic
Expand All @@ -57,6 +58,7 @@ impl BackendV2 {
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
served_model_name.clone(),
));

Self {
Expand Down Expand Up @@ -128,6 +130,7 @@ pub(crate) async fn batching_task(
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
served_model_name: String,
) {
// Infinite loop
loop {
Expand All @@ -146,7 +149,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, &mut entries, served_model_name.clone())
.instrument(span)
.await;
let mut waiting_tokens = 1;
Expand All @@ -158,8 +161,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(batch_max_tokens as f64);

let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
Expand All @@ -180,10 +183,10 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
metrics::counter!("tgi_batch_concat", "reason" => "backpressure", "model_name" => served_model_name.clone())
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded", "model_name" => served_model_name.clone())
.increment(1);
}

Expand All @@ -199,7 +202,7 @@ pub(crate) async fn batching_task(
});

// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries, served_model_name.clone())
.instrument(span)
.await;
// Reset waiting counter
Expand All @@ -225,13 +228,13 @@ pub(crate) async fn batching_task(
entry.temp_span = Some(entry_batch_span);
});

cached_batch = decode(&mut client, batches, &mut entries)
cached_batch = decode(&mut client, batches, &mut entries, served_model_name.clone())
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
metrics::gauge!("tgi_batch_current_size", "model_name" => served_model_name.clone()).set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens", "model_name" => served_model_name.clone()).set(0.0);
}
}
}
Expand All @@ -241,36 +244,37 @@ async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);

match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());

// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;

metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill", "model_name" => served_model_name.clone()).increment(1);
None
}
}
Expand All @@ -281,33 +285,34 @@ async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
served_model_name: String,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_count", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);

match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
filter_send_generations(generations, entries, served_model_name.clone());

// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;

if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode", "model_name" => served_model_name.clone())
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_success", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Expand All @@ -316,7 +321,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode", "model_name" => served_model_name.clone()).increment(1);
None
}
}
Expand Down Expand Up @@ -358,7 +363,7 @@ async fn filter_batch(
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>, served_model_name: String) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
Expand All @@ -372,9 +377,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).inspect_err(|_err| {
let stopped = send_responses(generation, entry, served_model_name.clone()).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
Expand All @@ -386,10 +391,11 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses(
generation: Generation,
entry: &Entry,
served_model_name: String,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
metrics::counter!("tgi_request_failure", "err" => "dropped", "model_name" => served_model_name.clone()).increment(1);
return Ok(true);
}

Expand All @@ -415,7 +421,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens", "model_name" => served_model_name.clone()).record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
Expand Down
4 changes: 3 additions & 1 deletion backends/v2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub async fn connect_backend(
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
served_model_name: String,
) -> Result<(BackendV2, BackendInfo), V2Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
Expand Down Expand Up @@ -108,7 +109,7 @@ pub async fn connect_backend(
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};

let backend = BackendV2::new(
sharded_client,
waiting_served_ratio,
Expand All @@ -119,6 +120,7 @@ pub async fn connect_backend(
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
served_model_name,
);

tracing::info!("Using backend V3");
Expand Down
10 changes: 9 additions & 1 deletion backends/v2/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use thiserror::Error;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(long, env)]
served_model_name: String,

#[command(subcommand)]
command: Option<Commands>,

Expand Down Expand Up @@ -83,8 +86,11 @@ enum Commands {
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
let _served_model_name = args.served_model_name.clone();

// Pattern match configuration
let Args {
served_model_name,
command,
max_concurrent_requests,
max_best_of,
Expand Down Expand Up @@ -170,8 +176,9 @@ async fn main() -> Result<(), RouterError> {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
served_model_name.clone(),
)
.await?;
.await?;

// Run server
server::run(
Expand All @@ -198,6 +205,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
served_model_name.clone(),
)
.await?;
Ok(())
Expand Down
Loading