1414    Qwen2_5_VLAttention ,
1515    Qwen2_5_VLConfig ,
1616    Qwen2_5_VLDecoderLayer ,
17+     Qwen2_5_VLModelOutputWithPast ,
1718    Qwen2_5_VLRotaryEmbedding ,
19+     Qwen2_5_VLTextModel ,
1820    Qwen2_5_VLVisionAttention ,
1921    apply_rotary_pos_emb_vision ,
2022    repeat_kv ,
@@ -393,6 +395,7 @@ def forward(
393395        batch_index : Optional [torch .LongTensor ] =  None ,
394396        output_attentions : bool  =  False ,
395397        use_cache : bool  =  False ,
398+         cache_position : Optional [torch .LongTensor ] =  None ,
396399        ** kwargs ,
397400    ) ->  Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
398401        bsz , q_len , _  =  hidden_states .size ()
@@ -406,7 +409,7 @@ def forward(
406409        value_states  =  value_states .view (bsz , q_len , - 1 , self .head_dim ).transpose (1 , 2 )
407410
408411        kv_seq_len  =  key_states .shape [- 2 ]
409-         kv_seq_len  =  past_key_value .get_usable_length ( kv_seq_len ,  self .layer_idx )
412+         kv_seq_len  =  past_key_value .get_seq_length ( self .layer_idx ,  cache_position )
410413
411414        cos , sin  =  self .rotary_emb (value_states , seq_len = kv_seq_len )
412415
@@ -490,7 +493,7 @@ def forward(
490493            output_attentions = output_attentions ,
491494            use_cache = use_cache ,
492495            cache_position = cache_position ,
493-             # position_embeddings=position_embeddings ,
496+             ** kwargs ,
494497        )
495498        hidden_states  =  residual  +  hidden_states 
496499
@@ -511,7 +514,7 @@ def forward(
511514        return  outputs 
512515
513516
514- class  QEffQwen2_5_VLModel ( Qwen2_5_VLModel ):
517+ class  QEffQwen2_5_VLTextModel ( Qwen2_5_VLTextModel ):
515518    def  forward (
516519        self ,
517520        input_ids : Optional [torch .LongTensor ] =  None ,
@@ -525,6 +528,7 @@ def forward(
525528        output_hidden_states : Optional [bool ] =  None ,
526529        return_dict : Optional [bool ] =  None ,
527530        cache_position : Optional [torch .LongTensor ] =  None ,
531+         ** kwargs ,
528532    ) ->  Union [Tuple , BaseModelOutputWithPast ]:
529533        output_attentions  =  output_attentions  if  output_attentions  is  not None  else  self .config .output_attentions 
530534        output_hidden_states  =  (
@@ -571,6 +575,7 @@ def forward(
571575                output_attentions = output_attentions ,
572576                use_cache = use_cache ,
573577                cache_position = cache_position ,
578+                 ** kwargs ,
574579            )
575580
576581            hidden_states  =  layer_outputs [0 ]
@@ -587,13 +592,66 @@ def forward(
587592        if  return_legacy_cache :
588593            past_key_values  =  past_key_values .to_legacy_cache ()
589594
590-         # Cast to INT32 to avoid issue while running in ONNXRT 
591-         logit_index  =  position_ids [0 ].to (torch .int32 ).argmax (1 , keepdim = True )
592-         hidden_states  =  hidden_states [torch .arange (position_ids [0 ].shape [0 ]).view (- 1 , 1 ), logit_index ]
595+         return  BaseModelOutputWithPast (
596+             last_hidden_state = hidden_states ,
597+             past_key_values = past_key_values ,
598+             hidden_states = all_hidden_states ,
599+             attentions = all_self_attns ,
600+         )
593601
594602        return  (hidden_states , past_key_values )
595603
596604
605+ class  QEffQwen2_5_VLModel (Qwen2_5_VLModel ):
606+     def  forward (
607+         self ,
608+         input_ids : Optional [torch .LongTensor ] =  None ,
609+         attention_mask : Optional [torch .Tensor ] =  None ,
610+         position_ids : Optional [torch .LongTensor ] =  None ,
611+         past_key_values : Optional [List [torch .FloatTensor ]] =  None ,
612+         batch_index : Optional [torch .LongTensor ] =  None ,
613+         inputs_embeds : Optional [torch .FloatTensor ] =  None ,
614+         use_cache : Optional [bool ] =  None ,
615+         output_attentions : Optional [bool ] =  None ,
616+         output_hidden_states : Optional [bool ] =  None ,
617+         return_dict : Optional [bool ] =  None ,
618+         cache_position : Optional [torch .LongTensor ] =  None ,
619+         ** kwargs ,
620+     ) ->  Union [Tuple , BaseModelOutputWithPast ]:
621+         output_attentions  =  output_attentions  if  output_attentions  is  not None  else  self .config .output_attentions 
622+         output_hidden_states  =  (
623+             output_hidden_states  if  output_hidden_states  is  not None  else  self .config .output_hidden_states 
624+         )
625+         return_dict  =  return_dict  if  return_dict  is  not None  else  self .config .use_return_dict 
626+ 
627+         if  inputs_embeds  is  None :
628+             inputs_embeds  =  self .get_input_embeddings ()(input_ids )
629+ 
630+         outputs  =  self .language_model (
631+             input_ids = None ,
632+             position_ids = position_ids ,
633+             attention_mask = attention_mask ,
634+             past_key_values = past_key_values ,
635+             batch_index = batch_index ,
636+             inputs_embeds = inputs_embeds ,
637+             use_cache = use_cache ,
638+             output_attentions = output_attentions ,
639+             output_hidden_states = output_hidden_states ,
640+             return_dict = True ,
641+             cache_position = cache_position ,
642+             ** kwargs ,
643+         )
644+ 
645+         output  =  Qwen2_5_VLModelOutputWithPast (
646+             last_hidden_state = outputs .last_hidden_state ,
647+             past_key_values = outputs .past_key_values ,
648+             hidden_states = outputs .hidden_states ,
649+             attentions = outputs .attentions ,
650+             rope_deltas = self .rope_deltas ,
651+         )
652+         return  output  if  return_dict  else  output .to_tuple ()
653+ 
654+ 
597655class  QEffQwen_2_5_vl_EncoderWrapper (nn .Module ):
598656    def  __init__ (self , model ):
599657        super ().__init__ ()
@@ -613,7 +671,7 @@ class QEffQwen_2_5_vl_DecoderWrapper(nn.Module):
613671    def  __init__ (self , model ):
614672        super ().__init__ ()
615673        self .model  =  model 
616-         self .language_model  =  self .model .model 
674+         self .language_model  =  self .model .model . language_model 
617675
618676    def  forward (self , input_ids , vision_embeds , position_ids , image_idx , past_key_values ):
619677        inputs_embeds  =  self .model .get_input_embeddings ()(input_ids )
@@ -628,10 +686,13 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
628686        outputs  =  self .model .model (
629687            inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True 
630688        )
631-         logits  =  self .model .lm_head (outputs [0 ])
689+ 
690+         logit_index  =  position_ids [0 ].to (torch .int32 ).argmax (1 , keepdim = True )
691+         hidden_states  =  outputs .last_hidden_state [torch .arange (position_ids [0 ].shape [0 ]).view (- 1 , 1 ), logit_index ]
692+         logits  =  self .model .lm_head (hidden_states )
632693        image_idx  =  (indices1 .max () +  1 ).unsqueeze (0 ).unsqueeze (0 )
633694
634-         return  logits , vision_embeds , image_idx , outputs [ 1 ] 
695+         return  logits , vision_embeds , image_idx , outputs . past_key_values 
635696
636697
637698class  QEffQwen_2_5_vl_ForConditionalGeneration (Qwen2_5_VLForConditionalGeneration ):
0 commit comments