Skip to content

Commit f171bdc

Browse files
authored
Inline images for multimodal models. (#1666)
1 parent 66914f7 commit f171bdc

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

integration-tests/models/test_idefics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ async def idefics(idefics_handle):
2020
def get_chicken():
2121
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
2222
encoded_string = base64.b64encode(image_file.read())
23-
return f"data:image/png;base64,{encoded_string}"
23+
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
2424

2525

2626
@pytest.mark.asyncio
2727
async def test_idefics(idefics, response_snapshot):
28+
chicken = get_chicken()
2829
response = await idefics.generate(
29-
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?",
30+
f"User:![]({chicken})Can you tell me a very short story based on the image?",
3031
max_new_tokens=10,
3132
decoder_input_details=True,
3233
)
@@ -37,9 +38,10 @@ async def test_idefics(idefics, response_snapshot):
3738

3839
@pytest.mark.asyncio
3940
async def test_idefics_load(idefics, generate_load, response_snapshot):
41+
chicken = get_chicken()
4042
responses = await generate_load(
4143
idefics,
42-
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?",
44+
f"User:![]({chicken})Can you tell me a very short story based on the image?",
4345
max_new_tokens=10,
4446
n=4,
4547
)

router/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
4646
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
4747
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
4848
futures-util = "0.3.30"
49+
regex = "1.10.3"
50+
once_cell = "1.19.0"
4951

5052
[build-dependencies]
5153
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

router/src/validation.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use tokenizers::TruncationDirection;
1313
use tokio::sync::mpsc;
1414
use tokio::sync::oneshot;
1515
use tracing::{instrument, Span};
16+
use {once_cell::sync::Lazy, regex::Regex};
1617

1718
/// Validation
1819
#[derive(Debug, Clone)]
@@ -409,10 +410,14 @@ async fn round_robin_task(
409410
/// Start tokenization workers
410411
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
411412
// Loop over requests
413+
let is_multimodal = {
414+
let vocab = tokenizer.get_vocab(true);
415+
vocab.contains_key("<image>")
416+
};
412417
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
413418
parent_span.in_scope(|| {
414419
response_tx
415-
.send(prepare_input(inputs, truncate, &tokenizer))
420+
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
416421
.unwrap_or(())
417422
})
418423
}
@@ -423,15 +428,22 @@ fn prepare_input(
423428
mut inputs: String,
424429
truncate: Option<usize>,
425430
tokenizer: &Tokenizer,
431+
is_multimodal: bool,
426432
) -> Result<(tokenizers::Encoding, String), ValidationError> {
433+
let simplified_query = if is_multimodal {
434+
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
435+
RE.replace_all(&inputs, "<image>").into()
436+
} else {
437+
inputs.clone()
438+
};
427439
// Get the number of tokens in the input
428440
let mut encoding = tokenizer
429-
.encode(inputs.clone(), true)
441+
.encode(simplified_query, true)
430442
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
431443

432444
// Optionally truncate
433445
if let Some(truncate) = truncate {
434-
if truncate < encoding.len() {
446+
if truncate < encoding.len() && !is_multimodal {
435447
encoding.truncate(truncate, 0, TruncationDirection::Left);
436448
inputs = tokenizer
437449
.decode(encoding.get_ids(), false)

0 commit comments

Comments
 (0)