From d96eef2a026f90da9c51dd855fd5c7053c58ab45 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Thu, 24 Oct 2024 15:36:53 +0000 Subject: [PATCH 01/18] feat: add support for qwen2 vl model --- .../text_generation_server/layers/rotary.py | 28 + .../layers/tensor_parallel.py | 2 +- .../text_generation_server/models/__init__.py | 20 + .../custom_modeling/flash_qwen2_modeling.py | 43 +- .../models/custom_modeling/qwen2_vl.py | 544 ++++++++++++++++++ .../models/vlm_causal_lm.py | 22 +- 6 files changed, 653 insertions(+), 6 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/qwen2_vl.py diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index a2076bb2078..6e2ba228332 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -89,6 +89,8 @@ def static(cls, config, dim, base, device): if rope_type == "linear": pass + elif rope_type == "default": + pass elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -275,6 +277,32 @@ def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) + def get_cos_sin_hack( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + # TODO: avoid always computing, use the cache and update it if necessary + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 2, 3 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + + # Update cached values + self._cos_cached = cos + self._sin_cached = sin + + return cos, sin + class SuRotaryEmbedding(PositionRotaryEmbedding): def __init__( diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1ec1..febf28da71a 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -144,7 +144,7 @@ def load_qkv( num_key_value_heads=num_key_value_heads, ) if bias: - raise NotImplementedError("packed_qkv only implemented for baichuan") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None linear = get_linear(weight, bias) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 99e3d3430a0..6c633521090 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -146,6 +146,9 @@ from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") @@ -275,6 +278,11 @@ class ModelType(enum.Enum): "name": "Qwen 2", "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } OPT = { "type": "opt", "name": "Opt", @@ -1193,6 +1201,18 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == QWEN2_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index ab2a177db6a..e9be22b1c8b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -49,6 +49,13 @@ def _load_gqa(config, prefix: str, weights): ) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + class Qwen2Attention(torch.nn.Module): def __init__( self, @@ -61,6 +68,7 @@ def __init__( config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads + self.mrope_section = config.rope_scaling.get("mrope_section", None) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -122,7 +130,28 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + # TODO: correctly handle the multimodal case + if False: + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + else: + # multimodal rotary + unsqueeze_dim = 1 + mrope_section = self.mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + + _query = query.transpose(0, 1).unsqueeze(0) + _key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0) + q_embed = (_query * cos) + (rotate_half(_query) * sin) + k_embed = (_key * cos) + (rotate_half(_key) * sin) + query = q_embed.squeeze(0).transpose(0, 1) + kv[:, 0] = k_embed.squeeze(0).transpose(0, 1) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -306,12 +335,20 @@ def forward( max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + + # if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids + if inputs_embeds is not None: + hidden_states = inputs_embeds.squeeze(0) + else: + hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + # TODO: fix how N-D position_ids are handled + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( position_ids, true_max_s, hidden_states.dtype ) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 00000000000..28217d7dcfd --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import flash_attn_2_cuda + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=True, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + + self.proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + # TODO: make use of existing RotatoryPositionEmbedding class + + # create the attention mask + attention_mask = torch.zeros( + [1, hidden_state.shape[0], hidden_state.shape[0]], + device=hidden_state.device, + dtype=torch.bool, + ) + # TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it + + # apply the cu_seqlens to the attention mask + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + + # transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # apply attention + attn_output = F.scaled_dot_product_attention( + query, key, value, attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + # TODO: prefer flash attention + + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLSdpaAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states_post_norm1, res = self.norm1(hidden_states) + hidden_states = hidden_states + self.attn( + hidden_states_post_norm1, cu_seqlens, rotary_pos_emb + ) + hidden_states_post_norm2, res = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + context_dim = 2560 + spatial_merge_size: int = 2 + self.hidden_size = 5120 # context_dim * (spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states, grid_thw) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states, grid_thw) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + + # make an attention_mask that is the same size as the input_ids + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + + inputs_embeds = self.text_model.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + # input embeddings are masked with image embeddings + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # handle the position_ids taking the multimodal inputs into account + mrope_position_deltas = [] + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + # determine the number of images and videos in the input + image_nums = (vision_tokens == self.image_token_id).sum() + video_nums = (vision_tokens == self.video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + # process each input based on it's token type and grid size + for _ in range(image_nums + video_nums): + if self.image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(self.image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if self.video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(self.video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // self.spatial_merge_size, + w.item() // self.spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + + # TODO: adjust model to accept 2D position_ids + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + return outputs, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4bbddcfb4cd..1b8e7f88c6b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -67,6 +67,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + return "" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -137,6 +139,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] @classmethod @tracer.start_as_current_span("concatenate") @@ -145,6 +148,7 @@ def concatenate(cls, batches): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @tracer.start_as_current_span("filter") @@ -153,6 +157,7 @@ def filter(self, request_ids: List[int]): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @classmethod @@ -178,6 +183,10 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: + # TODO: REMOVE (this is for debugging purposes) + images = images[0][0].resize( + (images[0][0].width * 2, images[0][0].height * 2) + ) image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None @@ -237,10 +246,15 @@ def from_pb_processor( batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None else: batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @@ -381,8 +395,9 @@ def forward( max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + # TODO: remove the unsqueeze(0) + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -394,6 +409,7 @@ def forward( pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -403,6 +419,8 @@ def forward( batch.pixel_attention_mask = None if batch.image_sizes is not None: batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph From 09ac4fb6eba6f29bf1a82962958221ff34531a8d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Thu, 24 Oct 2024 19:57:47 +0000 Subject: [PATCH 02/18] feat: fix token padding, enable warmup and process basic request --- router/src/config.rs | 29 +++++++++++++++++++ router/src/validation.rs | 8 ++++- .../custom_modeling/flash_qwen2_modeling.py | 4 +++ .../models/custom_modeling/qwen2_vl.py | 18 ++++++++---- .../models/vlm_causal_lm.py | 13 +++++---- 5 files changed, 60 insertions(+), 12 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index ce066ad00ca..7fc27f960f6 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -138,10 +138,39 @@ impl Paligemma { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2VlVisionConfig { + pub(crate) depth: usize, + pub(crate) embed_dim: usize, + pub(crate) mlp_ratio: usize, + pub(crate) num_heads: usize, + pub(crate) in_chans: usize, + pub(crate) hidden_size: usize, + pub(crate) patch_size: usize, + pub(crate) spatial_merge_size: usize, + pub(crate) spatial_patch_size: usize, + pub(crate) temporal_patch_size: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2Vl { + pub(crate) vision_config: Qwen2VlVisionConfig, +} + +impl Qwen2Vl { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + // TODO: calculate number of features + 6000 / 4 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub enum Config { + Qwen2Vl(Qwen2Vl), LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, diff --git a/router/src/validation.rs b/router/src/validation.rs index 8159ede40d4..5b2a153ce2a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -594,6 +594,10 @@ fn image_tokens( } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), + Qwen2Vl(config) => format!( + "<|vision_start|>{:?}<|vision_end|>", + "<|image_pad|>".repeat(config.get_number_of_features(height, width)) + ), _ => unimplemented!("Images tokens are not supported for this model configuration"), } } @@ -620,7 +624,9 @@ fn prepare_input( use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { - Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { + Some( + config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e9be22b1c8b..7ae432563ac 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -348,6 +348,10 @@ def forward( # Get rotary cos and sin for this forward # Avoid to index in each layer # TODO: fix how N-D position_ids are handled + + if position_ids.dim() == 2: + position_ids = position_ids.unsqueeze(0) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( position_ids, true_max_s, hidden_states.dtype ) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 28217d7dcfd..ac66695a75b 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -34,6 +34,7 @@ ) from text_generation_server.layers import ( TensorParallelColumnLinear, + TensorParallelRowLinear, FastLinear, ) from text_generation_server.layers.attention import ( @@ -352,6 +353,7 @@ def forward( class Qwen2VLForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + self.config = config config.vision_config.quantize = None config.vision_config.speculator = config.speculator self.hidden_size = config.hidden_size @@ -364,6 +366,10 @@ def __init__(self, prefix, config, weights): prefix="visual", config=config.vision_config, weights=weights ) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + self.lm_head = FastLinear.load( + prefix="lm_head", weights=weights, config=config, bias=False + ) + self.device = weights.device def forward( self, @@ -386,10 +392,10 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - - # make an attention_mask that is the same size as the input_ids - attention_mask = torch.ones_like(input_ids, dtype=torch.bool) - + # make an attention_mask that is (batch_size, sequence_length) + attention_mask = torch.ones_like( + input_ids, dtype=torch.bool, device=input_ids.device + ) inputs_embeds = self.text_model.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided @@ -525,7 +531,6 @@ def forward( mrope_position_deltas, device=input_ids.device ).unsqueeze(1) - # TODO: adjust model to accept 2D position_ids outputs = self.text_model( input_ids=input_ids, position_ids=position_ids, @@ -541,4 +546,5 @@ def forward( attention_mask=attention_mask, ) - return outputs, None + logits = self.lm_head(outputs) + return logits, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1b8e7f88c6b..7625c3051c3 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -68,7 +68,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens elif config.model_type == "qwen2_vl": - return "" + num_pads = image_input.pixel_values.shape[0] // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -183,10 +185,11 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - # TODO: REMOVE (this is for debugging purposes) - images = images[0][0].resize( - (images[0][0].width * 2, images[0][0].height * 2) - ) + if images[0][0].width <= 20: + # TODO: provide a better way to handle the issue of the prefill image being too small + images = images[0][0].resize( + (images[0][0].width * 2, images[0][0].height * 2) + ) image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None From 22fdf9344fd2e90bd8c1c26c765f6caa21b7ecd1 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 02:15:48 +0000 Subject: [PATCH 03/18] fix: improve get_position_ids, add lift embed_tokens --- .../layers/tensor_parallel.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 74 +++--- .../models/custom_modeling/qwen2_vl.py | 237 +++++++----------- 3 files changed, 128 insertions(+), 185 deletions(-) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index febf28da71a..13f12ef1ec1 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -144,7 +144,7 @@ def load_qkv( num_key_value_heads=num_key_value_heads, ) if bias: - bias = weights.get_tensor(f"{prefix}.bias") + raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None linear = get_linear(weight, bias) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7ae432563ac..5fe39bc9147 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,28 +130,23 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # TODO: correctly handle the multimodal case - if False: - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - else: - # multimodal rotary - unsqueeze_dim = 1 - mrope_section = self.mrope_section * 2 - cos = torch.cat( - [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) - sin = torch.cat( - [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], - dim=-1, - ).unsqueeze(unsqueeze_dim) - - _query = query.transpose(0, 1).unsqueeze(0) - _key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0) - q_embed = (_query * cos) + (rotate_half(_query) * sin) - k_embed = (_key * cos) + (rotate_half(_key) * sin) - query = q_embed.squeeze(0).transpose(0, 1) - kv[:, 0] = k_embed.squeeze(0).transpose(0, 1) + _query = query.clone() + _cos = cos.clone() + _sin = sin.clone() + + self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin) + + _cos = torch.cat((_cos, _cos), dim=-1) + _sin = torch.cat((_sin, _sin), dim=-1) + q_emb = (_query * _cos).reshape(2, 1, -1) + ( + rotate_half(_query) * _sin + ).reshape(2, 1, -1) + k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + ( + rotate_half(torch.select(kv, dim=1, index=0)) * _sin + ).reshape(2, 1, -1) + + query = q_emb.reshape(-1, self.num_heads, self.head_size) + kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -299,9 +294,6 @@ def __init__(self, prefix: str, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights - ) self.layers = nn.ModuleList( [ Qwen2Layer( @@ -325,7 +317,7 @@ def __init__(self, prefix: str, config, weights): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -335,25 +327,12 @@ def forward( max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + hidden_states = inputs_embeds - # if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids - if inputs_embeds is not None: - hidden_states = inputs_embeds.squeeze(0) - else: - hidden_states = self.embed_tokens(input_ids) - - # Get rotary cos and sin for this forward - # Avoid to index in each layer - # TODO: fix how N-D position_ids are handled - - if position_ids.dim() == 2: - position_ids = position_ids.unsqueeze(0) - - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( - position_ids, true_max_s, hidden_states.dtype + # TODO: ensure we are getting the correct positional embeddings + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids[0, 0, :], true_max_s, hidden_states.dtype ) residual = None @@ -393,6 +372,11 @@ def __init__(self, prefix: str, config, weights): prefix=f"{prefix}.{suffix}" if prefix else suffix, weights=weights, ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.max_past = config.sliding_window self.max_past_tensor = ( torch.tensor(config.sliding_window, device=weights.device) @@ -423,8 +407,10 @@ def forward( # kernel requires the true values seqlen = seqlen.clamp(max=self.max_past_tensor) + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index ac66695a75b..2eb7b97805a 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -35,6 +35,7 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, + TensorParallelEmbedding, FastLinear, ) from text_generation_server.layers.attention import ( @@ -78,11 +79,11 @@ def __init__(self, *, prefix, config, weights): config, prefix=f"{prefix}.qkv", weights=weights, - bias=True, + bias=False, num_heads=self.num_heads, num_key_value_heads=self.num_heads, ) - + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) self.proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.proj", @@ -285,7 +286,6 @@ def forward( self, pixel_values: torch.Tensor, aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, grid_thw: Optional[torch.LongTensor] = None, ) -> torch.Tensor: # reshape the input tensor for processing @@ -361,7 +361,9 @@ def __init__(self, prefix, config, weights): self.image_token_id = config.image_token_id self.video_token_id = config.video_token_id self.spatial_merge_size = config.vision_config.spatial_merge_size - + self.embed_tokens = TensorParallelEmbedding( + prefix=f"model.embed_tokens", weights=weights + ) self.visual = Qwen2VisionModel( prefix="visual", config=config.vision_config, weights=weights ) @@ -371,6 +373,93 @@ def __init__(self, prefix, config, weights): ) self.device = weights.device + def get_position_ids( + self, + batch_input_ids: torch.Tensor, + image_grid_thw: Optional[torch.LongTensor], + # video_grid_thw is not implemented yet as we do not accept video inputs at the moment + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = torch.ones( + 3, + batch_input_ids.shape[0], + batch_input_ids.shape[1], + dtype=batch_input_ids.dtype, + device=batch_input_ids.device, + ) + d = batch_input_ids.device + if image_grid_thw is not None: + image_index = 0 + llm_pos_ids_list = [] + + for i, input_ids in enumerate(batch_input_ids): + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + # only copy the sum of the image tokens GPU<->CPU + image_count = (vision_tokens == self.image_token_id).sum().item() + + current_pos = 0 + for _ in range(image_count): + # copy the value position of the next image token from GPU<->CPU + next_image_pos = ( + (input_ids[current_pos:] == self.image_token_id) + .nonzero()[0] + .item() + ) + # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop + time_steps, height, width = image_grid_thw[image_index] + height //= self.spatial_merge_size + width //= self.spatial_merge_size + + # calculate the length of the text and image tokens + text_length = next_image_pos - current_pos + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + ) + + # text position ids + text_pos_ids = torch.arange(text_length, device=d) + text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx + llm_pos_ids_list.append(text_pos_ids) + + # image position ids + t_indices = torch.arange(time_steps, device=d).repeat_interleave( + height * width + ) + h_indices = ( + torch.arange(height, device=d) + .repeat_interleave(width) + .repeat(time_steps) + ) + w_indices = torch.arange(width, device=d).repeat( + height * time_steps + ) + + image_pos_ids = ( + torch.stack([t_indices, h_indices, w_indices]) + + text_length + + start_idx + ) + llm_pos_ids_list.append(image_pos_ids) + + current_pos = next_image_pos + time_steps * height * width + image_index += 1 + + if current_pos < batch_input_ids.size(1): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + text_len = batch_input_ids.size(1) - current_pos + llm_pos_ids_list.append( + torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[:, i, :] = llm_positions.to(position_ids.device) + + return position_ids + def forward( self, input_ids: torch.Tensor, @@ -392,147 +481,17 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - # make an attention_mask that is (batch_size, sequence_length) - attention_mask = torch.ones_like( - input_ids, dtype=torch.bool, device=input_ids.device - ) - inputs_embeds = self.text_model.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None: image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) - # input embeddings are masked with image embeddings - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - # handle the position_ids taking the multimodal inputs into account - mrope_position_deltas = [] - if image_grid_thw is not None or video_grid_thw is not None: - total_input_ids = input_ids - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - - for i, input_ids in enumerate(total_input_ids): - if attention_mask is not None: - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere( - input_ids == self.vision_start_token_id - ).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - # determine the number of images and videos in the input - image_nums = (vision_tokens == self.image_token_id).sum() - video_nums = (vision_tokens == self.video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - # process each input based on it's token type and grid size - for _ in range(image_nums + video_nums): - if self.image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(self.image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if self.video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(self.video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // self.spatial_merge_size, - w.item() // self.spatial_merge_size, - ) - text_len = ed - st - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) - mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) - ) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + position_ids = self.get_position_ids(input_ids, image_grid_thw) outputs = self.text_model( - input_ids=input_ids, + inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, @@ -542,8 +501,6 @@ def forward( max_s=max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, ) logits = self.lm_head(outputs) From ec933282b23296a5b418bbfabe6014f37e7585a9 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 02:20:00 +0000 Subject: [PATCH 04/18] fix: remove get_cos_sin_hack dev function --- .../text_generation_server/layers/rotary.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 6e2ba228332..123bbadbb9e 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -277,32 +277,6 @@ def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) - def get_cos_sin_hack( - self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype - ): - # TODO: avoid always computing, use the cache and update it if necessary - inv_freq_expanded = ( - self.inv_freq[None, None, :, None] - .float() - .expand(3, position_ids.shape[1], -1, 1) - ) - - position_ids_expanded = position_ids[ - :, :, None, : - ].float() # shape (3, bs, 1, positions) - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( - 2, 3 - ) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype) - sin = emb.sin().to(dtype) - - # Update cached values - self._cos_cached = cos - self._sin_cached = sin - - return cos, sin - class SuRotaryEmbedding(PositionRotaryEmbedding): def __init__( From 80ea4f061063fc98a33a272095757e87bd81357d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 03:06:35 +0000 Subject: [PATCH 05/18] feat: add simple test chat with meesage and text --- .../test_flash_qwen2_vl_simple.json | 26 +++++++++ .../models/test_flash_qwen2_vl.py | 54 +++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json create mode 100644 integration-tests/models/test_flash_qwen2_vl.py diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json new file mode 100644 index 00000000000..1c74a405741 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1730084696, + "id": "", + "model": "Qwen/Qwen2-VL-7B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.3.2-dev0-native", + "usage": { + "completion_tokens": 41, + "prompt_tokens": 349, + "total_tokens": 390 + } +} diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py new file mode 100644 index 00000000000..986e442efde --- /dev/null +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -0,0 +1,54 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_vl_handle(launcher): + with launcher( + "Qwen/Qwen2-VL-7B-Instruct", + max_batch_prefill_tokens=2000, + max_input_length=2000, + max_total_tokens=2001, + cuda_graphs=[0], + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2(flash_qwen2_vl_handle): + await flash_qwen2_vl_handle.health(300) + return flash_qwen2_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): + response = await flash_qwen2.chat( + max_tokens=100, + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + ], + ) + + assert ( + response.choices[0].message.content + == "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet" + ) + + # # TODO: return reference response + # assert ( + # response.choices[0].message.content + # == "The image depicts an astronaut with a rabbit's head standing on a rocky, reddish terrain. The astronaut is wearing a space suit with various buttons and" + # ) + + assert response == response_snapshot From e1114c2726fc1727fafd3754ec219266275763a9 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 03:07:15 +0000 Subject: [PATCH 06/18] fix: lint test --- .../models/test_flash_qwen2_vl.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 986e442efde..8658267318f 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -25,18 +25,18 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): max_tokens=100, seed=42, messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, }, - }, - {"type": "text", "text": "Describe this image."}, - ], - }, + {"type": "text", "text": "Describe this image."}, + ], + }, ], ) From 279b114ab34270dadebd175155fd3a8dbb2c8fc7 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 14:06:18 +0000 Subject: [PATCH 07/18] fix: adjust positional embeddings for multi dimensional position ids --- .../custom_modeling/flash_qwen2_modeling.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 5fe39bc9147..f411c849a42 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,23 +130,18 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - _query = query.clone() - _cos = cos.clone() - _sin = sin.clone() - - self.rotary_emb(_query, torch.select(kv, dim=1, index=0), cos, sin) - - _cos = torch.cat((_cos, _cos), dim=-1) - _sin = torch.cat((_sin, _sin), dim=-1) - q_emb = (_query * _cos).reshape(2, 1, -1) + ( - rotate_half(_query) * _sin - ).reshape(2, 1, -1) - k_emb = (torch.select(kv, dim=1, index=0) * _cos).reshape(2, 1, -1) + ( - rotate_half(torch.select(kv, dim=1, index=0)) * _sin - ).reshape(2, 1, -1) + if self.mrope_section is not None: + # if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) - query = q_emb.reshape(-1, self.num_heads, self.head_size) - kv[:, 0] = k_emb.reshape(-1, self.num_key_value_heads, self.head_size) + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -330,10 +325,13 @@ def forward( ) -> torch.Tensor: hidden_states = inputs_embeds - # TODO: ensure we are getting the correct positional embeddings + # flatten position ids from 2D to 1D cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids[0, 0, :], true_max_s, hidden_states.dtype + position_ids.flatten(), true_max_s, hidden_states.dtype ) + # reshape cos and sin for the number of position ids present in the input + cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers): From 670d75b872016ed775d5bcee25581b05a8a38d0d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 14:36:54 +0000 Subject: [PATCH 08/18] fix: update docs and lint unused vars --- docs/source/supported_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index ede1fc778f3..55449e473b6 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -24,6 +24,7 @@ Text Generation Inference enables serving optimized models. The following sectio - [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct) - [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1) - [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f) +- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d) - [Opt](https://huggingface.co/facebook/opt-6.7b) - [T5](https://huggingface.co/google/flan-t5-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) From aa2aa9f915ef30ca75856d18e107fb15cce72d9f Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 14:37:59 +0000 Subject: [PATCH 09/18] fix: include linted file --- .../text_generation_server/models/custom_modeling/qwen2_vl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 2eb7b97805a..e4fd3325066 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -209,9 +209,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2VLPatchMerger(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() - context_dim = 2560 - spatial_merge_size: int = 2 - self.hidden_size = 5120 # context_dim * (spatial_merge_size**2) + self.hidden_size = config.embed_dim * (config.spatial_merge_size**2) self.patch_merger_ln_q = FastLayerNorm.load( prefix=f"{prefix}.ln_q", weights=weights, From 65558b32f4f92d2da80a1becb25ab03d4fe9b0e2 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 15:14:02 +0000 Subject: [PATCH 10/18] fix: add norm after text output --- .../models/custom_modeling/qwen2_vl.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index e4fd3325066..907fa16369e 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -31,6 +31,7 @@ from text_generation_server.layers.layernorm import ( FastLayerNorm, + FastRMSNorm ) from text_generation_server.layers import ( TensorParallelColumnLinear, @@ -369,6 +370,11 @@ def __init__(self, prefix, config, weights): self.lm_head = FastLinear.load( prefix="lm_head", weights=weights, config=config, bias=False ) + self.norm = FastRMSNorm.load( + prefix="model.norm", + weights=weights, + eps=config.rms_norm_eps, + ) self.device = weights.device def get_position_ids( @@ -488,7 +494,7 @@ def forward( inputs_embeds[input_ids == self.image_token_id] = image_embeds position_ids = self.get_position_ids(input_ids, image_grid_thw) - outputs = self.text_model( + hidden_states = self.text_model( inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -500,6 +506,6 @@ def forward( true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, ) - - logits = self.lm_head(outputs) + hidden_states, _ = self.norm(hidden_states) + logits = self.lm_head(hidden_states) return logits, None From 6208d10c536e46f9e72eb4f30c50f137c7dfedfd Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 15:24:32 +0000 Subject: [PATCH 11/18] fix: format model file --- .../models/custom_modeling/qwen2_vl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 907fa16369e..4edd336d96d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -29,10 +29,7 @@ from transformers.activations import ACT2FN import torch.nn.functional as F -from text_generation_server.layers.layernorm import ( - FastLayerNorm, - FastRMSNorm -) +from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, From f2a1b1b3fcc3441a5a1f7212c0ac4135a9594965 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 28 Oct 2024 12:30:03 -0400 Subject: [PATCH 12/18] fix: adjust for ruff lints --- .../models/custom_modeling/qwen2_vl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 4edd336d96d..3bb29b9bca3 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -22,9 +22,9 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex + pass else: - import flash_attn_2_cuda + pass from transformers.activations import ACT2FN import torch.nn.functional as F @@ -358,7 +358,7 @@ def __init__(self, prefix, config, weights): self.video_token_id = config.video_token_id self.spatial_merge_size = config.vision_config.spatial_merge_size self.embed_tokens = TensorParallelEmbedding( - prefix=f"model.embed_tokens", weights=weights + prefix="model.embed_tokens", weights=weights ) self.visual = Qwen2VisionModel( prefix="visual", config=config.vision_config, weights=weights From 831a07f9900520bdee7cd1f694680b9cc84d29f2 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 16:33:07 +0000 Subject: [PATCH 13/18] fix: remove unused rotate_half --- .../models/custom_modeling/flash_qwen2_modeling.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index f411c849a42..8c2c31d6230 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -49,13 +49,6 @@ def _load_gqa(config, prefix: str, weights): ) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - class Qwen2Attention(torch.nn.Module): def __init__( self, From fb1ae6d24ca729218e1dc16e2a313c6f70104154 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Mon, 28 Oct 2024 16:57:35 +0000 Subject: [PATCH 14/18] feat: refactors and calc num features --- integration-tests/models/test_flash_qwen2_vl.py | 8 +------- router/src/config.rs | 8 +++++--- .../models/custom_modeling/qwen2_vl.py | 2 +- server/text_generation_server/models/vlm_causal_lm.py | 5 ++--- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 8658267318f..73413eb0d6e 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -3,13 +3,7 @@ @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher): - with launcher( - "Qwen/Qwen2-VL-7B-Instruct", - max_batch_prefill_tokens=2000, - max_input_length=2000, - max_total_tokens=2001, - cuda_graphs=[0], - ) as handle: + with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: yield handle diff --git a/router/src/config.rs b/router/src/config.rs index 7fc27f960f6..eb16e88b003 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -160,9 +160,11 @@ pub struct Qwen2Vl { } impl Qwen2Vl { - pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { - // TODO: calculate number of features - 6000 / 4 + pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { + let num_pixels = height * width; + let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2); + let start_and_end_tokens = 2; + num_image_tokens + start_and_end_tokens } } diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 3bb29b9bca3..8eee045a9a0 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -490,7 +490,7 @@ def forward( image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) inputs_embeds[input_ids == self.image_token_id] = image_embeds - position_ids = self.get_position_ids(input_ids, image_grid_thw) + position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw) hidden_states = self.text_model( inputs_embeds=inputs_embeds.squeeze(0), position_ids=position_ids, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7625c3051c3..a8467059162 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -398,9 +398,8 @@ def forward( max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - # TODO: remove the unsqueeze(0) - input_ids=input_ids.unsqueeze(0), - position_ids=position_ids.unsqueeze(0), + input_ids=input_ids, + position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, From 77c81a29cb0967eb8020318570c43e37ccf6e83d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 29 Oct 2024 01:13:17 +0000 Subject: [PATCH 15/18] fix: prefer position_ids passed from vlm causal lm and reset ids on batch --- .../test_flash_qwen2_vl_simple.json | 10 +++++----- integration-tests/models/test_flash_qwen2_vl.py | 10 ++-------- .../models/custom_modeling/qwen2_vl.py | 9 +++++---- server/text_generation_server/models/vlm_causal_lm.py | 10 ++++++++++ 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json index 1c74a405741..2f7ffb08494 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet", + "content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1730084696, + "created": 1730164250, "id": "", "model": "Qwen/Qwen2-VL-7B-Instruct", "object": "chat.completion", - "system_fingerprint": "2.3.2-dev0-native", + "system_fingerprint": "2.4.1-dev0-native", "usage": { - "completion_tokens": 41, + "completion_tokens": 58, "prompt_tokens": 349, - "total_tokens": 390 + "total_tokens": 407 } } diff --git a/integration-tests/models/test_flash_qwen2_vl.py b/integration-tests/models/test_flash_qwen2_vl.py index 73413eb0d6e..357de2b14b3 100644 --- a/integration-tests/models/test_flash_qwen2_vl.py +++ b/integration-tests/models/test_flash_qwen2_vl.py @@ -3,7 +3,7 @@ @pytest.fixture(scope="module") def flash_qwen2_vl_handle(launcher): - with launcher("Qwen/Qwen2-VL-7B-Instruct") as handle: + with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle: yield handle @@ -36,13 +36,7 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot): assert ( response.choices[0].message.content - == "The image shows a rabbit with a is on floating in outer a a in outer and seems a as an in the be an astronaut suit a a a have crew the front ag a suit the chalet" + == "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape." ) - # # TODO: return reference response - # assert ( - # response.choices[0].message.content - # == "The image depicts an astronaut with a rabbit's head standing on a rocky, reddish terrain. The astronaut is wearing a space suit with various buttons and" - # ) - assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 8eee045a9a0..6ebc3d4ef8c 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -409,7 +409,7 @@ def get_position_ids( .item() ) # TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop - time_steps, height, width = image_grid_thw[image_index] + time_steps, height, width = image_grid_thw[image_index].clone() height //= self.spatial_merge_size width //= self.spatial_merge_size @@ -487,12 +487,13 @@ def forward( # apply the visual model to the pixel values if they are provided if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None: - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) inputs_embeds[input_ids == self.image_token_id] = image_embeds - position_ids = self.get_position_ids(input_ids.unsqueeze(0), image_grid_thw) hidden_states = self.text_model( - inputs_embeds=inputs_embeds.squeeze(0), + inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a8467059162..fc813b30696 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -360,6 +360,16 @@ def forward( max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices + if self.model.get_position_ids: + if position_ids.shape[0] != 1: + position_ids = self.model.get_position_ids( + input_ids.unsqueeze(0), batch.image_grid_thw + ) + batch.position_ids = position_ids[0, 0, :] + else: + position_ids = position_ids.repeat(3, 1, 1).clone() + batch.position_ids = position_ids[0, 0, :] + if cu_seqlen_prefill is None and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. From 4f90db47be547f5e4ca70ae360a5d6706a533ff6 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 29 Oct 2024 15:26:41 +0000 Subject: [PATCH 16/18] fix: adjust get_position_ids if not available and add required args to signatures --- router/src/config.rs | 4 +--- .../models/custom_modeling/flash_pali_gemma_modeling.py | 1 + .../text_generation_server/models/custom_modeling/idefics2.py | 1 + .../models/custom_modeling/llava_next.py | 1 + server/text_generation_server/models/vlm_causal_lm.py | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index eb16e88b003..9c31e6e8c75 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -162,9 +162,7 @@ pub struct Qwen2Vl { impl Qwen2Vl { pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { let num_pixels = height * width; - let num_image_tokens = num_pixels / self.vision_config.patch_size.pow(2); - let start_and_end_tokens = 2; - num_image_tokens + start_and_end_tokens + num_pixels / self.vision_config.patch_size.pow(2) } } diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 0024f2bb92b..b1f89eff484 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -80,6 +80,7 @@ def forward( pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index a829c374128..923123d61b6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -750,6 +750,7 @@ def forward( # Unused here image_sizes: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 32e9d3348b3..df7366eafa6 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -180,6 +180,7 @@ def forward( pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index fc813b30696..df2c2a2c336 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -360,7 +360,7 @@ def forward( max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if self.model.get_position_ids: + if hasattr(self.model, "get_position_ids"): if position_ids.shape[0] != 1: position_ids = self.model.get_position_ids( input_ids.unsqueeze(0), batch.image_grid_thw From 77eb07f73b21a785ba7b4b8f31c725ed01656d4d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 29 Oct 2024 15:47:32 +0000 Subject: [PATCH 17/18] fix: adjust resize case for qwen2_vl warmup --- .../text_generation_server/models/vlm_causal_lm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index df2c2a2c336..9a3db502b75 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -177,6 +177,14 @@ def batch_tokenized_inputs( pass elif chunk_type == "image": image = Image.open(BytesIO(chunk.image.data)) + # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the + # default warmup image is 20x20 + if config.model_type == "qwen2_vl": + if image.width <= 20: + w = image.width * 2 + h = image.height * 2 + image = image.resize((w, h)) + if config.model_type == "llava_next": images.append(image) else: @@ -185,11 +193,6 @@ def batch_tokenized_inputs( raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - if images[0][0].width <= 20: - # TODO: provide a better way to handle the issue of the prefill image being too small - images = images[0][0].resize( - (images[0][0].width * 2, images[0][0].height * 2) - ) image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None From 620769e380099f4f3f2fdf8630cfbb817bd6d28f Mon Sep 17 00:00:00 2001 From: David Holtz Date: Tue, 29 Oct 2024 17:49:50 +0000 Subject: [PATCH 18/18] fix: avoid qwen2 vl specific paths with qwen2 --- .../custom_modeling/flash_qwen2_modeling.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 8c2c31d6230..cc4039b1cbc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -61,7 +61,11 @@ def __init__( config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads - self.mrope_section = config.rope_scaling.get("mrope_section", None) + self.mrope_section = ( + config.rope_scaling.get("mrope_section", None) + if config.rope_scaling is not None + else None + ) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -322,9 +326,10 @@ def forward( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids.flatten(), true_max_s, hidden_states.dtype ) - # reshape cos and sin for the number of position ids present in the input - cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) - sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + # reshape back to 2D if the position_ids were 2D + if position_ids.size(0) != cos.size(0): + cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) + sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2) residual = None for i, layer in enumerate(self.layers): @@ -365,7 +370,8 @@ def __init__(self, prefix: str, config, weights): ) self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, ) self.max_past = config.sliding_window