Skip to content

Commit

Permalink
router: send the input as chunks to the backend
Browse files Browse the repository at this point in the history
Before this change, the generation input was sent to the backend as a
single string, encoding images as Base64 and packing them in
Markdown-style links.

This change adds a new chunked input representation that separates text
chunks from images chunks. Image chunks contain binary data (for smaller
message sizes) and the image's MIME type.

The stringly-typed inputs are still sent to support backends that do not
support chunked inputs yet.
  • Loading branch information
danieldk authored and Daniël de Kok committed Jun 3, 2024
1 parent d1d724b commit df71aaf
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 69 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }

Expand Down
7 changes: 5 additions & 2 deletions benchmark/src/generation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request,
ShardedClient, StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
Expand Down Expand Up @@ -142,6 +142,9 @@ async fn prefill(
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),
Expand Down
25 changes: 24 additions & 1 deletion proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}

message Image {
/// Binary image data.
bytes data = 1;

/// Image MIME type.
string mimetype = 2;
}

message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}

message Input {
repeated InputChunk chunks = 1;
}

enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
Expand Down Expand Up @@ -95,7 +116,9 @@ message StoppingCriteriaParameters {
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;
Expand Down
2 changes: 1 addition & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = "0.22.0"
base64 = { workspace = true }

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
1 change: 1 addition & 0 deletions router/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors.workspace = true
homepage.workspace = true

[dependencies]
base64 = { workspace = true }
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12"
Expand Down
31 changes: 28 additions & 3 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
/// Single shard Client
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
use crate::Result;
use crate::{Chunk, Result};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;

static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";

/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
Expand Down Expand Up @@ -113,18 +117,39 @@ impl Client {
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);

let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}

// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}

requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
input_chunks: Some(Input {
chunks: input_chunks,
}),
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
Expand Down
35 changes: 34 additions & 1 deletion router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ mod client;
mod pb;
mod sharded_client;

use base64::{engine::general_purpose::STANDARD, Engine};
pub use client::Client;
pub use pb::generate::v2::input_chunk::Chunk;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::Image;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
Expand Down Expand Up @@ -44,3 +47,33 @@ impl From<transport::Error> for ClientError {
}

pub type Result<T> = std::result::Result<T, ClientError>;

// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}

/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}

impl ChunksToString for Vec<InputChunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}
14 changes: 7 additions & 7 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
text_config: TextConfig,
vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>,
pub(crate) text_config: TextConfig,
pub(crate) vision_config: VisionConfig,
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
}

fn get_anyres_image_grid_shape(
Expand Down Expand Up @@ -119,13 +119,13 @@ impl Idefics2 {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
pub(crate) num_image_tokens: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
pub(crate) text_config: PaliTextConfig,
}

impl Paligemma {
Expand Down Expand Up @@ -175,8 +175,8 @@ pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
image_size: usize,
patch_size: usize,
pub(crate) image_size: usize,
pub(crate) patch_size: usize,
}

#[cfg(test)]
Expand Down
7 changes: 5 additions & 2 deletions router/src/health.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};

// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
Expand Down Expand Up @@ -33,6 +33,9 @@ impl Health {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
Expand Down
9 changes: 7 additions & 2 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use text_generation_client::{Batch, Request};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
Expand Down Expand Up @@ -278,7 +280,10 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(),
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
Expand Down Expand Up @@ -366,7 +371,7 @@ mod tests {

let entry = Entry {
request: ValidGenerateRequest {
inputs: String::new(),
inputs: vec![],
input_length: 0,
truncate: 0,
decoder_input_details: false,
Expand Down
Loading

0 comments on commit df71aaf

Please sign in to comment.