Skip to content

Commit

Permalink
Improve vlm support (add idefics3 support) (#2437)
Browse files Browse the repository at this point in the history
* feat: expand vlm support and add image token logic and tests

* fix: avoid unused perceiver config

* feat: integrate image tokens into inputs embeds

* feat: add simple idefics3 test

* feat: update docs, image token logic and weight names

* fix: improve image processing

* feat: improve prefix for idefics3

* fix: bump idefics3 tests and snapshots

* fix: improve text model loading

* feat: consolidate changes with existing vlms and add support and test for smolvlm

* fix: create new idefic3 file, simplify logic and adjust llama weight loading

* fix: lint with ruff

* fix: clean up idefics 3 and improve prefix handling

* fix: improve typing

* fix: improve prompt_split_image with ref to original impl

* fix: adjust ruff lints and small refactors

* fix: adjust FlashLlamaModel prefix logic
  • Loading branch information
drbh authored Jan 9, 2025
1 parent a9c7d2e commit da5ab46
Show file tree
Hide file tree
Showing 15 changed files with 988 additions and 43 deletions.
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio

- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
Expand Down
4 changes: 4 additions & 0 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def local_launcher(
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_input_tokens: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
Expand Down Expand Up @@ -402,6 +403,9 @@ def local_launcher(
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_input_tokens:
args.append("--max-input-tokens")
args.append(str(max_input_tokens))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2684,
"logprob": -0.24902344,
"special": false,
"text": " There"
},
{
"id": 374,
"logprob": -0.0703125,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.23535156,
"special": false,
"text": " a"
},
{
"id": 35372,
"logprob": -0.125,
"special": false,
"text": " statue"
},
{
"id": 304,
"logprob": -0.30273438,
"special": false,
"text": " in"
},
{
"id": 279,
"logprob": -0.20507812,
"special": false,
"text": " the"
},
{
"id": 2217,
"logprob": -0.076171875,
"special": false,
"text": " image"
},
{
"id": 13,
"logprob": -0.053710938,
"special": false,
"text": "."
},
{
"id": 128258,
"logprob": -0.011352539,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " There is a statue in the image."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 330,
"logprob": -0.118652344,
"special": false,
"text": " A"
},
{
"id": 11426,
"logprob": -0.28320312,
"special": false,
"text": " bee"
},
{
"id": 335,
"logprob": -0.95703125,
"special": false,
"text": " on"
},
{
"id": 253,
"logprob": -0.06982422,
"special": false,
"text": " a"
},
{
"id": 11986,
"logprob": -0.49414062,
"special": false,
"text": " pink"
},
{
"id": 8525,
"logprob": -0.07763672,
"special": false,
"text": " flower"
},
{
"id": 30,
"logprob": -1.0703125,
"special": false,
"text": "."
},
{
"id": 49154,
"logprob": -0.092285156,
"special": true,
"text": "<end_of_utterance>"
}
],
"top_tokens": null
},
"generated_text": " A bee on a pink flower."
}
31 changes: 31 additions & 0 deletions integration-tests/models/test_idefics3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest


@pytest.fixture(scope="module")
def flash_idefics3_next_handle(launcher):
with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_idefics3_next(flash_idefics3_next_handle):
await flash_idefics3_next_handle.health(300)
return flash_idefics3_next_handle.client


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):
ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
query = "What is in this image?"
response = await flash_idefics3_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " There is a statue in the image."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 9
assert response == response_snapshot
31 changes: 31 additions & 0 deletions integration-tests/models/test_smolvlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest


@pytest.fixture(scope="module")
def flash_smolvlm_next_handle(launcher):
with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_smolvlm_next(flash_smolvlm_next_handle):
await flash_smolvlm_next_handle.health(300)
return flash_smolvlm_next_handle.client


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):
ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"
query = "What is in this image?"
response = await flash_smolvlm_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text == " A bee on a pink flower."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 8
assert response == response_snapshot
19 changes: 19 additions & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ pub struct ClipVisionModel {
patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {}

impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}

pub fn get_number_of_features(&self) -> usize {
169
}

pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
Expand Down Expand Up @@ -178,6 +196,7 @@ pub enum Config {
Idefics,
Mllama,
Idefics2(Idefics2),
Idefics3(Idefics3),
Ssm,
GptBigcode,
Granite,
Expand Down
1 change: 1 addition & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ impl TokenizerConfigToken {
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
}

impl HubPreprocessorConfig {
Expand Down
70 changes: 69 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,73 @@ fn image_tokens(

image_string
}
Idefics3(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const GLOBAL_IMG: &str = "<global-img>";

let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();

// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
let (height, width) = if height > max_longest_edge_for_image_resize
|| width > max_longest_edge_for_image_resize
{
let aspect_ratio = height as f32 / width as f32;
if height > width {
(
max_longest_edge_for_image_resize,
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
)
} else {
(
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
max_longest_edge_for_image_resize,
)
}
} else {
(height, width)
};

let image_seq_len = config.get_number_of_features();
let max_edge = config.get_max_longest_edge();

let (image_rows, image_cols) = if height > max_edge || width > max_edge {
(
(height as f32 / max_edge as f32).ceil() as usize,
(width as f32 / max_edge as f32).ceil() as usize,
)
} else {
(0, 0)
};

let mut image_string = String::new();

if image_rows == 0 && image_cols == 0 {
// Single image case
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
} else {
// Split image case
for n_h in 0..image_rows {
for n_w in 0..image_cols {
image_string.push_str(FAKE);
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
image_string.push_str(&IMAGE.repeat(image_seq_len));
}
image_string.push('\n');
}

image_string.push('\n');
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
}

image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!(
Expand Down Expand Up @@ -647,7 +714,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
Expand Down
27 changes: 27 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
Expand Down Expand Up @@ -188,6 +191,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
Expand Down Expand Up @@ -1253,6 +1262,24 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA:
if FLASH_ATTENTION:
return VlmCausalLM(
Expand Down
Loading

0 comments on commit da5ab46

Please sign in to comment.