diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8c407e81a0a..076cf8496ec 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -60,10 +60,13 @@ jobs: authkey: ${{ secrets.TAILSCALE_AUTHKEY }} slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }} slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} - - name: Initialize Docker Buildx - uses: docker/setup-buildx-action@v2.0.0 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 diff --git a/Cargo.toml b/Cargo.toml index bc2da5a1124..552c0bffb30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,10 +20,6 @@ tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } [profile.release] -incremental = true - -[profile.release-binary] -inherits = "release" debug = 1 incremental = true panic = "abort" diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba41b..e5fbdca4a2a 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -156,7 +156,6 @@ async fn prefill( }), top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], - slots: vec![], }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd320..afd5e005bf1 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -132,8 +132,6 @@ message Request { uint32 top_n_tokens = 7; /// Paged attention blocks repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; } message Batch { @@ -164,6 +162,7 @@ enum FinishReason { FINISH_REASON_LENGTH = 0; FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_TERMINATED = 3; } message GeneratedText { @@ -198,18 +197,43 @@ message Generation { optional GeneratedText generated_text = 4; /// Top tokens repeated Tokens top_tokens = 5; + /// Current length of the cache: prompt tokens + number of generated tokens until this point + uint32 cache_length = 6; } +message KeptRequest { + /// Request ID + uint64 id = 1; + /// Paged attention blocks + repeated uint32 blocks = 2; + /// Paged attention blocks padded to max blocks for this batch + repeated uint32 padded_blocks = 3; +} + +/// kept_requests + terminated_request_ids might not cover all requests from the +/// cached batch as some requests can be filtered out without requiring to generate text +/// for example if the client dropped its connection to the router message FilterBatchRequest { /// Batch ID uint64 batch_id = 1; /// Requests to keep - repeated uint64 request_ids = 2; + repeated KeptRequest kept_requests = 2; + /// Requests to terminate and generate text for + repeated uint64 terminated_request_ids = 3; +} + +message TerminatedGeneration { + // Request ID + uint64 id = 1; + // Generated text + GeneratedText generated_text = 2; } message FilterBatchResponse { /// Filtered Batch (cached) CachedBatch batch = 1; + /// Terminated generations + repeated TerminatedGeneration terminated_generations = 2; } diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb186..03efd4f5d47 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -90,15 +90,17 @@ impl Client { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, - ) -> Result> { + kept_requests: Vec, + terminated_request_ids: Vec, + ) -> Result<(Option, Vec)> { let request = tonic::Request::new(FilterBatchRequest { batch_id, - request_ids, + kept_requests, + terminated_request_ids, }) .inject_context(); let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); - Ok(filtered_batch.batch) + Ok((filtered_batch.batch, filtered_batch.terminated_generations)) } /// Warmup on a max size batch @@ -155,7 +157,6 @@ impl Client { truncate, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], - slots: vec![], // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index 4a1296a2247..9df17c50947 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v3::{ input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, Tokens, + HealthResponse, Image, InfoResponse, Input, InputChunk, KeptRequest, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, TerminatedGeneration, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55064..f89bf75defd 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -2,14 +2,15 @@ use crate::{v3, Health, ShardInfo}; use crate::{ClientError, Result}; -use crate::v3::{Chunk, InfoResponse, Input}; +use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use async_trait::async_trait; -use futures::future::join_all; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; use v3::{ - Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, KeptRequest, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; @@ -29,8 +30,12 @@ impl ShardedClient { async fn from_master_client(mut master_client: Client) -> Result { // Get all uris/unix sockets from the master client let uris = master_client.service_discovery().await?; - let futures = uris.into_iter().map(Client::connect_uds); - let clients: Result> = join_all(futures).await.into_iter().collect(); + let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect(); + let clients: Result> = futures + .collect::>>() + .await + .into_iter() + .collect(); Ok(Self::new(clients?)) } @@ -49,34 +54,43 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] pub async fn info(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + futures + .collect::>>() + .await + .pop() + .unwrap() + .map(ShardInfo::from) } /// GRPC health check #[instrument(skip(self))] pub async fn health(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.health()) .collect(); - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.clear_cache(batch_id)) .collect(); - join_all(futures).await.into_iter().collect() + futures + .collect::>>() + .await + .into_iter() + .collect() } /// Filter a cached batch @@ -84,15 +98,22 @@ impl ShardedClient { pub async fn filter_batch( &mut self, batch_id: u64, - request_ids: Vec, - ) -> Result> { - let futures: Vec<_> = self + kept_requests: Vec, + terminated_request_ids: Vec, + ) -> Result<(Option, Vec)> { + let futures: FuturesUnordered<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .map(|client| { + Box::pin(client.filter_batch( + batch_id, + kept_requests.clone(), + terminated_request_ids.clone(), + )) + }) .collect(); // all shards return the same message - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Warmup on a max size batch @@ -106,7 +127,7 @@ impl ShardedClient { max_total_tokens: u32, max_batch_size: Option, ) -> Result> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| { @@ -119,7 +140,8 @@ impl ShardedClient { }) .collect(); // Take the minimum value - let results = join_all(futures) + let results = futures + .collect::>>() .await .into_iter() .collect::>>>()?; @@ -135,14 +157,17 @@ impl ShardedClient { &mut self, batch: Batch, ) -> Result<(Vec, Option, PrefillTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, PrefillTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, PrefillTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) = @@ -168,14 +193,17 @@ impl ShardedClient { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, DecodeTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, DecodeTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) = @@ -243,7 +271,6 @@ impl Health for ShardedClient { top_n_tokens: 0, // Block 0 is reserved for health checks blocks: vec![0], - slots: (0..16).collect(), }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 07c334a3647..e13bedb522b 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -503,6 +503,8 @@ pub enum InferError { TemplateError(#[from] minijinja::Error), #[error("Tool error: {0}")] ToolError(String), + #[error("Request could not be re-allocated: out of pages")] + OutOfPages, } impl InferError { @@ -514,6 +516,7 @@ impl InferError { InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", InferError::ToolError(_) => "tool_error", + InferError::OutOfPages => "out_of_pages", } } } diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 7467fd85997..45ed1f07dcb 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,23 +1,88 @@ use std::cmp::min; -use tokio::sync::{mpsc, oneshot}; +use std::fmt::Formatter; +use std::sync::{Arc, Mutex, TryLockError}; +use thiserror::Error; -#[derive(Debug, Clone)] +#[derive(Clone)] pub(crate) struct BlockAllocation { - pub blocks: Vec, - pub slots: Vec, + block_size: usize, + allocated_blocks: Vec, + required_blocks: usize, + required_slots: usize, block_allocator: BlockAllocator, } +impl BlockAllocation { + pub(crate) fn len(&self) -> usize { + self.allocated_blocks.len() * self.block_size + } + + pub(crate) fn blocks(&self) -> &[u32] { + &self.allocated_blocks + } + + /// Extend an allocation by adding new blocks + /// If the allocation length > window size, repeats blocks and slots to cover the + /// whole `required_blocks` and `required_slots` + pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { + let required_blocks = match self.block_allocator.window_size { + None => self.required_blocks, + Some(window_size) => min( + (window_size as usize + self.block_size - 1) / self.block_size, + self.required_blocks, + ), + }; + let remaining_blocks = required_blocks.saturating_sub(self.allocated_blocks.len()); + let new_blocks = min(remaining_blocks, 16); + + // Try to allocate all remaining blocks + let blocks = match self.block_allocator.allocate_blocks(new_blocks) { + Ok(blocks) => blocks, + // Failed, try to allocate one block + Err(_) => self.block_allocator.allocate_blocks(1)?, + }; + // Add block and slots to current allocation + self.allocated_blocks.extend(blocks); + + if let Some(window_size) = self.block_allocator.window_size { + // if we have more slots than the window size, + // we will never need to re-allocate and we can just repeat the blocks/slots + let window_size = window_size as usize; + if self.len() > window_size { + let repeats = (self.required_slots + window_size - 1) / window_size; + self.allocated_blocks = self.allocated_blocks.repeat(repeats); + self.allocated_blocks.truncate(self.required_blocks); + } + } + + Ok(()) + } +} + impl Drop for BlockAllocation { + /// Free the blocks fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + let allocated_blocks = std::mem::take(&mut self.allocated_blocks); + self.block_allocator.free(allocated_blocks) + } +} + +impl std::fmt::Debug for BlockAllocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BlockAllocation") + .field("allocated_blocks", &self.allocated_blocks.len()) + .field("required_blocks", &self.required_blocks) + .field("required_slots", &self.required_slots) + .field("block_allocator", &self.block_allocator) + .finish() } } -#[derive(Debug, Clone)] +#[derive(Clone)] pub(crate) struct BlockAllocator { - /// Channel to communicate with the background task - block_allocator: mpsc::UnboundedSender, + free_blocks: Arc>>, + block_size: u32, + window_size: Option, } impl BlockAllocator { @@ -26,111 +91,129 @@ impl BlockAllocator { block_size: u32, window_size: Option, ) -> Self { - // Create channel - let (sender, receiver) = mpsc::unbounded_channel(); + let blocks = max_batch_total_tokens / block_size; + // Block 0 is reserved for health checks + let free_blocks: Vec = (1..blocks).collect(); - // Launch background queue task - tokio::spawn(block_allocator_task( - max_batch_total_tokens / block_size, + Self { + free_blocks: Arc::new(Mutex::new(free_blocks)), block_size, window_size, - receiver, - )); + } + } - Self { - block_allocator: sender, + fn allocate_blocks(&self, blocks: usize) -> Result, AllocationError> { + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + + if blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return + return Err(AllocationError::NotEnoughPages); } + + // Take the blocks + let n_free_blocks = free_blocks.len(); + Ok(free_blocks.split_off(n_free_blocks - blocks)) } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { - let (response_sender, response_receiver) = oneshot::channel(); - self.block_allocator - .send(BlockAllocatorCommand::Allocate { - tokens, - response_sender, - }) - .unwrap(); - - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + /// For prompt tokens, we allocate enough blocks to cover all tokens + /// For decode tokens, we allocate min(decode_blocks, 16) blocks + /// + /// If allocation > window size, we repeat blocks and slots + pub(crate) fn block_allocation( + &self, + prompt_tokens: u32, + decode_tokens: u32, + ) -> Result { + let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; + // prompt blocks + 16 blocks for decode + let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size; + let required_blocks = required_prompt_blocks + min(decode_blocks, 16); + let required_slots = required_blocks * self.block_size; + + // Slots and blocks required for the whole request + let total_slots = prompt_tokens + decode_tokens; + let total_required_blocks = (total_slots + self.block_size - 1) / self.block_size; + + let (clipped_required_blocks, repeats) = match self.window_size { + Some(window_size) if required_slots >= window_size => { + // Number of blocks for this window size + let window_size_blocks = (window_size + self.block_size - 1) / self.block_size; + // Number of times we will need to repeat blocks to cover the total allocation + let repeats = (total_slots + window_size - 1) / window_size; + (window_size_blocks, repeats) + } + // Nothing to do + _ => (required_blocks, 1), + }; + + // Scoped to drop the lock early + let allocated_blocks = { + let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); + let clipped_required_blocks = clipped_required_blocks as usize; + + if clipped_required_blocks > free_blocks.len() { + // Not enough blocks to cover this allocation + // Early return + return Err(AllocationError::NotEnoughPages); + } + + // Take the blocks + let n_free_blocks = free_blocks.len(); + free_blocks.split_off(n_free_blocks - clipped_required_blocks) + }; + + let repeats = repeats as usize; + let total_slots = total_slots as usize; + let total_required_blocks = total_required_blocks as usize; + + let allocated_blocks = if repeats != 1 { + let mut allocated_blocks = allocated_blocks.repeat(repeats); + allocated_blocks.truncate(total_required_blocks); + allocated_blocks + } else { + allocated_blocks + }; + + Ok(BlockAllocation { + block_size: self.block_size as usize, + allocated_blocks, + required_blocks: total_required_blocks, + required_slots: total_slots, + block_allocator: self.clone(), + }) } pub(crate) fn free(&self, blocks: Vec) { - self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) - .unwrap(); + self.free_blocks + .lock() + .expect("Lock could not be acquired. This is a bug.") + .extend(blocks) } } -async fn block_allocator_task( - blocks: u32, - block_size: u32, - window_size: Option, - mut receiver: mpsc::UnboundedReceiver, -) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); - while let Some(cmd) = receiver.recv().await { - match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), - BlockAllocatorCommand::Allocate { - tokens, - response_sender, - } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); +impl std::fmt::Debug for BlockAllocator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("BlockAllocator"); + d.field("block_size", &self.block_size) + .field("window_size", &self.window_size); + match self.free_blocks.try_lock() { + Ok(guard) => { + d.field("free_blocks", &(*guard).len()); } - } + Err(TryLockError::Poisoned(err)) => { + d.field("free_blocks", &(**err.get_ref()).len()); + } + Err(TryLockError::WouldBlock) => { + d.field("free_blocks", &format_args!("")); + } + }; + d.finish() } } -#[derive(Debug)] -enum BlockAllocatorCommand { - Free { - blocks: Vec, - }, - Allocate { - tokens: u32, - response_sender: oneshot::Sender, Vec)>>, - }, +#[derive(Error, Debug)] +pub enum AllocationError { + #[error("Not enough pages")] + NotEnoughPages, } diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142abb1..43d2bdd84f3 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -5,7 +5,7 @@ use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, @@ -33,6 +33,8 @@ pub(crate) struct Entry { pub batch_time: Option, /// Block Allocation pub block_allocation: Option, + /// Cache length (in tokens) of the request (prompt tokens + generated_tokens) + pub cache_length: u32, } /// Request Queue @@ -162,9 +164,6 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, @@ -188,7 +187,6 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, block_allocator, } @@ -226,6 +224,11 @@ impl State { } } + // Check if max_size == 0 + if max_size == Some(0) { + return None; + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; @@ -274,41 +277,31 @@ impl State { } Some(block_allocator) => { prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { + if prefill_tokens > prefill_token_budget { // Entry is over budget // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + tracing::debug!( + "Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}" + ); self.entries.push_front((id, entry)); break; } - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - - match block_allocator.allocate(tokens).await { - None => { + let decode_tokens = + entry.request.stopping_parameters.max_new_tokens + self.speculate; + match block_allocator + .block_allocation(entry.request.input_length, decode_tokens) + { + Err(_) => { // Entry is over budget // Add it back to the front tracing::debug!("Over budget: not enough free blocks"); self.entries.push_front((id, entry)); break 'entry_loop; } - Some(block_allocation) => { + Ok(block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + max_blocks = max(max_blocks, block_allocation.blocks().len() as u32); Some(block_allocation) } } @@ -324,14 +317,10 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), - Some(block_allocation) => ( - block_allocation.blocks.clone(), - block_allocation.slots.clone(), - ), - }; - + let blocks = block_allocation + .as_ref() + .map(|block_allocation| block_allocation.blocks().to_vec()) + .unwrap_or_default(); entry.block_allocation = block_allocation; batch_requests.push(Request { @@ -350,7 +339,6 @@ impl State { )), top_n_tokens: entry.request.top_n_tokens, blocks, - slots, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -470,7 +458,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], - input_length: 0, + input_length: 1, truncate: 0, decoder_input_details: false, parameters: ValidParameters { @@ -498,6 +486,7 @@ mod tests { queue_time: Instant::now(), batch_time: None, block_allocation: None, + cache_length: 0, }; (entry, receiver_tx) } @@ -580,7 +569,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -702,7 +691,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(true, 1, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd8375f..3c7c59f59af 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -5,12 +5,14 @@ use crate::infer::{ }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; +use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; -use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::v3::{ + Batch, CachedBatch, Generation, KeptRequest, ShardedClient, TerminatedGeneration, +}; use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; @@ -88,6 +90,7 @@ impl Scheduler for SchedulerV3 { queue_time: Instant::now(), batch_time: None, block_allocation: None, + cache_length: 0, }); // Notify the background task that we have a new entry in the queue that needs @@ -161,7 +164,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue @@ -242,11 +246,34 @@ async fn prefill( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); + // Filter and send finished generations + let filtered_stream_responses = filter_send_ended_generations(generations, entries); + + // Iterate on intermediate generations + for (id, stream_responses) in filtered_stream_responses { + // Get entry + let entry = entries + .get_mut(&id) + .expect("ID not found in entries. This is a bug."); + + // Send intermediate responses + if send_stream_responses(stream_responses, entry).is_err() { + // Sending failed, remove entry + entries + .remove(&id) + .expect("ID not found in entries. This is a bug."); + } + } // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + let next_batch = match next_batch { + Some(batch) if batch.size as usize != entries.len() => { + let (filtered_batch, _) = + filter_batch(client, batch, entries, &IntMap::default()).await; + filtered_batch + } + batch => batch, + }; metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); @@ -284,11 +311,39 @@ async fn decode( generation_health.store(true, Ordering::SeqCst); let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; + // Filter and send finished generations + let mut filtered_stream_responses = filter_send_ended_generations(generations, entries); + + tracing::info!("filtered_stream: {:?}", start_filtering_time.elapsed()); + + // Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be + // re-allocated, + // Allocated new blocks for entries that go over their allocation + // Filter entries that couldn't be re-allocated and add them to `terminated_entries` + let (force_update, terminated_entries) = + filter_send_update_allocations(entries, &mut filtered_stream_responses); + + tracing::info!("filtered_update: {:?}", start_filtering_time.elapsed()); + + let next_batch = match next_batch { + // Run Only on re-allocation or if entries were filtered + Some(batch) if batch.size as usize != entries.len() || force_update => { + // Filter next batch: remove requests that were stopped and update blocks/slots + let (filtered_batch, terminated_generations) = + filter_batch(client, batch, entries, &terminated_entries).await; + tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed()); + send_terminated_generations( + terminated_generations, + terminated_entries, + filtered_stream_responses, + ); + tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed()); + + filtered_batch + } + batch => batch, + }; if let Some(concat_duration) = timings.concat { metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); @@ -314,78 +369,269 @@ async fn decode( } /// Filter a `batch` and remove all requests not present in `entries` +/// Ask the server to generate the full texts for entries in `terminated_entries` #[instrument(skip_all)] async fn filter_batch( client: &mut ShardedClient, - next_batch: Option, + batch: CachedBatch, entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - + terminated_entries: &IntMap, +) -> (Option, Vec) { let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { + if entries.is_empty() && terminated_entries.is_empty() { // All requests have been filtered out // Next batch is now empty // Clear it from the Python shards cache // We unwrap here as we need to panic since we cannot recover if this method fails client.clear_cache(Some(id)).await.unwrap(); - None + Default::default() } else { + let max_blocks = entries + .iter() + .map(|(_, entry)| { + entry + .block_allocation + .as_ref() + .map(|alloc| alloc.blocks().len()) + }) + .max() + .flatten(); + + let start_time = Instant::now(); + + // Collect new blocks + let updated_requests = entries + .iter() + .map(|(request_id, entry)| { + let (blocks, padded_blocks) = entry + .block_allocation + .as_ref() + .map(|alloc| { + let blocks = alloc.blocks().to_vec(); + let mut padded_blocks = blocks.clone(); + + if let Some(max_blocks) = max_blocks { + padded_blocks.resize(max_blocks, 0); + } + + (blocks, padded_blocks) + }) + .unwrap_or_default(); + + KeptRequest { + id: *request_id, + blocks, + padded_blocks, + } + }) + .collect(); + + tracing::info!("Collect blocks: {:?}", start_time.elapsed()); + // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() + client + .filter_batch( + id, + updated_requests, + terminated_entries.keys().copied().collect(), + ) + .await + .unwrap() } } -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries +/// Send `InferStreamResponse::Intermediate` and the final `InferStreamResponse::End` messages +/// to terminated requests +/// It modifies the last `InferStreamResponse::Intermediate` to add the final full text in +/// `terminated_generations` #[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { +fn send_terminated_generations( + terminated_generations: Vec, + terminated_entries: IntMap, + mut stream_responses: IntMap>, +) { + // Receive final message for terminated generations + 'terminated_generations: for terminated_generation in terminated_generations { + let id = terminated_generation.id; + // Get entry for this generation + let entry = terminated_entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + // Get previous `InferStreamResponse` for this generation + let stream_responses = stream_responses + .remove(&id) + .expect("ID not found in stream_responses. This is a bug."); + + // Peekable iterator to know when we are at the last `InferStreamResponse` + let mut iterator = stream_responses.into_iter().peekable(); + + while let Some(stream_response) = iterator.next() { + let response = if iterator.peek().is_none() { + // Last `InferStreamResponse::Intermediate` + let (token, top_tokens) = match stream_response { + InferStreamResponse::Intermediate { token, top_tokens } => (token, top_tokens), + _ => unreachable!(), + }; + // Modify it to be a `InferStreamResponse::End` with the new OutOfResources finish + // reason + InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from( + terminated_generation + .generated_text + .clone() + .expect("Generated Text is None. This is a bug."), + ), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + } + } else { + stream_response + }; + + // Send responses + let send_result = entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }); + + if send_result.is_err() { + // The channel is dropped, skip the rest of the messages + continue 'terminated_generations; + } + } + } +} + +/// Send `InferStreamResponse::End` to `Infer` for finished entries and remove them from `entries` +/// Returns filtered `InferStreamResponse::Intermediate` generations +#[instrument(skip_all)] +fn filter_send_ended_generations( + generations: Vec, + entries: &mut IntMap, +) -> IntMap> { + generations.into_iter().filter_map(|generation| { let id = generation.request_id; // Get entry // We can `expect` here as the request id should always be in the entries let entry = entries - .get(&id) + .get_mut(&id) .expect("ID not found in entries. This is a bug."); // Create and enter a span to link this function back to the entry let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); + + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }).unwrap_or(true); - if stopped { + // Remove from entries and filter entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; } - }); + + // Update cache length + entry.cache_length = generation.cache_length; + + let (finished, stream_responses) = map_generation(generation, entry); + // If the generation has ended for this request, we send the responses to the channel and + // remove the entry to drop it and free its blocks + if finished { + let _ = send_stream_responses(stream_responses, entry); + // Remove from entries and filter + entries.remove(&id).expect("ID not found in entries. This is a bug."); + return None; + } + + Some((id, stream_responses)) + }).collect() } -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, +/// Send `InferStreamResponse` to `Infer` through an `Entry` response channel +#[instrument(skip_all)] +fn send_stream_responses( + stream_responses: Vec, entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - return Ok(true); +) -> Result<(), Box>>> { + for response in stream_responses { + entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + })?; + } + Ok(()) +} + +/// Check if block allocations need to be extended +/// If we don't have enough blocks, request will be filtered and added to an IntMap of +/// terminated entries. +/// If at least one entry allocation was extended, we return true to force an update +#[instrument(skip_all)] +fn filter_send_update_allocations( + entries: &mut IntMap, + stream_responses: &mut IntMap>, +) -> (bool, IntMap) { + let mut updated = false; + + let ids: Vec = entries.keys().copied().collect(); + let mut terminated_entries = + IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); + + for id in &ids { + let entry = entries + .get_mut(id) + .expect("ID not found in entries. This is a bug."); + + if let Some(block_allocation) = entry.block_allocation.as_mut() { + // Check if allocation can handle the current cache_length + if entry.cache_length > block_allocation.len() as u32 { + updated = true; + + // Extend allocation by asking for a new block + if let Err(err) = block_allocation.extend() { + // Failed to extend allocation + tracing::error!("Failed to extend allocation: {err}"); + metrics::increment_counter!("tgi_request_failure", "err" => "out_of_resources"); + + // Remove entry + let mut entry = entries + .remove(id) + .expect("ID not found in entries. This is a bug."); + // Clear block allocation + entry.block_allocation = None; + // Add it to terminated entries + terminated_entries.insert(*id, entry); + // Skip the rest of the logic to not send the intermediate messages + // This entry will be terminated and we will need to edit the last intermediate + // response to add the complete generated text + continue; + } + } + } + let stream_response = stream_responses + .remove(id) + .expect("ID not found in stream_responses. This is a bug."); + + // Send intermediate responses + if send_stream_responses(stream_response, entry).is_err() { + // Sending failed, remove entry + entries + .remove(id) + .expect("ID not found in entries. This is a bug."); + } } - let mut stopped = false; + (updated, terminated_entries) +} + +/// Map `Generation` to `<(bool, Vec<(u64, InferStreamResponse)>)>` +/// `bool` is `true` if the generation is finished +fn map_generation(generation: Generation, entry: &Entry) -> (bool, Vec) { + let mut finished = false; + let mut stream_responses = Vec::with_capacity(16); if let Some(prefill_tokens) = generation.prefill_tokens { // Create Token objects @@ -398,10 +644,8 @@ fn send_responses( .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Prefill(prefill_tokens)); } // Create last Token @@ -443,26 +687,24 @@ fn send_responses( match (&generation.generated_text, iterator.peek()) { (Some(generated_text), None) => { // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { + finished = true; + // Push to stream_responses + stream_responses.push(InferStreamResponse::End { token, top_tokens, generated_text: GeneratedText::from(generated_text.clone()), queued: entry.queue_time, start: entry.batch_time.unwrap(), - }))?; + }); } _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + // Push to stream_responses + stream_responses.push(InferStreamResponse::Intermediate { token, top_tokens }); } } } - Ok(stopped) + (finished, stream_responses) } /// Send errors to Infer for all `entries` @@ -488,6 +730,7 @@ impl From for GeneratedText { let v3_finish_reason = text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { + text_generation_client::v3::FinishReason::Terminated => FinishReason::OutOfResources, text_generation_client::v3::FinishReason::Length => FinishReason::Length, text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, diff --git a/router/src/lib.rs b/router/src/lib.rs index b0b93c13ae1..6a90da4a0fa 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1096,6 +1096,8 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, + #[schema(rename = "out_of_resources")] + OutOfResources, } impl std::fmt::Display for FinishReason { @@ -1104,6 +1106,7 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), + FinishReason::OutOfResources => write!(f, "out_of_resources"), } } } diff --git a/router/src/server.rs b/router/src/server.rs index aa872df98b7..ddc3310ee76 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1933,6 +1933,7 @@ impl From for (StatusCode, Json) { InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::OutOfPages => StatusCode::TOO_MANY_REQUESTS, }; ( diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686b6b..79cd00ccfde 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -197,7 +197,11 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], + [], + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -305,8 +309,13 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + next_batch, _ = next_batch.filter( + default_bloom, + [ + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), + ], + [], ) for _ in range( @@ -330,7 +339,11 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])], + [], + ) for _ in range( default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc948..c807a15e048 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -198,7 +198,11 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch, _ = next_batch.filter( + default_causal_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], + [], + ) for _ in range( stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 @@ -305,8 +309,13 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + next_batch, _ = next_batch.filter( + default_causal_lm, + [ + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), + ], + [], ) for _ in range( @@ -328,7 +337,13 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch, _ = next_batch.filter( + default_causal_lm, + [ + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), + ], + [], + ) for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b0820d..2eea96275e7 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -206,7 +206,11 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id]) + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[])], + [], + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -339,8 +343,13 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter( - [next_batch.requests[0].id, next_batch.requests[1].id] + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [ + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[]), + ], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -351,7 +360,11 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter([next_batch.requests[1].id]) + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[])], + [], + ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e896c831bd3..f3b94e8cbbb 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -158,11 +158,49 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + def filter( + self, + model: "CausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -258,7 +296,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") @@ -746,6 +784,7 @@ def generate_token( ), generated_text, top_tokens, + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d371068e..1182f3d4b29 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,13 +79,9 @@ class FlashCausalLMBatch(Batch): # Paged Attention values # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor - # list of length b of list of length s_i // block_size - block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences @@ -154,7 +150,6 @@ def from_tokenized( sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] - start_slots = [] slot_indices = [] prefill_cache_indices = [] @@ -176,7 +171,6 @@ def from_tokenized( # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -185,7 +179,7 @@ def from_tokenized( max_blocks = 0 block_tables = [] - slots = [] + flat_blocks = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -204,6 +198,9 @@ def from_tokenized( input_length = len(tokenized_input) input_lengths.append(input_length) + speculative_length = get_speculate() + speculative_length = 0 if speculative_length is None else speculative_length + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -226,36 +223,26 @@ def from_tokenized( top_n_tokens.append(r.top_n_tokens) # Paged attention - # Remove one as the first token des not have a past - speculative_length = get_speculate() - speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length - # blocks and slots can be empty (for example in warmup) if not r.blocks: + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + len(flat_blocks) * BLOCK_SIZE, + (len(flat_blocks) * BLOCK_SIZE) + input_length, dtype=torch.int64, ) + flat_blocks.extend(request_blocks) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill @@ -289,7 +276,6 @@ def from_tokenized( # Update cumulative_length += input_length - cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -299,7 +285,6 @@ def from_tokenized( next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -356,7 +341,13 @@ def from_tokenized( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -372,9 +363,7 @@ def from_tokenized( position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, @@ -408,12 +397,47 @@ def from_pb( return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - # We assume that if len(requests) == len(self) then the requests are the same - if len(request_ids) == len(self): - return self + def filter( + self, + model: "FlashCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + start = time.time_ns() + + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + do_sample = self.next_token_chooser.do_sample[idx] + seed = self.next_token_chooser.seeds[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed if do_sample else None, + ), + ) + ) + + from loguru import logger + + logger.info(f"terminated generations {(time.time_ns() - start)/1e6}") + + if not kept_requests: + return None, terminated_generations device = self.input_ids.device @@ -423,18 +447,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Used to index into tensors indices = [] - # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) - - # Create on CPU to only move to GPU once instead of at every copy - slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + slot_indices = [] max_seqlen = 0 requests = [] - start_slots = [] - block_tables = [] + flat_blocks = [] + padded_blocks = [] all_input_ids = [] input_lengths = [] @@ -446,10 +464,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 - for i, request_id in enumerate(request_ids): + for i, request in enumerate(kept_requests): + request_id = request.id + idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i @@ -471,49 +489,53 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": top_n_tokens.append(self.top_n_tokens[idx]) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - - request_block_table = self.block_tables[idx] - num_blocks += len(request_block_table) - block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 - - # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True + request_block_table = request.blocks + flat_blocks.extend(request_block_table) + padded_blocks.extend(request.padded_blocks) - cumulative_max_length += request_input_length + remaining_tokens - 1 + # Index + slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) + num_blocks += len(request_block_table) max_blocks = max(max_blocks, len(request_block_table)) + logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}") + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] - block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) + logger.info(f"slice objects: {(time.time_ns() - start)/1e6}") - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # Create block_tables_tensor on GPU + block_tables_tensor = torch.tensor( + padded_blocks, dtype=torch.int32, device=device + ).view(len(requests), -1) - return type(self)( + logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}") + + # Allocate on GPU + slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + + # Move to GPU + block_tables_tensor = block_tables_tensor.to(device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() + + logger.info(f"done allocation: {(time.time_ns() - start)/1e6}") + + filtered_batch = type(self)( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -521,9 +543,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, @@ -544,6 +564,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": max_blocks=max_blocks, speculative_ids=speculative_ids, ) + return filtered_batch, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") @@ -567,6 +588,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) + # When we filter, we do not recompute this value so we do so here max_length = max( max_length, max( @@ -597,8 +619,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size, ) - start_slots = [] - block_tables = [] all_input_ids = [] input_lengths = [] @@ -645,9 +665,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - start_slots.append(batch.start_slots + cumulative_slots) - - block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) @@ -664,8 +681,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - start_slots = torch.concat(start_slots) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, @@ -688,9 +703,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, @@ -1374,6 +1387,7 @@ def generate_token( ), generated_text, top_tokens, + input_length + n_accepted_ids, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669936..f92378cbcac 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -214,12 +214,52 @@ def from_pb_processor( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + def filter( + self, + model: "IdeficsCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[ + Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration] + ]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -326,7 +366,7 @@ def filter(self, request_ids: List[int]) -> Optional["IdeficsCausalLMBatch"]: self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") @@ -829,6 +869,7 @@ def generate_token( ), generated_text, top_tokens, + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137d9a..64cb739e1dd 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -195,11 +195,49 @@ def from_pb( max_tokens=max_tokens, ) - def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + def filter( + self, + model: "Mamba", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -274,7 +312,7 @@ def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: :, indices ] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self + return self, terminated_generations @classmethod def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": @@ -775,6 +813,7 @@ def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, in ), generated_text, top_tokens, + new_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3bd095564c3..3cf874fac26 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -166,11 +166,50 @@ def from_pb( ) @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + def filter( + self, + model: "Seq2SeqLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_decoder_input_ids = self.all_decoder_input_ids[idx] + decoder_input_length = self.decoder_input_lengths[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] + + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -277,7 +316,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: self.padding_right_offset = padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") @@ -801,6 +840,7 @@ def generate_token( ), generated_text, top_tokens, + new_decoder_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 339b733b5f6..0b7868fceb3 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from transformers import PreTrainedTokenizerBase @@ -28,7 +28,12 @@ def from_pb( raise NotImplementedError @abstractmethod - def filter(self, request_ids: List[int]) -> "Batch": + def filter( + self, + model, + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["Batch"], List[generate_pb2.TerminatedGeneration]]: raise NotImplementedError @classmethod @@ -84,6 +89,7 @@ class Generation: generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. top_tokens: Optional[List[Tokens]] + cache_length: int def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -100,4 +106,5 @@ def to_pb(self) -> generate_pb2.Generation: if self.top_tokens is not None else None ), + cache_length=self.cache_length, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 8b5819d17cd..cd5dd3ea14b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -124,12 +124,20 @@ def concatenate(cls, batches): return batch @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + def filter( + self, + model: "VlmCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + batch, terminated_generations = super().filter( + model, kept_requests, terminated_request_ids + ) + if batch is not None: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch, terminated_generations @classmethod def batch_tokenized_inputs( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a0347cd8e73..195d042b747 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -83,10 +83,16 @@ async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") - filtered_batch = batch.filter(request.request_ids) - self.cache.set(filtered_batch) + filtered_batch, terminated_generations = batch.filter( + self.model, request.kept_requests, request.terminated_request_ids + ) + if filtered_batch is not None: + self.cache.set(filtered_batch) - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + return generate_pb2.FilterBatchResponse( + batch=filtered_batch.to_pb() if filtered_batch is not None else None, + terminated_generations=terminated_generations, + ) async def Warmup(self, request, context): if self.quantize in {"exl2", "gptq"}: