Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choosing input/total tokens automatically based on available VRAM? #2673

Merged
merged 13 commits into from
Oct 28, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ router/tokenizer.json

backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb

# ROCm auto-generated files
*.hip
Expand Down
36 changes: 24 additions & 12 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}

let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
Expand All @@ -136,7 +138,7 @@ impl Client {
// been updated to support chunks.

let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
Expand All @@ -145,6 +147,12 @@ impl Client {
));
}

let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};

requests.push(Request {
id: 0,
inputs,
Expand Down Expand Up @@ -175,15 +183,15 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;

// Check max_batch_size
if Some(requests.len()) == max_batch_size {
Expand All @@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}

/// Generate one token for each request in the given batch
Expand Down
18 changes: 13 additions & 5 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
Expand All @@ -122,8 +122,16 @@ impl ShardedClient {
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}

/// Generate one token for each request in the given batch
Expand Down
36 changes: 24 additions & 12 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}

let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
Expand All @@ -137,7 +139,7 @@ impl Client {
// been updated to support chunks.

let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
Expand All @@ -146,6 +148,12 @@ impl Client {
));
}

let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};

requests.push(Request {
id: 0,
inputs,
Expand Down Expand Up @@ -175,15 +183,15 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;

// Check max_batch_size
if Some(requests.len()) == max_batch_size {
Expand All @@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}

/// Generate one token for each request in the given batch
Expand Down
19 changes: 13 additions & 6 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
Expand All @@ -119,12 +119,19 @@ impl ShardedClient {
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}

/// Generate one token for each request in the given batch
Expand Down
69 changes: 49 additions & 20 deletions backends/v3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ pub struct BackendInfo {
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,

#[schema(example = "30000")]
pub max_input_tokens: usize,
#[schema(example = "32000")]
pub max_total_tokens: usize,
}

#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
max_input_tokens: Option<usize>,
max_total_tokens: Option<usize>,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
Expand All @@ -51,14 +56,32 @@ pub async fn connect_backend(
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
let check_max_batch_total_tokens = |(
max_supported_batch_total_tokens,
shard_max_input_tokens,
shard_max_total_tokens,
): (Option<u32>, u32, u32)|
-> Result<(u32, usize, usize), V3Error> {
if let Some(max_input_tokens) = max_input_tokens {
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
}
if let Some(max_total_tokens) = max_total_tokens {
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
}
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000
.max(shard_max_total_tokens)
.max(max_batch_prefill_tokens),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
Ok((
max_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
Expand All @@ -72,11 +95,15 @@ pub async fn connect_backend(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
if shard_max_total_tokens > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
}

Ok(max_supported_batch_total_tokens)
Ok((
max_supported_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
}
};
Expand All @@ -96,23 +123,25 @@ pub async fn connect_backend(

// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
let answer = sharded_client
.warmup(
max_input_tokens.map(|p| p as u32),
max_batch_prefill_tokens,
max_total_tokens.map(|p| p as u32),
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?;
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
check_max_batch_total_tokens(answer)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);

let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_input_tokens,
max_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
Expand Down
Loading
Loading