Skip to content

Commit 7bd1de7

Browse files
author
Mohit Soni
committed
TF_upgrade
Signed-off-by: Mohit Soni <[email protected]>
1 parent 3603fb0 commit 7bd1de7

File tree

5 files changed

+109
-32
lines changed

5 files changed

+109
-32
lines changed

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@
139139
Qwen2Model,
140140
Qwen2RMSNorm,
141141
)
142+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
143+
Qwen2_5_VisionTransformerPretrainedModel,
144+
Qwen2_5_VLAttention,
145+
Qwen2_5_VLDecoderLayer,
146+
Qwen2_5_VLForConditionalGeneration,
147+
Qwen2_5_VLModel,
148+
Qwen2_5_VLTextModel,
149+
Qwen2_5_VLVisionAttention,
150+
)
151+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
152+
Qwen2RMSNorm as Qwen2_5RMSNorm,
153+
)
142154
from transformers.models.qwen3.modeling_qwen3 import (
143155
Qwen3Attention,
144156
Qwen3DecoderLayer,
@@ -155,17 +167,6 @@
155167
Qwen3MoeRotaryEmbedding,
156168
Qwen3MoeSparseMoeBlock,
157169
)
158-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
159-
Qwen2_5_VisionTransformerPretrainedModel,
160-
Qwen2_5_VLAttention,
161-
Qwen2_5_VLDecoderLayer,
162-
Qwen2_5_VLForConditionalGeneration,
163-
Qwen2_5_VLModel,
164-
Qwen2_5_VLVisionAttention,
165-
)
166-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
167-
Qwen2RMSNorm as Qwen2_5RMSNorm,
168-
)
169170
from transformers.models.starcoder2.modeling_starcoder2 import (
170171
Starcoder2Attention,
171172
Starcoder2DecoderLayer,
@@ -336,6 +337,15 @@
336337
QEffQwen2ForCausalLM,
337338
QEffQwen2Model,
338339
)
340+
from QEfficient.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
341+
QEffQwen2_5_VisionTransformerPretrainedModel,
342+
QEffQwen2_5_VLAttention,
343+
QEffQwen2_5_VLDecoderLayer,
344+
QEffQwen2_5_VLModel,
345+
QEffQwen2_5_VLTextModel,
346+
QEffQwen2_5_VLVisionAttention,
347+
QEffQwen_2_5_vl_ForConditionalGeneration,
348+
)
339349
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
340350
QEffQwen3Attention,
341351
QEffQwen3DecoderLayer,
@@ -350,14 +360,6 @@
350360
QEffQwen3MoeRotaryEmbedding,
351361
QEffQwen3MoeSparseMoeBlock,
352362
)
353-
from QEfficient.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
354-
QEffQwen2_5_VisionTransformerPretrainedModel,
355-
QEffQwen2_5_VLAttention,
356-
QEffQwen2_5_VLDecoderLayer,
357-
QEffQwen2_5_VLModel,
358-
QEffQwen2_5_VLVisionAttention,
359-
QEffQwen_2_5_vl_ForConditionalGeneration,
360-
)
361363
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
362364
QEffStarcoder2Attention,
363365
QEFFStarcoder2DecoderLayer,
@@ -532,6 +534,7 @@ class KVCacheTransform(ModuleMappingTransform):
532534
Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer,
533535
Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel,
534536
Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention,
537+
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
535538
# Starcoder2
536539
Starcoder2Attention: QEffStarcoder2Attention,
537540
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,

QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
Qwen2_5_VLAttention,
1515
Qwen2_5_VLConfig,
1616
Qwen2_5_VLDecoderLayer,
17+
Qwen2_5_VLModelOutputWithPast,
1718
Qwen2_5_VLRotaryEmbedding,
19+
Qwen2_5_VLTextModel,
1820
Qwen2_5_VLVisionAttention,
1921
apply_rotary_pos_emb_vision,
2022
repeat_kv,
@@ -393,6 +395,7 @@ def forward(
393395
batch_index: Optional[torch.LongTensor] = None,
394396
output_attentions: bool = False,
395397
use_cache: bool = False,
398+
cache_position: Optional[torch.LongTensor] = None,
396399
**kwargs,
397400
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
398401
bsz, q_len, _ = hidden_states.size()
@@ -406,7 +409,7 @@ def forward(
406409
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
407410

408411
kv_seq_len = key_states.shape[-2]
409-
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
412+
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
410413

411414
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412415

@@ -490,7 +493,7 @@ def forward(
490493
output_attentions=output_attentions,
491494
use_cache=use_cache,
492495
cache_position=cache_position,
493-
# position_embeddings=position_embeddings,
496+
**kwargs,
494497
)
495498
hidden_states = residual + hidden_states
496499

@@ -511,7 +514,7 @@ def forward(
511514
return outputs
512515

513516

514-
class QEffQwen2_5_VLModel(Qwen2_5_VLModel):
517+
class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel):
515518
def forward(
516519
self,
517520
input_ids: Optional[torch.LongTensor] = None,
@@ -525,6 +528,7 @@ def forward(
525528
output_hidden_states: Optional[bool] = None,
526529
return_dict: Optional[bool] = None,
527530
cache_position: Optional[torch.LongTensor] = None,
531+
**kwargs,
528532
) -> Union[Tuple, BaseModelOutputWithPast]:
529533
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
530534
output_hidden_states = (
@@ -571,6 +575,7 @@ def forward(
571575
output_attentions=output_attentions,
572576
use_cache=use_cache,
573577
cache_position=cache_position,
578+
**kwargs,
574579
)
575580

576581
hidden_states = layer_outputs[0]
@@ -587,13 +592,66 @@ def forward(
587592
if return_legacy_cache:
588593
past_key_values = past_key_values.to_legacy_cache()
589594

590-
# Cast to INT32 to avoid issue while running in ONNXRT
591-
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
592-
hidden_states = hidden_states[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
595+
return BaseModelOutputWithPast(
596+
last_hidden_state=hidden_states,
597+
past_key_values=past_key_values,
598+
hidden_states=all_hidden_states,
599+
attentions=all_self_attns,
600+
)
593601

594602
return (hidden_states, past_key_values)
595603

596604

605+
class QEffQwen2_5_VLModel(Qwen2_5_VLModel):
606+
def forward(
607+
self,
608+
input_ids: Optional[torch.LongTensor] = None,
609+
attention_mask: Optional[torch.Tensor] = None,
610+
position_ids: Optional[torch.LongTensor] = None,
611+
past_key_values: Optional[List[torch.FloatTensor]] = None,
612+
batch_index: Optional[torch.LongTensor] = None,
613+
inputs_embeds: Optional[torch.FloatTensor] = None,
614+
use_cache: Optional[bool] = None,
615+
output_attentions: Optional[bool] = None,
616+
output_hidden_states: Optional[bool] = None,
617+
return_dict: Optional[bool] = None,
618+
cache_position: Optional[torch.LongTensor] = None,
619+
**kwargs,
620+
) -> Union[Tuple, BaseModelOutputWithPast]:
621+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
622+
output_hidden_states = (
623+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
624+
)
625+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
626+
627+
if inputs_embeds is None:
628+
inputs_embeds = self.get_input_embeddings()(input_ids)
629+
630+
outputs = self.language_model(
631+
input_ids=None,
632+
position_ids=position_ids,
633+
attention_mask=attention_mask,
634+
past_key_values=past_key_values,
635+
batch_index=batch_index,
636+
inputs_embeds=inputs_embeds,
637+
use_cache=use_cache,
638+
output_attentions=output_attentions,
639+
output_hidden_states=output_hidden_states,
640+
return_dict=True,
641+
cache_position=cache_position,
642+
**kwargs,
643+
)
644+
645+
output = Qwen2_5_VLModelOutputWithPast(
646+
last_hidden_state=outputs.last_hidden_state,
647+
past_key_values=outputs.past_key_values,
648+
hidden_states=outputs.hidden_states,
649+
attentions=outputs.attentions,
650+
rope_deltas=self.rope_deltas,
651+
)
652+
return output if return_dict else output.to_tuple()
653+
654+
597655
class QEffQwen_2_5_vl_EncoderWrapper(nn.Module):
598656
def __init__(self, model):
599657
super().__init__()
@@ -613,7 +671,7 @@ class QEffQwen_2_5_vl_DecoderWrapper(nn.Module):
613671
def __init__(self, model):
614672
super().__init__()
615673
self.model = model
616-
self.language_model = self.model.model
674+
self.language_model = self.model.model.language_model
617675

618676
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
619677
inputs_embeds = self.model.get_input_embeddings()(input_ids)
@@ -628,10 +686,13 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
628686
outputs = self.model.model(
629687
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
630688
)
631-
logits = self.model.lm_head(outputs[0])
689+
690+
logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
691+
hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
692+
logits = self.model.lm_head(hidden_states)
632693
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
633694

634-
return logits, vision_embeds, image_idx, outputs[1]
695+
return logits, vision_embeds, image_idx, outputs.past_key_values
635696

636697

637698
class QEffQwen_2_5_vl_ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
*Latest news* :fire: <br>
99

10+
- [10/2025] Added support for Qwen2.5VL Multi-Model [Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)
11+
- [10/2025] Added support for Mistral3 Multi-Model [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
12+
- [10/2025] Added support for Molmo Multi-Model [allenai/Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924)
1013
- [06/2025] Added support for Llama4 Multi-Model [meta-llama/Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
1114
- [06/2025] Added support for Gemma3 Multi-Modal-Model [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)
1215
- [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)

examples/qwen2_5_vl_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
from QEfficient import QEFFAutoModelForImageTextToText
1717

18+
## For AWQ model update pytorch version to 2.8.*
1819
model_id = "Qwen/Qwen2.5-VL-32B-Instruct"
1920
config = AutoConfig.from_pretrained(model_id)
2021

21-
# For Testing Purpose Only
22-
config.num_hidden_layers = 1
22+
## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model
2323

2424
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
2525
model_id, attn_implementation="eager", kv_offload=True, config=config
@@ -28,7 +28,7 @@
2828
processor = AutoProcessor.from_pretrained(model_id)
2929

3030
### use skip_vision=Ture, if want to run only text, ow false ###
31-
skip_vision = True
31+
skip_vision = False
3232

3333
if skip_vision:
3434
## Only Text ##
@@ -152,7 +152,7 @@
152152

153153
inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1)
154154

155-
pos_ids, rope_deltas = qeff_model.model.get_rope_index(
155+
pos_ids, rope_deltas = qeff_model.model.model.get_rope_index(
156156
inputs["input_ids"],
157157
inputs["image_grid_thw"],
158158
video_grid_thw=None,
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"SKIP_QNN_CONVERTER_STEP":false,
3+
"context_binary_generator_args_extension":"--log_level debug",
4+
"converter_args_extension":"--onnx_defer_loading",
5+
"qnn_compilation_backend":{
6+
"compiler_enable_depth_first":true,
7+
"compiler_printDDRStats":false,
8+
"compiler_printPerfMetrics":false
9+
}
10+
}

0 commit comments

Comments
 (0)