diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index cf8664236..a20fc4cb3 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -177,4 +177,4 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration"} +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} diff --git a/QEfficient/transformers/models/gemma3/__init__.py b/QEfficient/transformers/models/gemma3/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/gemma3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py new file mode 100644 index 000000000..44aa06c83 --- /dev/null +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -0,0 +1,778 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache, HybridCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3Config, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3TextModel, + logger, + repeat_kv, + rotate_half, +) + +from QEfficient.customop.rms_norm import CustomRMSNorm +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config + + +class GemmaRMSNormFunc(torch.autograd.Function): + @staticmethod + def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): + hidden_states = hidden_states.to(torch.float32) + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + return weight * hidden_states + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value: + return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) + + +class QEffGemma3CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def forward(self, hidden_states): + return GemmaRMSNormFunc.apply( + hidden_states, + self.weight.float() + 1.0, + self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps, + ) + + +class QEffGemma3RotaryEmbedding(nn.Module): + """ + Copied from Gemma2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffGemma3Attention(Gemma3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma3Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffGemma3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + + config = copy.deepcopy(self.config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + + self.rotary_emb_local = QEffGemma3RotaryEmbedding( + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if self.is_sliding: + cos, sin = self.rotary_emb_local(value_states, seq_len=kv_seq_len) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + # import ipdb; ipdb.set_trace() + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + + if attention_mask is not None: # no matter the length, we just slice it + attn_weights = torch.where(attention_mask.bool(), torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class QEffGemma3DecoderLayer(Gemma3DecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class QEffGemma3TextModel(Gemma3TextModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] + + Args: + config: Gemma3TextConfig + """ + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + # return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens, sliding_window=self.config.sliding_window + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if use_cache: + next_cache = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffGemma3ForCausalLMModel(Gemma3ForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class QEffGemma3EncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.vision_tower + + def forward(self, pixel_values): + image_features = self.model.get_image_features(pixel_values=pixel_values) + return image_features + + +class QEffGemma3DecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.language_model + self.config = self.model.config + + def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return outputs.logits, vision_embeds, image_idx, outputs.past_key_values + + +class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffGemma3EncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffGemma3DecoderWrapper(self) + + def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + image_features = self.get_image_features(pixel_values=pixel_values) + inputs_embeds = self.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.config.image_token_index + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return outputs.logits, pixel_values, image_idx, outputs.past_key_values + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: int, + kv_offload: bool = False, + **compiler_options, + ): + prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + if img_size is None and hasattr(self.config.vision_config, "image_size"): + img_size = getattr(self.config.vision_config, "image_size") + elif img_size is None: + img_size = 896 # FIXME based on gemma3 Image size + logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + + vision = [ + { + "batch_size": batch_size, + "img_size": img_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + ] + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + ] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + vision_dynamic_axes = {} + lang_dynamic_axes = {} + lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} + lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} + lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} + vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} + + pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def get_dummy_inputs(self, kv_offload: bool = False): + if vis_cfg := getattr(self.config, "vision_config", None): + img_size = getattr(vis_cfg, "image_size", 896) + else: + img_size = 896 + + mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + # Define shapes + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + inputs_shapes["vision_embeds"] = ( + 1, # constants.INTERN_NUM_PATCHES, + mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, + self.language_model.config.hidden_size, # 5120 + ) + inputs_shapes["position_ids"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.INTERN_NUM_CHANNELS, + img_size, + img_size, + ) + inputs_shapes["image_idx"] = (1, 1) + + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.language_model.config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo( + name="pixel_values", + datatype=torch.float32, + shape=("batch_size", 3, "img_size", "img_size"), + ), + ] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ab3e834c8..729d11462 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -769,12 +769,13 @@ def kv_offload_generate( device_ids: List[int] = None, generation_len: int = None, ): - if not self.vision_model.qpc_path or not self.lang_model.qpc_path: - raise TypeError("Please run compile API for vision and language model first!") + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) - vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) @@ -849,7 +850,8 @@ def kv_offload_generate( if not_mllama: lang_inputs["image_idx"] = np.array([[0]]) - vision_session.deactivate() + if self.vision_model.qpc_path: + vision_session.deactivate() lang_session.activate() lang_session.set_buffers(vision_outputs) @@ -859,6 +861,8 @@ def kv_offload_generate( prefill_start = perf_counter() # Run prefill + chunk_inputs = lang_inputs.copy() + chunk_inputs["index"] = np.array([[0]]) for i in range(num_chunks): chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ @@ -1087,11 +1091,8 @@ def cloud_ai_100_generate( qpc_session = QAICInferenceSession( self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False ) - batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) - pad_token_id = 1 - # Skip inputs/outputs qpc_session.skip_buffers( [ @@ -1699,6 +1700,7 @@ def build_decode_specialization( "ctx_len": ctx_len, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -1785,7 +1787,6 @@ def compile( # --- Specializations --- specializations = [] - if prefill_only is None or prefill_only or prefill_seq_len == 1: specializations.append( self.build_prefill_specialization( @@ -1833,11 +1834,6 @@ def compile( **compiler_options, ) - if compiler_options.get("io_encrypt", None): - logger.warning( - "Compilation for IO-Encrypt has been successfully completed. However, Efficient-Transformers do not support IO-Encrypt execution. Please run the execution separately with QPC compiled without io-encrypt." - ) - return qpc_path # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate @@ -1890,7 +1886,7 @@ def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[ elif num_speculative_tokens is None: raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` instance variable is True.") - if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + if not isinstance(num_speculative_tokens, int) and num_speculative_tokens: ValueError( f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 34d26f104..502ea0456 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -36,6 +36,14 @@ Gemma2Model, Gemma2RMSNorm, ) +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3RMSNorm, + Gemma3TextModel, +) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( GPTBigCodeAttention, @@ -171,6 +179,14 @@ QEffGemma2ForCausalLM, QEffGemma2Model, ) +from QEfficient.transformers.models.gemma3.modeling_gemma3 import ( + QEffGemma3Attention, + QEffGemma3CustomRMSNormAIC, + QEffGemma3DecoderLayer, + QEffGemma3ForCausalLMModel, + QEffGemma3ForConditionalGeneration, + QEffGemma3TextModel, +) from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( QEffGPT2Attention, QEffGPT2Block, @@ -319,6 +335,7 @@ class CustomOpsTransform(ModuleMappingTransform): MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, + Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, } @@ -373,6 +390,12 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2DecoderLayer: QEffGemma2DecoderLayer, Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, + # Gemma3 + Gemma3Attention: QEffGemma3Attention, + Gemma3DecoderLayer: QEffGemma3DecoderLayer, + Gemma3TextModel: QEffGemma3TextModel, + Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, + Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index a0e7d6893..9bda72d4e 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -138,7 +138,7 @@ class QnnConstants: # Converter Arguments FLOAT_BITWIDTH = 16 FLOAT_BIAS_BITWIDTH = 32 - CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification " + CONVERTER_DEFAULT_ARGS = "--preserve_io_datatype --onnx_skip_simplification --target_backend AIC " # Context-Binary-Generator Arguments LOG_LEVEL = "error" diff --git a/QEfficient/utils/generate_qnn_network_specialization_config.py b/QEfficient/utils/generate_qnn_network_specialization_config.py index ab057fd46..14d83efda 100644 --- a/QEfficient/utils/generate_qnn_network_specialization_config.py +++ b/QEfficient/utils/generate_qnn_network_specialization_config.py @@ -66,55 +66,44 @@ def generate_qnn_specialization( raise AttributeError(f"ERROR: {input_shape} Shape not Found") shapes.append(shape) - # Filling shape value for nodes with shape size != 2, example: past_key / past_value nodes. - if len(shapes) != 2: + shape_list = [] + prefill_decode_shapes = False + if len(specializations) > 1 and (node.name in ["input_ids", "position_ids"]): + prefill_decode_shapes = True + for input_shape in shapes: + # If shape contains the parameter string, it value is extracted from the specialization file. + if isinstance(input_shape, str): + if input_shape in specializations[0]: + shape_list.append(int(specializations[0][input_shape])) + if ( + not prefill_decode_shapes + and len(specializations) > 1 + and input_shape in specializations[1] + and specializations[0][input_shape] != specializations[1][input_shape] + ): + prefill_decode_shapes = True + else: + raise AttributeError(f"ERROR: {input_shape} is required in specializations") + # If shape contains the value, then that value is used as it is. + else: + shape_list.append(input_shape) + # Calculated shape is now assigned to the input node. + input_info["Shape"] = str(shape_list).replace("[", "(").replace("]", ")") + + if prefill_decode_shapes: shape_list = [] for input_shape in shapes: # If shape contains the parameter string, it value is extracted from the specialization file. if isinstance(input_shape, str): - if input_shape in specializations[0]: - shape_list.append(int(specializations[0][input_shape])) + if input_shape in specializations[1]: + shape_list.append(int(specializations[1][input_shape])) else: raise AttributeError(f"ERROR: {input_shape} is required in specializations") # If shape contains the value, then that value is used as it is. else: shape_list.append(input_shape) - # Calculated shape is now assigned to the input node. - input_info["Shape"] = str(shape_list).replace("[", "(").replace("]", ")") - # If shape value for nodes is with shape size == 2, example: input_ids, position_ids, etc. - else: - shape_list = [] - for input_shape in shapes: - if isinstance(input_shape, str): - if input_shape in specializations[0]: - shape_list.append(int(specializations[0][input_shape])) - else: - raise AttributeError(f"ERROR: {input_shape} is required in specializations") - else: - shape_list.append(input_shape) - # If specializations file contains more than one parameters list, then first list is used for prefill and second one for decode graph. - if len(specializations) > 1: - prefill_shape_list = shape_list - decode_shape_list = [] - for input_shape in shapes: - if isinstance(input_shape, str): - if input_shape in specializations[1]: - decode_shape_list.append(int(specializations[1][input_shape])) - else: - raise AttributeError(f"ERROR: {input_shape} is required in specializations") - else: - decode_shape_list.append(input_shape) - - input_info["Shape"] = ( - str(prefill_shape_list).replace("[", "(").replace("]", ")") - + ", " - + str(decode_shape_list).replace("[", "(").replace("]", ")") - ) - - # If specializations file contains only one parameters list, then that list is used for decode graph information. - else: - input_info["Shape"] = str(shape_list).replace("[", "(").replace("]", ")") + input_info["Shape"] += ", " + str(shape_list).replace("[", "(").replace("]", ")") # Finally, input node is created with its name, and desired model parameters {DataType, Shape} input_nodes_info.append({"Name": node.name, "Desired Model Parameters": input_info}) diff --git a/examples/causal_lm_examples/fp32_nodes_gemma3_27b_image.yaml b/examples/causal_lm_examples/fp32_nodes_gemma3_27b_image.yaml new file mode 100644 index 000000000..b8dc5fe5e --- /dev/null +++ b/examples/causal_lm_examples/fp32_nodes_gemma3_27b_image.yaml @@ -0,0 +1,499 @@ + FP32NodeInstanceNames: + - /language_model/model/layers.0/Add_1_output_0 + - /language_model/model/layers.0/Add_output_0 + - /language_model/model/layers.1/Add_1_output_0 + - /language_model/model/layers.1/Add_output_0 + - /language_model/model/layers.10/Add_1_output_0 + - /language_model/model/layers.10/Add_output_0 + - /language_model/model/layers.11/Add_1_output_0 + - /language_model/model/layers.11/Add_output_0 + - /language_model/model/layers.12/Add_1_output_0 + - /language_model/model/layers.12/Add_output_0 + - /language_model/model/layers.13/Add_1_output_0 + - /language_model/model/layers.13/Add_output_0 + - /language_model/model/layers.14/Add_1_output_0 + - /language_model/model/layers.14/Add_output_0 + - /language_model/model/layers.15/Add_1_output_0 + - /language_model/model/layers.15/Add_output_0 + - /language_model/model/layers.16/Add_1_output_0 + - /language_model/model/layers.16/Add_output_0 + - /language_model/model/layers.17/Add_1_output_0 + - /language_model/model/layers.17/Add_output_0 + - /language_model/model/layers.18/Add_1_output_0 + - /language_model/model/layers.18/Add_output_0 + - /language_model/model/layers.19/Add_1_output_0 + - /language_model/model/layers.19/Add_output_0 + - /language_model/model/layers.2/Add_1_output_0 + - /language_model/model/layers.2/Add_output_0 + - /language_model/model/layers.20/Add_1_output_0 + - /language_model/model/layers.20/Add_output_0 + - /language_model/model/layers.21/Add_1_output_0 + - /language_model/model/layers.21/Add_output_0 + - /language_model/model/layers.22/Add_1_output_0 + - /language_model/model/layers.22/Add_output_0 + - /language_model/model/layers.23/Add_1_output_0 + - /language_model/model/layers.23/Add_output_0 + - /language_model/model/layers.24/Add_1_output_0 + - /language_model/model/layers.24/Add_output_0 + - /language_model/model/layers.25/Add_1_output_0 + - /language_model/model/layers.25/Add_output_0 + - /language_model/model/layers.26/Add_1_output_0 + - /language_model/model/layers.26/Add_output_0 + - /language_model/model/layers.27/Add_1_output_0 + - /language_model/model/layers.27/Add_output_0 + - /language_model/model/layers.28/Add_1_output_0 + - /language_model/model/layers.28/Add_output_0 + - /language_model/model/layers.29/Add_1_output_0 + - /language_model/model/layers.29/Add_output_0 + - /language_model/model/layers.3/Add_1_output_0 + - /language_model/model/layers.3/Add_output_0 + - /language_model/model/layers.30/Add_1_output_0 + - /language_model/model/layers.30/Add_output_0 + - /language_model/model/layers.31/Add_1_output_0 + - /language_model/model/layers.31/Add_output_0 + - /language_model/model/layers.32/Add_1_output_0 + - /language_model/model/layers.32/Add_output_0 + - /language_model/model/layers.33/Add_1_output_0 + - /language_model/model/layers.33/Add_output_0 + - /language_model/model/layers.4/Add_1_output_0 + - /language_model/model/layers.4/Add_output_0 + - /language_model/model/layers.5/Add_1_output_0 + - /language_model/model/layers.5/Add_output_0 + - /language_model/model/layers.6/Add_1_output_0 + - /language_model/model/layers.6/Add_output_0 + - /language_model/model/layers.7/Add_1_output_0 + - /language_model/model/layers.7/Add_output_0 + - /language_model/model/layers.8/Add_1_output_0 + - /language_model/model/layers.8/Add_output_0 + - /language_model/model/layers.9/Add_1_output_0 + - /language_model/model/layers.9/Add_output_0 + - /language_model/model/norm/Add_output_0 + - /language_model/model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/norm/CustomRMSNorm_output_0 + - /language_model/model/layers.34/Add_1_output_0' + - /language_model/model/layers.34/Add_output_0' + - /language_model/model/layers.34/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/Add_1_output_0' + - /language_model/model/layers.35/Add_output_0' + - /language_model/model/layers.35/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/Add_1_output_0' + - /language_model/model/layers.36/Add_output_0' + - /language_model/model/layers.36/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/Add_1_output_0' + - /language_model/model/layers.37/Add_output_0' + - /language_model/model/layers.37/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/Add_1_output_0' + - /language_model/model/layers.38/Add_output_0' + - /language_model/model/layers.38/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/Add_1_output_0' + - /language_model/model/layers.39/Add_output_0' + - /language_model/model/layers.39/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/Add_1_output_0' + - /language_model/model/layers.40/Add_output_0' + - /language_model/model/layers.40/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/Add_1_output_0' + - /language_model/model/layers.41/Add_output_0' + - /language_model/model/layers.41/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/Add_1_output_0' + - /language_model/model/layers.42/Add_output_0' + - /language_model/model/layers.42/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/Add_1_output_0' + - /language_model/model/layers.43/Add_output_0' + - /language_model/model/layers.43/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/Add_1_output_0' + - /language_model/model/layers.44/Add_output_0' + - /language_model/model/layers.44/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/Add_1_output_0' + - /language_model/model/layers.45/Add_output_0' + - /language_model/model/layers.45/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/Add_1_output_0' + - /language_model/model/layers.46/Add_output_0' + - /language_model/model/layers.46/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/Add_1_output_0' + - /language_model/model/layers.47/Add_output_0' + - /language_model/model/layers.47/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/Add_1_output_0' + - /language_model/model/layers.48/Add_output_0' + - /language_model/model/layers.48/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/Add_1_output_0' + - /language_model/model/layers.49/Add_output_0' + - /language_model/model/layers.49/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/Add_1_output_0' + - /language_model/model/layers.50/Add_output_0' + - /language_model/model/layers.50/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/Add_1_output_0' + - /language_model/model/layers.51/Add_output_0' + - /language_model/model/layers.51/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/Add_1_output_0' + - /language_model/model/layers.52/Add_output_0' + - /language_model/model/layers.52/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/Add_1_output_0' + - /language_model/model/layers.53/Add_output_0' + - /language_model/model/layers.53/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/Add_1_output_0' + - /language_model/model/layers.54/Add_output_0' + - /language_model/model/layers.54/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/Add_1_output_0' + - /language_model/model/layers.55/Add_output_0' + - /language_model/model/layers.55/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/Add_1_output_0' + - /language_model/model/layers.56/Add_output_0' + - /language_model/model/layers.56/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/Add_1_output_0' + - /language_model/model/layers.57/Add_output_0' + - /language_model/model/layers.57/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/Add_1_output_0' + - /language_model/model/layers.58/Add_output_0' + - /language_model/model/layers.58/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/Add_1_output_0' + - /language_model/model/layers.59/Add_output_0' + - /language_model/model/layers.59/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/Add_1_output_0' + - /language_model/model/layers.60/Add_output_0' + - /language_model/model/layers.60/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/Add_1_output_0' + - /language_model/model/layers.61/Add_output_0' + - /language_model/model/layers.61/input_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0' + - /language_model/model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0' \ No newline at end of file diff --git a/examples/causal_lm_examples/fp32_nodes_gemma3_27b_text.yaml b/examples/causal_lm_examples/fp32_nodes_gemma3_27b_text.yaml new file mode 100644 index 000000000..2d1fc6763 --- /dev/null +++ b/examples/causal_lm_examples/fp32_nodes_gemma3_27b_text.yaml @@ -0,0 +1,500 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.34/Add_1_output_0 + - /model/layers.34/Add_output_0 + - /model/layers.35/Add_1_output_0 + - /model/layers.35/Add_output_0 + - /model/layers.36/Add_1_output_0 + - /model/layers.36/Add_output_0 + - /model/layers.37/Add_1_output_0 + - /model/layers.37/Add_output_0 + - /model/layers.38/Add_1_output_0 + - /model/layers.38/Add_output_0 + - /model/layers.39/Add_1_output_0 + - /model/layers.39/Add_output_0 + - /model/layers.40/Add_1_output_0 + - /model/layers.40/Add_output_0 + - /model/layers.41/Add_1_output_0 + - /model/layers.41/Add_output_0 + - /model/layers.42/Add_1_output_0 + - /model/layers.42/Add_output_0 + - /model/layers.43/Add_1_output_0 + - /model/layers.43/Add_output_0 + - /model/layers.44/Add_1_output_0 + - /model/layers.44/Add_output_0 + - /model/layers.45/Add_1_output_0 + - /model/layers.45/Add_output_0 + - /model/layers.46/Add_1_output_0 + - /model/layers.46/Add_output_0 + - /model/layers.47/Add_1_output_0 + - /model/layers.47/Add_output_0 + - /model/layers.48/Add_1_output_0 + - /model/layers.48/Add_output_0 + - /model/layers.49/Add_1_output_0 + - /model/layers.49/Add_output_0 + - /model/layers.50/Add_1_output_0 + - /model/layers.50/Add_output_0 + - /model/layers.51/Add_1_output_0 + - /model/layers.51/Add_output_0 + - /model/layers.52/Add_1_output_0 + - /model/layers.52/Add_output_0 + - /model/layers.53/Add_1_output_0 + - /model/layers.53/Add_output_0 + - /model/layers.54/Add_1_output_0 + - /model/layers.54/Add_output_0 + - /model/layers.55/Add_1_output_0 + - /model/layers.55/Add_output_0 + - /model/layers.56/Add_1_output_0 + - /model/layers.56/Add_output_0 + - /model/layers.57/Add_1_output_0 + - /model/layers.57/Add_output_0 + - /model/layers.58/Add_1_output_0 + - /model/layers.58/Add_output_0 + - /model/layers.59/Add_1_output_0 + - /model/layers.59/Add_output_0 + - /model/layers.60/Add_1_output_0 + - /model/layers.60/Add_output_0 + - /model/layers.61/Add_1_output_0 + - /model/layers.61/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 + diff --git a/examples/causal_lm_examples/fp32_nodes_gemma3_4b_image.yaml b/examples/causal_lm_examples/fp32_nodes_gemma3_4b_image.yaml new file mode 100644 index 000000000..22b7b9636 --- /dev/null +++ b/examples/causal_lm_examples/fp32_nodes_gemma3_4b_image.yaml @@ -0,0 +1,275 @@ +FP32NodeInstanceNames: + - /language_model/model/layers.0/Add_1_output_0 + - /language_model/model/layers.0/Add_output_0 + - /language_model/model/layers.1/Add_1_output_0 + - /language_model/model/layers.1/Add_output_0 + - /language_model/model/layers.10/Add_1_output_0 + - /language_model/model/layers.10/Add_output_0 + - /language_model/model/layers.11/Add_1_output_0 + - /language_model/model/layers.11/Add_output_0 + - /language_model/model/layers.12/Add_1_output_0 + - /language_model/model/layers.12/Add_output_0 + - /language_model/model/layers.13/Add_1_output_0 + - /language_model/model/layers.13/Add_output_0 + - /language_model/model/layers.14/Add_1_output_0 + - /language_model/model/layers.14/Add_output_0 + - /language_model/model/layers.15/Add_1_output_0 + - /language_model/model/layers.15/Add_output_0 + - /language_model/model/layers.16/Add_1_output_0 + - /language_model/model/layers.16/Add_output_0 + - /language_model/model/layers.17/Add_1_output_0 + - /language_model/model/layers.17/Add_output_0 + - /language_model/model/layers.18/Add_1_output_0 + - /language_model/model/layers.18/Add_output_0 + - /language_model/model/layers.19/Add_1_output_0 + - /language_model/model/layers.19/Add_output_0 + - /language_model/model/layers.2/Add_1_output_0 + - /language_model/model/layers.2/Add_output_0 + - /language_model/model/layers.20/Add_1_output_0 + - /language_model/model/layers.20/Add_output_0 + - /language_model/model/layers.21/Add_1_output_0 + - /language_model/model/layers.21/Add_output_0 + - /language_model/model/layers.22/Add_1_output_0 + - /language_model/model/layers.22/Add_output_0 + - /language_model/model/layers.23/Add_1_output_0 + - /language_model/model/layers.23/Add_output_0 + - /language_model/model/layers.24/Add_1_output_0 + - /language_model/model/layers.24/Add_output_0 + - /language_model/model/layers.25/Add_1_output_0 + - /language_model/model/layers.25/Add_output_0 + - /language_model/model/layers.26/Add_1_output_0 + - /language_model/model/layers.26/Add_output_0 + - /language_model/model/layers.27/Add_1_output_0 + - /language_model/model/layers.27/Add_output_0 + - /language_model/model/layers.28/Add_1_output_0 + - /language_model/model/layers.28/Add_output_0 + - /language_model/model/layers.29/Add_1_output_0 + - /language_model/model/layers.29/Add_output_0 + - /language_model/model/layers.3/Add_1_output_0 + - /language_model/model/layers.3/Add_output_0 + - /language_model/model/layers.30/Add_1_output_0 + - /language_model/model/layers.30/Add_output_0 + - /language_model/model/layers.31/Add_1_output_0 + - /language_model/model/layers.31/Add_output_0 + - /language_model/model/layers.32/Add_1_output_0 + - /language_model/model/layers.32/Add_output_0 + - /language_model/model/layers.33/Add_1_output_0 + - /language_model/model/layers.33/Add_output_0 + - /language_model/model/layers.4/Add_1_output_0 + - /language_model/model/layers.4/Add_output_0 + - /language_model/model/layers.5/Add_1_output_0 + - /language_model/model/layers.5/Add_output_0 + - /language_model/model/layers.6/Add_1_output_0 + - /language_model/model/layers.6/Add_output_0 + - /language_model/model/layers.7/Add_1_output_0 + - /language_model/model/layers.7/Add_output_0 + - /language_model/model/layers.8/Add_1_output_0 + - /language_model/model/layers.8/Add_output_0 + - /language_model/model/layers.9/Add_1_output_0 + - /language_model/model/layers.9/Add_output_0 + - /language_model/model/norm/Add_output_0 + - /language_model/model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/model/norm/CustomRMSNorm_output_0 \ No newline at end of file diff --git a/examples/causal_lm_examples/fp32_nodes_gemma3_4b_text.yaml b/examples/causal_lm_examples/fp32_nodes_gemma3_4b_text.yaml new file mode 100644 index 000000000..494486e68 --- /dev/null +++ b/examples/causal_lm_examples/fp32_nodes_gemma3_4b_text.yaml @@ -0,0 +1,275 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 \ No newline at end of file diff --git a/examples/causal_lm_examples/gemma3_mm.py b/examples/causal_lm_examples/gemma3_mm.py new file mode 100644 index 000000000..761f43627 --- /dev/null +++ b/examples/causal_lm_examples/gemma3_mm.py @@ -0,0 +1,69 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "google/gemma-3-4b-it" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +# config.text_config.num_hidden_layers = 1 +# config.vision_config.num_hidden_layers = 2 + +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) +model.eval() + +qeff_model = QEFFAutoModelForImageTextToText(model, kv_offload=True) +# TODO: Map the Vision Encoder to FP16 Only and Disable MXFP6 For Better Accuracy. +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=896, + num_cores=16, + num_devices=1, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + mos=1, + node_precision_info="fp32_nodes_gemma3_4b_image.yaml", +) + +image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) + +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, +] + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +) +for key, value in inputs.items(): + print(f"key : {key} and value shape is {value.shape}") + +inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) +streamer = TextStreamer(tokenizer) +output = qeff_model.generate(inputs=inputs, device_ids=[0], generation_len=100) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/causal_lm_examples/gemma3_text.py b/examples/causal_lm_examples/gemma3_text.py new file mode 100644 index 000000000..bc39146b2 --- /dev/null +++ b/examples/causal_lm_examples/gemma3_text.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +from transformers import Gemma3ForCausalLM +from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants + + +def add_named_scopes(model): + for name, module in model.named_modules(): + if isinstance(module, Gemma3RMSNorm): + module._onnx_scope_name = f"/{name}" + + +torch.manual_seed(42) +model_id = "google/gemma-3-4b-it" +model = Gemma3ForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, use_cache=True, attn_implementation="eager" +) +model.eval() + +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=True) + +onnx_model_path = qeff_model.export() + +qpc_path = qeff_model.compile( + prefill_seq_len=Constants.PROMPT_LEN, + ctx_len=Constants.CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + full_batch_size=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + node_precision_info="fp32_nodes_gemma3_4b_text.yaml", +) +print(f"qpc path is {qpc_path}") +exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR, device_ids=[0]) diff --git a/examples/llama4_example.py b/examples/llama4_example.py index ad271de99..a9ed37a38 100644 --- a/examples/llama4_example.py +++ b/examples/llama4_example.py @@ -19,77 +19,93 @@ model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) model.eval() - +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) qeff_model = QEFFAutoModelForImageTextToText(model, kv_offload=True) -# TODO: Map the Vision Encoder to FP16 Only and Disable MXFP6 For Better Accuracy. -qeff_model.compile( - prefill_seq_len=128, - ctx_len=3072, - img_size=336, - num_cores=16, - num_devices=8, - max_num_tiles=17, - mxfp6_matmul=True, - mxint8_kv_cache=True, - aic_enable_depth_first=True, - mos=1, -) +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = True -### IMAGE + TEXT ### +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=8, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) -image_url = ( - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" -) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] -messages = [ - { - "role": "user", - "content": [ - {"type": "image", "url": image_url}, - {"type": "text", "text": "Can you describe the image in detail."}, - ], - }, -] + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) -tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) -processor = AutoProcessor.from_pretrained(model_id) -inputs = processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", -) -inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) -streamer = TextStreamer(tokenizer) -output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) -print(output.generated_ids) -print(tokenizer.batch_decode(output.generated_ids)) -print(output) -print() + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) -### ONLY TEXT ### +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=8, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) -messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Tell me about yourself."}, - ], - }, -] + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) -inputs = processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", -) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] -streamer = TextStreamer(tokenizer) -output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) -print(output.generated_ids) -print(tokenizer.batch_decode(output.generated_ids)) -print(output) + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 4dae9507a..e6a69d5fb 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -13,7 +13,7 @@ pipeline { steps { sh ''' . ~/.bashrc - sudo docker run --privileged -dit --name ${BUILD_TAG} -v ./:/efficient-transformers -v ${HF_PATH}:${DOCKER_HF_PATH} ${DOCKER_LATEST}:master_latest + sudo docker run --privileged -dit --name ${BUILD_TAG} -e HF_TOKEN=${HF_TOKEN} -v ./:/efficient-transformers -v ${HF_PATH}:${DOCKER_HF_PATH} ${DOCKER_LATEST}:master_latest sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && apt update && @@ -106,7 +106,7 @@ pipeline { stage('vLLM Tests') { steps { catchError(buildResult: 'FAILURE', stageResult: 'FAILURE') { - build job: 'qefficient_vllm_upstream', + build job: 'qefficient_vllm_upstream', parameters: [string(name: 'NAME', value: "${BUILD_TAG}")], propagate: true, wait: true @@ -144,13 +144,32 @@ pipeline { mkdir -p $PWD/Qnn_non_cli && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Qnn_non_cli && - pytest tests -m '(not cli) and (qnn) and (on_qaic)' --ignore tests/vllm --junitxml=tests/tests_log5.xml && + pytest tests -m '(not cli) and (qnn) and (on_qaic) and (not multimodal)' --ignore tests/vllm --junitxml=tests/tests_log5.xml && junitparser merge tests/tests_log5.xml tests/tests_log.xml && deactivate" ''' } } } + stage('QNN MultiModal Tests') { + steps { + timeout(time: 60, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + source /qnn_sdk/bin/envsetup.sh && + source /qnn_sdk/bin/envcheck -c && + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qnn_multimodal && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qnn_multimodal && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (qnn)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log7.xml && + junitparser merge tests/tests_log7.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } } post { diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index 16f9945ee..658d8d68f 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -5,8 +5,9 @@ # # ---------------------------------------------------------------------------- +import os from io import BytesIO -from typing import List +from typing import List, Optional import pytest import requests @@ -23,7 +24,8 @@ from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.utils import hf_download -from QEfficient.utils._utils import get_num_layers_vlm +from QEfficient.utils._utils import create_json, get_num_layers_vlm +from QEfficient.utils.constants import QnnConstants from QEfficient.utils.device_utils import get_available_device_id from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerVlm from QEfficient.utils.test_utils import InternProcessor @@ -86,6 +88,28 @@ "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", 1, ), + ( + "google/gemma-3-4b-it", + False, + 1, + 128, + 1000, + 896, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "Can you describe the image in detail.", + 1, + ), + ( + "google/gemma-3-4b-it", + True, + 1, + 128, + 1000, + 896, + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "Can you describe the image in detail.", + 1, + ), # ( # "meta-llama/Llama-3.2-11B-Vision-Instruct", # True, @@ -176,6 +200,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( n_layer: int = 1, kv_offload: bool = False, num_devices: int = 1, + enable_qnn: Optional[bool] = False, + qnn_config: Optional[str] = None, ): model_config = {"model_name": model_name} model_config["img_size"] = img_size @@ -237,6 +263,8 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( prefill_seq_len=prompt_len, ctx_len=ctx_len, mxfp6=False, + enable_qnn=enable_qnn, + qnn_config=qnn_config, ) inputs = processor(images=image, text=prompt, return_tensors="pt") if "pixel_values" in inputs: @@ -259,6 +287,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( n_layer: int = 1, kv_offload: bool = False, num_devices: int = 1, + enable_qnn: Optional[bool] = False, + qnn_config: Optional[str] = None, ): model_config = {"model_name": model_name} @@ -324,6 +354,8 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( prefill_seq_len=prompt_len, ctx_len=ctx_len, mxfp6=False, + enable_qnn=enable_qnn, + qnn_config=qnn_config, ) print("QPC Outputs (QAIC):") output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer) @@ -359,6 +391,42 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( ) +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config +) +def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_qnn( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "meta-llama/Llama-4-Scout-17B-16E-Instruct" or model_name == "google/gemma-3-4b-it": + pytest.skip("QNN is not supported for these models yet.") + + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + enable_qnn=True, + qnn_config=qnn_config_json_path, + ) + + @pytest.mark.on_qaic @pytest.mark.multimodal @pytest.mark.parametrize( @@ -378,3 +446,30 @@ def test_image_text_to_text_intern_pytorch_vs_kv_vs_ort_vs_ai100( batch_size=batch_size, kv_offload=kv_offload, ) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer", intern_model_config +) +def test_image_text_to_text_intern_pytorch_vs_kv_vs_ort_vs_ai100_qnn( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, query, n_layer +): + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_url=img_url, + query=query, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + enable_qnn=True, + qnn_config=qnn_config_json_path, + )