Skip to content

Commit a6a0c97

Browse files
feat: prefill chunking (#2600)
* wip * rollback * refactor to use prefix/postfix namming + fix all_input_ids_tensor * maybe patching vlms? * fix filter and concat * wip, no filter, no concat * current * add prepare_for_prefill * working * load tested * re-create slots * re-create slots * fix slot_filtering_indices * feedback loop * remove log * fix benchmarker * fix vlm and seq2seq * rename to cache and input lengths * fix prefill logprobs * fix launcher * fix logprobs? * idk at this point * max input length * omfg * remove debugging lines * fix tests * fix mllama * fix cargo tests * remove support chunking for paged * Fixing non blocked attentions * Fixing dtype + AMD, Ipex targets. * lint fix. * rename * Fix prefix_caching variable, remove defaults in server (confusing a lot of the times). * Add simple resolution when user specifies ATTENTION=paged. * Put back non default simple tests. * Fix env name --------- Co-authored-by: Nicolas Patry <[email protected]>
1 parent 704a58c commit a6a0c97

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1701
-1130
lines changed

Dockerfile_amd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
327327
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
328328
ENV VLLM_MOE_PADDING=0
329329
ENV ATTENTION=paged
330-
ENV USE_PREFIX_CACHING=0
330+
ENV PREFIX_CACHING=0
331+
ENV PREFILL_CHUNKING=0
331332
ENV ROCM_USE_SKINNY_GEMM=1
332333

333334
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh

Dockerfile_intel

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
218218

219219
FROM ${PLATFORM} AS final
220220
ENV ATTENTION=paged
221-
ENV USE_PREFIX_CACHING=0
221+
ENV PREFIX_CACHING=0
222+
ENV PREFILL_CHUNKING=0
222223
ENV CUDA_GRAPHS=0
223224
ENTRYPOINT ["text-generation-launcher"]
224225
CMD ["--json-output"]

backends/client/src/v3/client.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ impl Client {
158158
// Blocks and slots will be set on the server side if we use paged attention
159159
blocks: vec![],
160160
slots: vec![],
161-
prefix_len: 0,
161+
cache_len: 0,
162+
chunk_len: None,
162163
// Set sampling parameters to also take these ops into account in the max memory
163164
parameters: Some(NextTokenChooserParameters {
164165
temperature: 0.9,
@@ -217,8 +218,13 @@ impl Client {
217218
pub async fn prefill(
218219
&mut self,
219220
batch: Batch,
221+
cached_batch: Option<CachedBatch>,
220222
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
221-
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
223+
let request = tonic::Request::new(PrefillRequest {
224+
batch: Some(batch),
225+
cached_batch,
226+
})
227+
.inject_context();
222228
let response = self.stub.prefill(request).await?.into_inner();
223229
Ok((
224230
response.generations,

backends/client/src/v3/sharded_client.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ impl ShardedClient {
134134
pub async fn prefill(
135135
&mut self,
136136
batch: Batch,
137+
cached_batch: Option<CachedBatch>,
137138
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
138139
let futures: Vec<_> = self
139140
.clients
140141
.iter_mut()
141-
.map(|client| Box::pin(client.prefill(batch.clone())))
142+
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
142143
.collect();
143144
#[allow(clippy::type_complexity)]
144145
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@@ -245,7 +246,8 @@ impl Health for ShardedClient {
245246
// Block 0 is reserved for health checks
246247
blocks: vec![0],
247248
slots: (0..16).collect(),
248-
prefix_len: 0,
249+
cache_len: 0,
250+
chunk_len: None,
249251
adapter_id: None,
250252
};
251253
let batch = Batch {
@@ -255,7 +257,7 @@ impl Health for ShardedClient {
255257
max_tokens: 2,
256258
max_blocks: 1,
257259
};
258-
self.clone().prefill(batch).await?;
260+
self.clone().prefill(batch, None).await?;
259261
Ok(())
260262
}
261263
}

backends/v2/src/backend.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
66
use std::sync::Arc;
77
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
88
use text_generation_router::validation::ValidGenerateRequest;
9-
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
9+
use text_generation_router::{FinishReason, PrefillToken, Token};
1010
use tokio::sync::mpsc::error::SendError;
1111
use tokio::sync::{mpsc, Notify};
1212
use tokio::time::Instant;
@@ -36,18 +36,14 @@ impl BackendV2 {
3636
speculate: u32,
3737
) -> Self {
3838
// Infer shared state
39-
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
40-
attention
41-
.parse()
42-
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
43-
} else {
44-
Attention::Paged
45-
};
46-
let block_size = if attention == Attention::FlashDecoding {
47-
256
48-
} else {
49-
16
39+
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
40+
let block_size = match attention.as_str() {
41+
"flashinfer" => 1,
42+
"flashdecoding" => 256,
43+
"paged" => 16,
44+
_ => unreachable!(),
5045
};
46+
5147
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
5248
let batching_task_notifier = Arc::new(Notify::new());
5349

backends/v3/src/backend.rs

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
21
/// Batching and inference logic
2+
use crate::client::{
3+
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
4+
};
35
use crate::queue::{Entry, Queue};
46
use async_trait::async_trait;
57
use nohash_hasher::IntMap;
68
use std::sync::Arc;
79
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
810
use text_generation_router::validation::ValidGenerateRequest;
9-
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
11+
use text_generation_router::{FinishReason, PrefillToken, Token};
1012
use tokio::sync::mpsc::error::SendError;
1113
use tokio::sync::{mpsc, Notify};
1214
use tokio::time::Instant;
@@ -31,27 +33,22 @@ impl BackendV3 {
3133
max_batch_total_tokens: u32,
3234
max_waiting_tokens: usize,
3335
max_batch_size: Option<usize>,
34-
requires_padding: bool,
35-
window_size: Option<u32>,
36-
speculate: u32,
36+
shard_info: InfoResponse,
3737
) -> Self {
38-
let prefix_caching =
39-
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
40-
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
41-
let attention: String = std::env::var("ATTENTION").expect("attention env var");
38+
if shard_info.support_chunking {
39+
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
40+
}
4241

43-
let attention: Attention = attention
44-
.parse()
45-
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
46-
let block_size = attention.block_size();
42+
let block_size = shard_info.block_size;
4743

4844
let queue = Queue::new(
49-
requires_padding,
45+
shard_info.requires_padding,
5046
block_size,
51-
prefix_caching,
52-
window_size,
53-
speculate,
47+
shard_info.use_prefix_caching,
48+
shard_info.window_size,
49+
shard_info.speculate,
5450
max_batch_total_tokens,
51+
shard_info.support_chunking,
5552
);
5653
let batching_task_notifier = Arc::new(Notify::new());
5754

@@ -63,6 +60,7 @@ impl BackendV3 {
6360
max_batch_total_tokens,
6461
max_waiting_tokens,
6562
max_batch_size,
63+
shard_info.support_chunking,
6664
queue.clone(),
6765
batching_task_notifier.clone(),
6866
));
@@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
127125
max_batch_total_tokens: u32,
128126
max_waiting_tokens: usize,
129127
max_batch_size: Option<usize>,
128+
support_chunking: bool,
130129
queue: Queue,
131130
notifier: Arc<Notify>,
132131
) {
@@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
147146
)
148147
.await
149148
{
150-
let mut cached_batch = prefill(&mut client, batch, &mut entries)
149+
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
151150
.instrument(span)
152151
.await;
153152
let mut waiting_tokens = 1;
@@ -158,60 +157,90 @@ pub(crate) async fn batching_task(
158157
// Get current batch info
159158
let batch_size = batch.size;
160159
let batch_max_tokens = batch.max_tokens;
160+
let current_tokens = batch.current_tokens;
161161
let mut batches = vec![batch];
162162
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
163163
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
164164

165-
let min_size = if waiting_tokens >= max_waiting_tokens {
166-
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
167-
// to add a new batch even though its size might be small
168-
None
165+
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
166+
167+
let (min_size, max_size, prefill_token_budget) = if support_chunking {
168+
// Since the next batch will be concatenated with the current batch,
169+
// the current batch tokens must be subtracted to the prefill budget
170+
let prefill_token_budget =
171+
max_batch_prefill_tokens.saturating_sub(current_tokens);
172+
// We can ignore min_size and max_size
173+
// Models than rely on max_size cannot support chunking
174+
// Regarding min_size, chunking allow us to consistently run at the compute
175+
// bound, making min_size useless.
176+
(None, None, prefill_token_budget)
169177
} else {
170-
// Minimum batch size
171-
// TODO: temporarily disable to avoid incorrect deallocation +
172-
// reallocation when using prefix caching.
173-
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
174-
};
178+
let min_size = if waiting_tokens >= max_waiting_tokens {
179+
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
180+
// to add a new batch even though its size might be small
181+
None
182+
} else {
183+
// Minimum batch size
184+
// TODO: temporarily disable to avoid incorrect deallocation +
185+
// reallocation when using prefix caching.
186+
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
187+
};
175188

176-
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
177-
let max_size =
178-
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
189+
let max_size =
190+
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
191+
192+
(min_size, max_size, max_batch_prefill_tokens)
193+
};
179194

180195
// Try to get a new batch
181-
if let Some((mut new_entries, new_batch, span)) = queue
182-
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
196+
if let Some((new_entries, new_batch, span)) = queue
197+
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
183198
.await
184199
{
185200
// Tracking metrics
186201
if min_size.is_some() {
187202
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
188203
.increment(1);
189204
} else {
190-
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
191-
.increment(1);
205+
let counter = if support_chunking {
206+
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
207+
} else {
208+
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
209+
};
210+
counter.increment(1);
192211
}
193-
194-
entries.iter_mut().for_each(|(_, entry)| {
195-
// Create a new span to add the info that this entry is waiting
196-
// because a new batch is being computed
197-
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
198-
// Add relationships
199-
span.follows_from(&entry_waiting_span);
200-
entry_waiting_span.follows_from(&span);
201-
// Update entry
202-
entry.temp_span = Some(entry_waiting_span);
203-
});
212+
let cached_batch = if support_chunking {
213+
// Concat current batch to the new one
214+
batches.pop()
215+
} else {
216+
// Request are waiting only if we don't support chunking
217+
entries.iter_mut().for_each(|(_, entry)| {
218+
// Create a new span to add the info that this entry is waiting
219+
// because a new batch is being computed
220+
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
221+
// Add relationships
222+
span.follows_from(&entry_waiting_span);
223+
entry_waiting_span.follows_from(&span);
224+
// Update entry
225+
entry.temp_span = Some(entry_waiting_span);
226+
});
227+
None
228+
};
229+
entries.extend(new_entries);
204230

205231
// Generate one token for this new batch to have the attention past in cache
206-
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
207-
.instrument(span)
208-
.await;
232+
let new_cached_batch =
233+
prefill(&mut client, new_batch, cached_batch, &mut entries)
234+
.instrument(span)
235+
.await;
209236
// Reset waiting counter
210237
waiting_tokens = 1;
211238
// Extend current batch with the new batch
212239
if let Some(new_cached_batch) = new_cached_batch {
213-
entries.extend(new_entries);
214240
batches.push(new_cached_batch);
241+
} else if support_chunking {
242+
// New cached batch is empty, no work left
243+
break;
215244
}
216245
}
217246

@@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
244273
async fn prefill(
245274
client: &mut ShardedClient,
246275
batch: Batch,
276+
cached_batch: Option<CachedBatch>,
247277
entries: &mut IntMap<u64, Entry>,
248278
) -> Option<CachedBatch> {
249279
let start_time = Instant::now();
250280
let batch_id = batch.id;
251281
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
252282

253-
match client.prefill(batch).await {
283+
match client.prefill(batch, cached_batch).await {
254284
Ok((generations, next_batch, timings)) => {
255285
let start_filtering_time = Instant::now();
256286
// Send generated tokens and filter stopped entries
@@ -259,6 +289,10 @@ async fn prefill(
259289
// Filter next batch and remove requests that were stopped
260290
let next_batch = filter_batch(client, next_batch, entries).await;
261291

292+
if let Some(concat_duration) = timings.concat {
293+
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
294+
.record(concat_duration.as_secs_f64());
295+
}
262296
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
263297
.record(timings.forward.as_secs_f64());
264298
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")

backends/v3/src/client/grpc_client.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ impl Client {
158158
// Blocks and slots will be set on the server side if we use paged attention
159159
blocks: vec![],
160160
slots: vec![],
161-
prefix_len: 0,
161+
cache_len: 0,
162+
chunk_len: None,
162163
// Set sampling parameters to also take these ops into account in the max memory
163164
parameters: Some(NextTokenChooserParameters {
164165
temperature: 0.9,
@@ -217,13 +218,23 @@ impl Client {
217218
pub async fn prefill(
218219
&mut self,
219220
batch: Batch,
221+
cached_batch: Option<CachedBatch>,
220222
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
221-
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
223+
let request = tonic::Request::new(PrefillRequest {
224+
batch: Some(batch),
225+
cached_batch,
226+
})
227+
.inject_context();
222228
let response = self.stub.prefill(request).await?.into_inner();
223229
Ok((
224230
response.generations,
225231
response.batch,
226-
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
232+
PrefillTimings::new(
233+
response.concat_ns,
234+
response.forward_ns,
235+
response.decode_ns,
236+
response.total_ns,
237+
),
227238
))
228239
}
229240

@@ -252,14 +263,16 @@ impl Client {
252263
}
253264

254265
pub struct PrefillTimings {
266+
pub concat: Option<Duration>,
255267
pub forward: Duration,
256268
pub decode: Duration,
257269
pub total: Duration,
258270
}
259271

260272
impl PrefillTimings {
261-
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
273+
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
262274
Self {
275+
concat: concat_ns.map(Duration::from_nanos),
263276
forward: Duration::from_nanos(forward_ns),
264277
decode: Duration::from_nanos(decode_ns),
265278
total: Duration::from_nanos(total_ns),

0 commit comments

Comments
 (0)