1
- use crate :: client:: { Batch , CachedBatch , ClientError , Generation , Health , ShardedClient } ;
2
1
/// Batching and inference logic
2
+ use crate :: client:: {
3
+ Batch , CachedBatch , ClientError , Generation , Health , InfoResponse , ShardedClient ,
4
+ } ;
3
5
use crate :: queue:: { Entry , Queue } ;
4
6
use async_trait:: async_trait;
5
7
use nohash_hasher:: IntMap ;
6
8
use std:: sync:: Arc ;
7
9
use text_generation_router:: infer:: { Backend , GeneratedText , InferError , InferStreamResponse } ;
8
10
use text_generation_router:: validation:: ValidGenerateRequest ;
9
- use text_generation_router:: { Attention , FinishReason , PrefillToken , Token } ;
11
+ use text_generation_router:: { FinishReason , PrefillToken , Token } ;
10
12
use tokio:: sync:: mpsc:: error:: SendError ;
11
13
use tokio:: sync:: { mpsc, Notify } ;
12
14
use tokio:: time:: Instant ;
@@ -31,27 +33,22 @@ impl BackendV3 {
31
33
max_batch_total_tokens : u32 ,
32
34
max_waiting_tokens : usize ,
33
35
max_batch_size : Option < usize > ,
34
- requires_padding : bool ,
35
- window_size : Option < u32 > ,
36
- speculate : u32 ,
36
+ shard_info : InfoResponse ,
37
37
) -> Self {
38
- let prefix_caching =
39
- std:: env:: var ( "USE_PREFIX_CACHING" ) . expect ( "Expect prefix caching env var" ) ;
40
- let prefix_caching = matches ! ( prefix_caching. as_str( ) , "true" | "1" ) ;
41
- let attention: String = std:: env:: var ( "ATTENTION" ) . expect ( "attention env var" ) ;
38
+ if shard_info. support_chunking {
39
+ tracing:: warn!( "Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored." ) ;
40
+ }
42
41
43
- let attention: Attention = attention
44
- . parse ( )
45
- . unwrap_or_else ( |_| panic ! ( "Invalid attention was specified :`{attention}`" ) ) ;
46
- let block_size = attention. block_size ( ) ;
42
+ let block_size = shard_info. block_size ;
47
43
48
44
let queue = Queue :: new (
49
- requires_padding,
45
+ shard_info . requires_padding ,
50
46
block_size,
51
- prefix_caching ,
52
- window_size,
53
- speculate,
47
+ shard_info . use_prefix_caching ,
48
+ shard_info . window_size ,
49
+ shard_info . speculate ,
54
50
max_batch_total_tokens,
51
+ shard_info. support_chunking ,
55
52
) ;
56
53
let batching_task_notifier = Arc :: new ( Notify :: new ( ) ) ;
57
54
@@ -63,6 +60,7 @@ impl BackendV3 {
63
60
max_batch_total_tokens,
64
61
max_waiting_tokens,
65
62
max_batch_size,
63
+ shard_info. support_chunking ,
66
64
queue. clone ( ) ,
67
65
batching_task_notifier. clone ( ) ,
68
66
) ) ;
@@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
127
125
max_batch_total_tokens : u32 ,
128
126
max_waiting_tokens : usize ,
129
127
max_batch_size : Option < usize > ,
128
+ support_chunking : bool ,
130
129
queue : Queue ,
131
130
notifier : Arc < Notify > ,
132
131
) {
@@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
147
146
)
148
147
. await
149
148
{
150
- let mut cached_batch = prefill ( & mut client, batch, & mut entries)
149
+ let mut cached_batch = prefill ( & mut client, batch, None , & mut entries)
151
150
. instrument ( span)
152
151
. await ;
153
152
let mut waiting_tokens = 1 ;
@@ -158,60 +157,90 @@ pub(crate) async fn batching_task(
158
157
// Get current batch info
159
158
let batch_size = batch. size ;
160
159
let batch_max_tokens = batch. max_tokens ;
160
+ let current_tokens = batch. current_tokens ;
161
161
let mut batches = vec ! [ batch] ;
162
162
metrics:: gauge!( "tgi_batch_current_size" ) . set ( batch_size as f64 ) ;
163
163
metrics:: gauge!( "tgi_batch_current_max_tokens" ) . set ( batch_max_tokens as f64 ) ;
164
164
165
- let min_size = if waiting_tokens >= max_waiting_tokens {
166
- // If we didn't onboard any new requests since >= max_waiting_tokens, we try
167
- // to add a new batch even though its size might be small
168
- None
165
+ let token_budget = max_batch_total_tokens. saturating_sub ( batch_max_tokens) ;
166
+
167
+ let ( min_size, max_size, prefill_token_budget) = if support_chunking {
168
+ // Since the next batch will be concatenated with the current batch,
169
+ // the current batch tokens must be subtracted to the prefill budget
170
+ let prefill_token_budget =
171
+ max_batch_prefill_tokens. saturating_sub ( current_tokens) ;
172
+ // We can ignore min_size and max_size
173
+ // Models than rely on max_size cannot support chunking
174
+ // Regarding min_size, chunking allow us to consistently run at the compute
175
+ // bound, making min_size useless.
176
+ ( None , None , prefill_token_budget)
169
177
} else {
170
- // Minimum batch size
171
- // TODO: temporarily disable to avoid incorrect deallocation +
172
- // reallocation when using prefix caching.
173
- Some ( ( batch_size as f32 * waiting_served_ratio) . floor ( ) as usize )
174
- } ;
178
+ let min_size = if waiting_tokens >= max_waiting_tokens {
179
+ // If we didn't onboard any new requests since >= max_waiting_tokens, we try
180
+ // to add a new batch even though its size might be small
181
+ None
182
+ } else {
183
+ // Minimum batch size
184
+ // TODO: temporarily disable to avoid incorrect deallocation +
185
+ // reallocation when using prefix caching.
186
+ Some ( ( batch_size as f32 * waiting_served_ratio) . floor ( ) as usize )
187
+ } ;
175
188
176
- let token_budget = max_batch_total_tokens. saturating_sub ( batch_max_tokens) ;
177
- let max_size =
178
- max_batch_size. map ( |max_size| max_size. saturating_sub ( batch_size as usize ) ) ;
189
+ let max_size =
190
+ max_batch_size. map ( |max_size| max_size. saturating_sub ( batch_size as usize ) ) ;
191
+
192
+ ( min_size, max_size, max_batch_prefill_tokens)
193
+ } ;
179
194
180
195
// Try to get a new batch
181
- if let Some ( ( mut new_entries, new_batch, span) ) = queue
182
- . next_batch ( min_size, max_size, max_batch_prefill_tokens , token_budget)
196
+ if let Some ( ( new_entries, new_batch, span) ) = queue
197
+ . next_batch ( min_size, max_size, prefill_token_budget , token_budget)
183
198
. await
184
199
{
185
200
// Tracking metrics
186
201
if min_size. is_some ( ) {
187
202
metrics:: counter!( "tgi_batch_concat" , "reason" => "backpressure" )
188
203
. increment ( 1 ) ;
189
204
} else {
190
- metrics:: counter!( "tgi_batch_concat" , "reason" => "wait_exceeded" )
191
- . increment ( 1 ) ;
205
+ let counter = if support_chunking {
206
+ metrics:: counter!( "tgi_batch_concat" , "reason" => "chunking" )
207
+ } else {
208
+ metrics:: counter!( "tgi_batch_concat" , "reason" => "wait_exceeded" )
209
+ } ;
210
+ counter. increment ( 1 ) ;
192
211
}
193
-
194
- entries. iter_mut ( ) . for_each ( |( _, entry) | {
195
- // Create a new span to add the info that this entry is waiting
196
- // because a new batch is being computed
197
- let entry_waiting_span = info_span ! ( parent: & entry. span, "waiting" ) ;
198
- // Add relationships
199
- span. follows_from ( & entry_waiting_span) ;
200
- entry_waiting_span. follows_from ( & span) ;
201
- // Update entry
202
- entry. temp_span = Some ( entry_waiting_span) ;
203
- } ) ;
212
+ let cached_batch = if support_chunking {
213
+ // Concat current batch to the new one
214
+ batches. pop ( )
215
+ } else {
216
+ // Request are waiting only if we don't support chunking
217
+ entries. iter_mut ( ) . for_each ( |( _, entry) | {
218
+ // Create a new span to add the info that this entry is waiting
219
+ // because a new batch is being computed
220
+ let entry_waiting_span = info_span ! ( parent: & entry. span, "waiting" ) ;
221
+ // Add relationships
222
+ span. follows_from ( & entry_waiting_span) ;
223
+ entry_waiting_span. follows_from ( & span) ;
224
+ // Update entry
225
+ entry. temp_span = Some ( entry_waiting_span) ;
226
+ } ) ;
227
+ None
228
+ } ;
229
+ entries. extend ( new_entries) ;
204
230
205
231
// Generate one token for this new batch to have the attention past in cache
206
- let new_cached_batch = prefill ( & mut client, new_batch, & mut new_entries)
207
- . instrument ( span)
208
- . await ;
232
+ let new_cached_batch =
233
+ prefill ( & mut client, new_batch, cached_batch, & mut entries)
234
+ . instrument ( span)
235
+ . await ;
209
236
// Reset waiting counter
210
237
waiting_tokens = 1 ;
211
238
// Extend current batch with the new batch
212
239
if let Some ( new_cached_batch) = new_cached_batch {
213
- entries. extend ( new_entries) ;
214
240
batches. push ( new_cached_batch) ;
241
+ } else if support_chunking {
242
+ // New cached batch is empty, no work left
243
+ break ;
215
244
}
216
245
}
217
246
@@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
244
273
async fn prefill (
245
274
client : & mut ShardedClient ,
246
275
batch : Batch ,
276
+ cached_batch : Option < CachedBatch > ,
247
277
entries : & mut IntMap < u64 , Entry > ,
248
278
) -> Option < CachedBatch > {
249
279
let start_time = Instant :: now ( ) ;
250
280
let batch_id = batch. id ;
251
281
metrics:: counter!( "tgi_batch_inference_count" , "method" => "prefill" ) . increment ( 1 ) ;
252
282
253
- match client. prefill ( batch) . await {
283
+ match client. prefill ( batch, cached_batch ) . await {
254
284
Ok ( ( generations, next_batch, timings) ) => {
255
285
let start_filtering_time = Instant :: now ( ) ;
256
286
// Send generated tokens and filter stopped entries
@@ -259,6 +289,10 @@ async fn prefill(
259
289
// Filter next batch and remove requests that were stopped
260
290
let next_batch = filter_batch ( client, next_batch, entries) . await ;
261
291
292
+ if let Some ( concat_duration) = timings. concat {
293
+ metrics:: histogram!( "tgi_batch_concat_duration" , "method" => "decode" )
294
+ . record ( concat_duration. as_secs_f64 ( ) ) ;
295
+ }
262
296
metrics:: histogram!( "tgi_batch_forward_duration" , "method" => "prefill" )
263
297
. record ( timings. forward . as_secs_f64 ( ) ) ;
264
298
metrics:: histogram!( "tgi_batch_decode_duration" , "method" => "prefill" )
0 commit comments