From 409da2498f20e3e084b47d8af6ca91d79d8519ee Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 23 Oct 2025 16:35:26 -0700 Subject: [PATCH 01/11] Extend on-device sampling support for dual QPC VLMs Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 122 +++++++++++++++++- .../transformers/models/pytorch_transforms.py | 4 + QEfficient/transformers/sampler/sampler.py | 56 +++++--- 3 files changed, 164 insertions(+), 18 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..8b021314e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,7 +721,12 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__( + self, + model, + qaic_config: Optional[dict] = None, + **kwargs + ): """ Initializes the language decoder component for multimodal models. @@ -729,12 +734,24 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.model.qaic_config = qaic_config + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -758,10 +775,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) + def get_sampling_inputs_and_outputs( + self, + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + ): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" + + logits_index = output_names.index("logits") + output_names[logits_index] = "next_tokens" + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + ) + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes + def compile( self, compile_dir, @@ -1499,6 +1601,8 @@ def __init__( """ if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("qaic_config", None): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) # to handle internvl models @@ -2023,6 +2127,7 @@ def from_pretrained( pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -2036,6 +2141,12 @@ def from_pretrained( If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2063,11 +2174,14 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, + qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, ) @@ -2273,7 +2387,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs ) return cls( model, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..c750a8c66 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -289,6 +289,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -392,6 +393,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -707,10 +709,12 @@ class SamplerTransform: QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..4a9aa6034 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -122,6 +124,8 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -170,20 +174,36 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. """ - - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + if vision_embeds is not None: + logits, vision_embeds, image_idx, past_key_values = self.old_forward( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values + ) + outputs = dict( + logits=logits, + vision_embeds=vision_embeds, + image_idx=image_idx, + past_key_values=past_key_values + ) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -230,7 +250,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -314,7 +336,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) From e06e1758bad19feae3dbb7c38a2f349c82b0a585 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 30 Oct 2025 00:04:01 -0700 Subject: [PATCH 02/11] Fix random_numbers shape Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 25 ++++++++----------- QEfficient/transformers/sampler/sampler.py | 22 ++++++---------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8b021314e..6168f4492 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,12 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__( - self, - model, - qaic_config: Optional[dict] = None, - **kwargs - ): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -735,7 +730,7 @@ def __init__( model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -776,7 +771,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + inputs, output_names, dynamic_axes + ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) @@ -807,7 +804,7 @@ def get_sampling_inputs_and_outputs( sampling-related parameters. """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" logits_index = output_names.index("logits") @@ -859,7 +856,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes @@ -2142,7 +2139,7 @@ def from_pretrained( If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -2181,7 +2178,7 @@ def from_pretrained( model, kv_offload=kv_offload, continuous_batching=continuous_batching, - qaic_config=qaic_config, + qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, ) @@ -2391,7 +2388,7 @@ def from_pretrained( kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - **kwargs + **kwargs, ) return cls( model, @@ -2594,7 +2591,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 4a9aa6034..a15e156ff 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,8 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None - vision_embeds: Optional[torch.FloatTensor] = None # For VLMs - image_idx: Optional[torch.IntTensor] = None # for VLMs + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -176,19 +176,14 @@ def sampler_forward( """ if vision_embeds is not None: logits, vision_embeds, image_idx, past_key_values = self.old_forward( - input_ids=input_ids, - vision_embeds=vision_embeds, - position_ids=position_ids, - image_idx=image_idx, - past_key_values=past_key_values - ) - outputs = dict( - logits=logits, + input_ids=input_ids, vision_embeds=vision_embeds, + position_ids=position_ids, image_idx=image_idx, - past_key_values=past_key_values + past_key_values=past_key_values, ) - if position_ids.dim() == 3: # For models using m-rope + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope position_ids = position_ids[0] else: outputs = self.old_forward( @@ -322,9 +317,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) From 3e242ce85b42e5babcee1ee87ce2dacb0c0565e9 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 30 Oct 2025 00:04:01 -0700 Subject: [PATCH 03/11] Update example with new random sampling logic Signed-off-by: quic-sanising Signed-off-by: sanising --- examples/on_device_sampling.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 00d8c2430..108e5390e 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -28,6 +28,7 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +37,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -110,10 +113,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -130,10 +133,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") From 1a01d57a9d737c0e06ea1c87cf91ce3408dbb324 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Mon, 10 Nov 2025 16:33:06 -0800 Subject: [PATCH 04/11] Update to align with recent VLM CB changes Signed-off-by: quic-xiyushi --- QEfficient/transformers/models/modeling_auto.py | 17 +++++++++++------ QEfficient/transformers/sampler/sampler.py | 6 +++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6168f4492..c110b3ce5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,7 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): + def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -729,6 +729,9 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + continuous_batching : bool, optional + If True, enables continuous batching mode for future compilation and execution. + This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. qaic_config : dict, optional A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: @@ -741,6 +744,7 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -804,6 +808,7 @@ def get_sampling_inputs_and_outputs( sampling-related parameters. """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" @@ -816,10 +821,10 @@ def get_sampling_inputs_and_outputs( dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool ) dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_repetition_penalty_buffer_RetainedState") @@ -829,10 +834,10 @@ def get_sampling_inputs_and_outputs( dynamic_axes["repetition_penalties"] = {0: "batch_size"} example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool ) dynamic_axes["past_presence_penalty_buffer"] = { - 0: "batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", } output_names.append("past_presence_penalty_buffer_RetainedState") @@ -981,7 +986,7 @@ def __init__( self.model = model self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index a15e156ff..1075db784 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -175,13 +175,17 @@ def sampler_forward( Must be in [-1, 1]. """ if vision_embeds is not None: - logits, vision_embeds, image_idx, past_key_values = self.old_forward( + forward_kwargs = dict( input_ids=input_ids, vision_embeds=vision_embeds, position_ids=position_ids, image_idx=image_idx, past_key_values=past_key_values, ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) if position_ids.dim() == 3: # For models using m-rope position_ids = position_ids[0] From 30d60618e43b7931d7a2a090e2fb4268510d1337 Mon Sep 17 00:00:00 2001 From: sanising Date: Mon, 10 Nov 2025 19:28:38 -0600 Subject: [PATCH 05/11] Update tests with new random sampling logic Signed-off-by: sanising --- tests/transformers/sampler/test_sampler.py | 52 +++++++++++----------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..8d437eee8 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -211,7 +211,7 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -233,7 +233,6 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", random_sampling_configs, @@ -291,6 +290,7 @@ def test_random_sampling( # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,11 +301,13 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -319,32 +321,32 @@ def test_random_sampling( # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", } golden_ids = { "w_sampler": [ [ - 21380, + 319, + 3615, 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, ] ], "wo_sampler": [ From 7cf106e39f7a448aed031cfb66852227348e9215 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Wed, 19 Nov 2025 10:53:24 -0800 Subject: [PATCH 06/11] Refactor Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 209 ++---------------- QEfficient/utils/sampler_utils.py | 91 +++++++- 2 files changed, 114 insertions(+), 186 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a1a333317..242063ee9 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -61,6 +61,7 @@ ) from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs class QEFFTransformersBase(QEFFBaseModel): @@ -730,28 +731,12 @@ def __init__(self, model, continuous_batching: bool = False, qaic_config: Option ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. - continuous_batching : bool, optional - If True, enables continuous batching mode for future compilation and execution. - This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. - qaic_config : dict, optional - A dictionary for QAIC-specific configurations. - Only the following keys are supported by the text model of the dual QPC multimodal model: - - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. - Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - self.continuous_batching = continuous_batching - self.model.qaic_config = qaic_config - # ---Sampling--- - # Note: SamplerTransform should be applied after all other transforms - # are done. The role of the sampler is to just add nodes at the output of the - # previous transform function. - self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -775,98 +760,10 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ - if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( - inputs, output_names, dynamic_axes - ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" - - logits_index = output_names.index("logits") - output_names[logits_index] = "next_tokens" - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def compile( self, compile_dir, @@ -993,7 +890,13 @@ def __init__( self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) self.continuous_batching = continuous_batching + self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) @property def model_name(self) -> str: @@ -1115,6 +1018,19 @@ def export( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.lang_model.model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) self.vision_model.export( inputs["vision"], @@ -2300,7 +2216,6 @@ def from_pretrained( model, kv_offload=kv_offload, continuous_batching=continuous_batching, - qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, **kwargs, @@ -2634,10 +2549,13 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( example_inputs=example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, ) return self._export( @@ -2647,85 +2565,6 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir=export_dir, ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def build_prefill_specialization( self, prefill_seq_len: int = 32, diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..0460eeb3a 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -5,8 +5,11 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Set +from typing import Dict, List, Optional, Set +import torch + +from QEfficient.utils import constants from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -56,3 +59,89 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ) return session_includes_sampler + + +def get_sampling_inputs_and_outputs( + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + continuous_batching: bool, + vocab_size: int, + qaic_config: Dict, +): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + continuous_batching : bool + Whether this model will be used for continuous batching in the future. + vocab_size: int + Vocabulary size for this model. + qaic_config : Dict + QAIC config dictionary. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes From 45aed11cf908615eadb1366416e5df6f5953d48b Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 10:45:58 -0800 Subject: [PATCH 07/11] Add unit tests Signed-off-by: quic-xiyushi --- QEfficient/generation/vlm_generation.py | 13 ++ .../transformers/models/modeling_auto.py | 18 ++- tests/transformers/sampler/test_sampler.py | 142 ++++++++++++++---- 3 files changed, 142 insertions(+), 31 deletions(-) diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 5eb91d142..6c028a12f 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -36,6 +36,7 @@ write_io_files, ) from QEfficient.utils import LRUCache +from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger @@ -303,6 +304,13 @@ def _execute_chunked_prefill( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + if self.include_sampler: + for op in Constants.SAMPLER_OPS: + if decode_batch_id is not None: + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + lang_inputs[op] = self.sampling_params[op] + for i in range(num_chunks): input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] position_ids_slice = lang_inputs["position_ids"][ @@ -328,6 +336,11 @@ def _execute_chunked_prefill( chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + for op in Constants.SAMPLER_OPS: + chunk_inputs[op] = lang_inputs[op] + outputs = self._session.run(chunk_inputs) if "image_idx_output" in outputs: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 242063ee9..2bf81f68f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -881,7 +881,10 @@ def __init__( If `full_batch_size` is provided. """ if kwargs.pop("full_batch_size", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) self.model = model self.config = model.config @@ -1028,7 +1031,7 @@ def export( output_names=output_names["lang"], dynamic_axes=dynamic_axes["lang"], continuous_batching=self.continuous_batching, - vocab_size=self.lang_model.model.config.vocab_size, + vocab_size=self.config.vocab_size, qaic_config=self.lang_model.model.qaic_config, ) @@ -1235,6 +1238,7 @@ def generate( device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + **kwargs, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1293,6 +1297,7 @@ def generate( full_batch_size=fbs, comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + **kwargs, ) # Call generate method @@ -1572,11 +1577,16 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided. + If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("continuous_batching", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - if kwargs.pop("qaic_config", None): + if qaic_config is not None and qaic_config.pop("include_sampler", False): raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 8d437eee8..d31dfef37 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -5,12 +5,13 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Union +from transformers import AutoConfig, AutoProcessor import numpy as np import pytest -from QEfficient import QEFFAutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants @@ -24,6 +25,20 @@ 20, # generation_len 2, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] greedy_sampling_configs = [ @@ -35,6 +50,20 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "Qwen/Qwen2.5-VL-3B-Instruct", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] random_sampling_configs = [ @@ -46,23 +75,38 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm ), + # pytest.param( + # "Qwen/Qwen2.5-VL-3B-Instruct", # model + # ( + # ["https://picsum.photos/id/237/536/354"] * 2, + # ["Can you describe the image in detail."] * 2, + # ), # images and prompts + # 128, # prefill_seq_len + # 4096, # ctx_len + # 20, # generation_len + # 2, # full_batch_size + # None, # spec_length + # True, # is_vlm + # ), ] @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", sampler_transform_configs, ) def test_sampler_transform( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -70,45 +114,56 @@ def test_sampler_transform( next tokens and/or probability distributions. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) - model_w_sampler_qpc_path: str = model_w_sampler.compile( + model_w_sampler_qpc_path = model_w_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) - model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) + if is_vlm: + model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) @@ -139,40 +194,54 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", - greedy_sampling_configs, + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", + sampler_transform_configs, ) def test_greedy_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test greedy sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + additional_configs["num_hidden_layers"] = 2 + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -180,7 +249,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -190,7 +259,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -213,6 +282,7 @@ def test_greedy_sampling( "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -221,6 +291,7 @@ def test_greedy_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts and ids @@ -234,23 +305,36 @@ def test_greedy_sampling( @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", random_sampling_configs, ) def test_random_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, spec_length: int, + is_vlm: bool, ): """ Test random sampling with QPC compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + additional_configs = {} + additional_params = {} + if is_vlm: + additional_configs["kv_offload"] = True + qeff_class = QEFFAutoModelForImageTextToText + assert isinstance(prompts, tuple) + additional_params["images"] = prompts[0] + additional_params["processor"] = AutoProcessor.from_pretrained(model) + prompts = prompts[1] + else: + qeff_class = QEFFAutoModelForCausalLM + spec_length -= 1 + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -258,14 +342,16 @@ def test_random_sampling( "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -273,7 +359,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -283,7 +369,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -309,6 +395,7 @@ def test_random_sampling( np.float32 ), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -317,6 +404,7 @@ def test_random_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts From 6273ab5c156ee53a6134872c29dd05067e055aa9 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 11:07:52 -0800 Subject: [PATCH 08/11] Clean up Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2bf81f68f..189017507 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -723,7 +723,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs): + def __init__(self, model, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -872,13 +872,10 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : - Additional keyword arguments. `full_batch_size` is not supported here. - - Raises - ------ - NotImplementedError - If `full_batch_size` is provided. + Additional keyword arguments. """ if kwargs.pop("full_batch_size", None): continuous_batching = True @@ -891,7 +888,7 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.continuous_batching = continuous_batching self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None @@ -1577,15 +1574,13 @@ def __init__( Raises ------ NotImplementedError - If `full_batch_size` is provided or `continuous_batching` is True or `include_sampler` is True. + If `full_batch_size` is provided or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): warnings.warn( "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 ) raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - if kwargs.pop("continuous_batching", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") if qaic_config is not None and qaic_config.pop("include_sampler", False): raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) @@ -2189,10 +2184,6 @@ def from_pretrained( If None, the default behavior of the internal classes is used (typically dual QPC). qaic_config : dict, optional A dictionary for QAIC-specific configurations. - Only the following keys are supported by the text model of the dual QPC multimodal model: - - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. - Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. From 3789d5a36f4a268251ac26e9f1f3c3e907c77c55 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 13:24:45 -0800 Subject: [PATCH 09/11] Update test_sampler.py Signed-off-by: quic-xiyushi --- tests/transformers/sampler/test_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index d31dfef37..ca4a3abef 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -195,7 +195,7 @@ def test_sampler_transform( @pytest.mark.on_qaic @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", - sampler_transform_configs, + greedy_sampling_configs, ) def test_greedy_sampling( model: str, @@ -221,7 +221,7 @@ def test_greedy_sampling( additional_params["processor"] = AutoProcessor.from_pretrained(model) prompts = prompts[1] else: - additional_configs["num_hidden_layers"] = 2 + additional_configs["num_hidden_layers"] = 4 qeff_class = QEFFAutoModelForCausalLM spec_length -= 1 model_w_sampler = qeff_class.from_pretrained( From 5e2afb7ad7b9f891ab93577eb4cd65faa4481f18 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 20 Nov 2025 18:13:13 -0800 Subject: [PATCH 10/11] Fix hash for VLM's language decoder to include qaic_config Signed-off-by: quic-xiyushi --- QEfficient/transformers/models/modeling_auto.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1b2189c6e..b3a58f669 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -752,7 +752,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -760,11 +760,14 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : Additional keyword arguments passed to the base class constructor. """ - super().__init__(model, **kwargs) + super().__init__(model, qaic_config=qaic_config, **kwargs) self.model = model.get_qeff_language_decoder() + self.model.qaic_config = qaic_config self.hash_params["qeff_auto_class"] = self.__class__.__name__ def export( @@ -936,9 +939,8 @@ def __init__( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) self.continuous_batching = continuous_batching - self.lang_model.model.qaic_config = qaic_config self.input_shapes, self.output_names = None, None # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms From 10990a9650e260e196a4c8001d894ec3a03ffbaf Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Tue, 25 Nov 2025 14:07:06 -0800 Subject: [PATCH 11/11] Fix bug in getting vocab_size and missing ccl in forward Signed-off-by: quic-xiyushi --- QEfficient/transformers/models/modeling_auto.py | 4 ++-- QEfficient/transformers/sampler/sampler.py | 3 +++ tests/transformers/sampler/test_sampler.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b3a58f669..82f172f3d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -8,7 +8,7 @@ import warnings from pathlib import Path from time import perf_counter -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -1081,7 +1081,7 @@ def export( output_names=output_names["lang"], dynamic_axes=dynamic_axes["lang"], continuous_batching=self.continuous_batching, - vocab_size=self.config.vocab_size, + vocab_size=self.model.language_model.config.vocab_size, qaic_config=self.lang_model.model.qaic_config, ) diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 1075db784..f7473cbd0 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -105,6 +105,7 @@ def sampler_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -181,6 +182,7 @@ def sampler_forward( position_ids=position_ids, image_idx=image_idx, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) if batch_index is not None: forward_kwargs["batch_index"] = batch_index @@ -195,6 +197,7 @@ def sampler_forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index ca4a3abef..99eb98a73 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -6,10 +6,10 @@ # ----------------------------------------------------------------------------- from typing import List, Union -from transformers import AutoConfig, AutoProcessor import numpy as np import pytest +from transformers import AutoProcessor from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession