diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..bbd937d52 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -6,10 +6,11 @@ # ----------------------------------------------------------------------------- +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( CtxGatherFunc, @@ -23,72 +24,20 @@ ) -class QEffDynamicCache(DynamicCache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - - Optimized implementation for the Cloud AI 100 to reuse KV Cache. - - get the position_ids input using kwargs. - - Use custom Onnxscript ops to write optimized version to generate Onnx model. - - """ - - def write_only(self, key_states, value_states, layer_idx, cache_kwargs): +class QEffDynamicLayer(DynamicLayer): + def read_only(self, cache_kwargs): """ - Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Reads the `key_states` and `value_states` for the layer. Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - """ - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - else: - position_ids = cache_kwargs.get("position_ids") - batch_index = cache_kwargs.get("batch_index", None) - - # Scatter - if batch_index is not None: - invalid_scatter_index = torch.iinfo(torch.int32).max - scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) - else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], position_ids, value_states - ) - - def read_only(self, layer_idx, cache_kwargs): - """ - Reads the `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. Return: A tuple containing the updated key and value states. """ - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + # Gather + k_out, v_out = self.keys, self.values position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) ctx_len = k_out.shape[2] @@ -113,23 +62,51 @@ def read_only(self, layer_idx, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + def write_only(self, key_states, value_states, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + + # Scatter + if batch_index is not None: + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) + + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) + else: + self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + def update( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for the layer. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. @@ -137,10 +114,10 @@ def update( A tuple containing the updated key and value states. """ # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -150,20 +127,14 @@ def update( invalid_scatter_index = torch.iinfo(torch.int32).max scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - self.key_cache[layer_idx] = CtxScatterFuncCB.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) + self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFuncCB.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) + self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], position_ids, value_states - ) + self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + k_out, v_out = self.keys, self.values # Gather ctx_len = k_out.shape[2] @@ -187,23 +158,21 @@ def update( return k_out, v_out + # TODO:This function will be depercated in future. def update3D( self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Updates the cache with the new `key_states` and `value_states` for the layer. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. @@ -211,32 +180,27 @@ def update3D( A tuple containing the updated key and value states. """ # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states + if self.keys is None: + self.keys = key_states + self.values = value_states + k_out, v_out = self.keys, self.values else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) + # Scatter if batch_index is not None: invalid_scatter_index = torch.iinfo(torch.int32).max scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids) - self.key_cache[layer_idx] = CtxScatterFuncCB3D.apply( - self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states - ) - - self.value_cache[layer_idx] = CtxScatterFuncCB3D.apply( - self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states - ) + self.keys = CtxScatterFuncCB3D.apply(self.keys, batch_index, scatter_position_ids, key_states) + self.values = CtxScatterFuncCB3D.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.key_cache[layer_idx] = CtxScatterFunc3D.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc3D.apply( - self.value_cache[layer_idx], position_ids, value_states - ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + self.keys = CtxScatterFunc3D.apply(self.keys, position_ids, key_states) + self.values = CtxScatterFunc3D.apply(self.values, position_ids, value_states) + + k_out, v_out = self.keys, self.values # Gather ctx_len = k_out.shape[1] @@ -260,6 +224,89 @@ def update3D( return k_out, v_out +class QEffDynamicCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + - Optimized implementation for the Cloud AI 100 to reuse KV Cache. + - get the position_ids input using kwargs. + - Use custom Onnxscript ops to write optimized version to generate Onnx model. + + """ + + def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + # Remove layer_classes if present to avoid duplicate argument + kwargs.pop("layer_classes", None) + from transformers.cache_utils import Cache # Import here to avoid circular import + + Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + if ddp_cache_data is not None: + for key_states, value_states in ddp_cache_data: + self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + + def read_only(self, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only(cache_kwargs) + + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only(key_states, value_states, cache_kwargs) + + # TODO:This function will be depercated in future. + def update3D( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs) + + class QEffEncoderDecoderCache(EncoderDecoderCache): """ Updated the `EncoderDecoderCache` to use the `QEffDynamicCache` for both self-attention and cross-attention caches. @@ -285,6 +332,7 @@ def from_legacy_cache( return cache +# TODO:This function will be depercated in future. class QEffHybridCache(HybridCache): def __init__(self, config, batch_size, max_cache_len): super().__init__(config, batch_size, max_cache_len=max_cache_len) @@ -388,6 +436,7 @@ def update( return k_out, v_out +# TODO:This function will be depercated in future. class QEffHybridChunkedCache(HybridChunkedCache): def __len__(self): """ diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 72b7acd98..c692d1beb 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -50,6 +50,7 @@ LlamaForCausalLM, LlamaModel, LlamaRMSNorm, + LlamaRotaryEmbedding, ) from transformers.models.mistral.modeling_mistral import ( MistralAttention, @@ -93,7 +94,7 @@ # Placeholder for all non-transformer models from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, - QeffCodeGenBlock, + QEffCodeGenBlock, QEffCodeGenForCausalLM, QEffCodeGenModel, ) @@ -122,6 +123,7 @@ QEffLlamaDecoderLayer, QEffLlamaForCausalLM, QEffLlamaModel, + QEffLlamaRotaryEmbedding, ) from .models.mistral.modeling_mistral import ( QEffMistralAttention, @@ -203,6 +205,7 @@ LlamaForCausalLM: QEffLlamaForCausalLM, LlamaDecoderLayer: QEffLlamaDecoderLayer, LlamaRMSNorm: CustomRMSNormAIC, + LlamaRotaryEmbedding: QEffLlamaRotaryEmbedding, # Gemma model layers GemmaModel: QEffGemmaModel, GemmaAttention: QEffGemmaAttention, @@ -224,7 +227,7 @@ CodeGenAttention: QEffCodeGenAttention, CodeGenModel: QEffCodeGenModel, CodeGenForCausalLM: QEffCodeGenForCausalLM, - CodeGenBlock: QeffCodeGenBlock, + CodeGenBlock: QEffCodeGenBlock, # Mistral model layers MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index e75181424..776bfce43 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -47,8 +47,6 @@ def _attn( attn_weights = torch.matmul(query, key.transpose(-1, -2)) - attn_weights = attn_weights / self.scale_attn - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device) @@ -57,6 +55,7 @@ def _attn( # Apply the attention mask attn_weights = torch.where(attention_mask, mask_value, attn_weights) + attn_weights = attn_weights / self.scale_attn attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -124,23 +123,8 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - # Update the cache_kwargs with position_ids for Cloud AI 100 - past_key_value = layer_past - cache_kwargs = { - "position_ids": position_ids, - "batch_index": batch_index, - } - pkv = QEffDynamicCache() - pkv.key_cache.append(past_key_value[0]) - pkv.value_cache.append(past_key_value[1]) - key, value = pkv.update(key, value, 0, cache_kwargs) - - if use_cache is True: - # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. - # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 - present = (pkv.key_cache[0].to(hidden_states.dtype), pkv.value_cache[0]) - else: - present = None + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) @@ -148,12 +132,7 @@ def forward( attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffCodeGenModel(CodeGenModel): @@ -167,7 +146,7 @@ class QEffCodeGenModel(CodeGenModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -179,7 +158,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + **kwargs, # NOOP kwargs, for now + ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -200,20 +180,21 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length = inputs_embeds.shape[1] + if cache_position is None: + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) # Attention mask. if attention_mask is not None: @@ -237,7 +218,7 @@ def forward( elif attention_mask is None: # 4d mask is passed through the layers - attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_length) + attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -245,32 +226,25 @@ def forward( # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) + output_shape = (-1, seq_length, hidden_states.size(-1)) - output_shape = input_shape + (hidden_states.size(-1),) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( - hidden_states=hidden_states, - layer_past=layer_past, + hidden_states, + layer_past=past_key_values, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -281,11 +255,9 @@ def forward( ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.ln_f(hidden_states) @@ -294,12 +266,17 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -330,12 +307,6 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( @@ -372,9 +343,7 @@ def forward( ) -class QeffCodeGenBlock(CodeGenBlock): - # Ignore copy - +class QEffCodeGenBlock(CodeGenBlock): def forward( self, hidden_states: Optional[torch.FloatTensor], @@ -389,7 +358,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -400,15 +369,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + hidden_states = attn_outputs + feed_forward_hidden_states + residual - return outputs # hidden_states, present, (attentions) + return hidden_states, attn_weights diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 79e3ebc01..8f2c3730d 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -135,13 +135,12 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = key_layer.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) if attention_mask is not None: @@ -162,10 +161,7 @@ def forward( attn_output = self.dense(attn_output) - if output_attentions: - return attn_output, layer_past, attention_scores - else: - return attn_output, layer_past + return attn_output, attention_scores class QEffFalconDecoderLayer(FalconDecoderLayer): @@ -193,7 +189,7 @@ def forward( attention_layernorm_out = self.input_layernorm(hidden_states) # Self attention. - attn_outputs = self.self_attention( + attention_output, attn_weights = self.self_attention( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, @@ -207,8 +203,6 @@ def forward( cache_position=cache_position, ) - attention_output = attn_outputs[0] - if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out @@ -225,8 +219,6 @@ def forward( ): mlp_layernorm_out = attention_layernorm_out - outputs = attn_outputs[1:] - # MLP. mlp_output = self.mlp(mlp_layernorm_out) @@ -235,12 +227,7 @@ def forward( output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, past_kv, attentions + return output, attn_weights class QEffFalconModel(FalconModel): @@ -367,22 +354,13 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 0cefbcfee..eea1e3898 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -104,7 +104,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -149,18 +148,15 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -169,12 +165,12 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights class QEffGemmaDecoderLayer(GemmaDecoderLayer): @@ -191,7 +187,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -202,9 +197,6 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -217,13 +209,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -236,15 +227,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffGemmaModel(GemmaModel): @@ -263,18 +246,14 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -310,27 +289,21 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) @@ -341,13 +314,11 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, - attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class QEffGemmaForCausalLM(GemmaForCausalLM): @@ -365,21 +336,14 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -389,19 +353,15 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - - logits = self.lm_head(hidden_states).float() - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float().float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 173da1798..be3ba942d 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -155,9 +155,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 851bb9436..20b7036fd 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -10,7 +10,7 @@ import torch from torch import nn -from transformers.cache_utils import Cache, HybridCache +from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -28,7 +28,7 @@ ) from QEfficient.customop.rms_norm import CustomRMSNorm -from QEfficient.transformers.cache_utils import QEffHybridCache +from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants from QEfficient.utils._utils import IOInfo @@ -231,7 +231,6 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -239,7 +238,6 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if self.is_sliding: cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings) else: @@ -309,7 +307,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - if self.is_sliding: + if self.self_attn.is_sliding: attention_mask = _create_causal_mask( position_ids=position_ids, target_length=past_seen_tokens, sliding_window=self.config.sliding_window ) @@ -364,7 +362,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -396,8 +394,7 @@ def forward( if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) # return_legacy_cache = True - past_key_values = QEffHybridCache.from_legacy_cache(self.config, past_key_values) - + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -427,31 +424,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - last_cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - last_cache_position=last_cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] @@ -481,7 +465,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, + past_key_values: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -606,6 +590,7 @@ def __init__(self, model): self.model = model self.language_model = self.model.language_model self.config = self.model.config + self.lm_head = self.model.lm_head def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) @@ -617,11 +602,15 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) - outputs = self.model.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return outputs.logits, vision_embeds, image_idx, outputs.past_key_values + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, vision_embeds, image_idx, outputs.past_key_values class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration): @@ -646,7 +635,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return outputs.logits, pixel_values, image_idx, outputs.past_key_values + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, pixel_values, image_idx, outputs.past_key_values def get_specializations( self, diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index a2b84c139..d68a65430 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -9,13 +9,14 @@ import torch from torch import nn +from transformers import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model -from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.cache_utils import QEffDynamicCache, QEffEncoderDecoderCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -63,18 +64,29 @@ class QEffGPT2Attention(GPT2Attention): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + if past_key_value is not None: + if isinstance(past_key_value, QEffEncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " @@ -82,31 +94,39 @@ def forward( ) query_states = self.q_attn(hidden_states) - key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask + + # Try to get key/value states from cache if possible + if past_key_value is not None and is_updated: + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) + else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) shape_q = (*query_states.shape[:-1], -1, self.head_dim) - shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.view(shape_q).transpose(1, 2) - key_states = key_states.view(shape_kv).transpose(1, 2) - value_states = value_states.view(shape_kv).transpose(1, 2) - if layer_past is not None: - # Added for optimized GPT Attention for AI 100 KV Retention + if (past_key_value is not None and not is_cross_attention) or ( + past_key_value is not None and is_cross_attention and not is_updated + ): + # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - pkv = QEffDynamicCache() - pkv.key_cache.append(layer_past[0]) - pkv.value_cache.append(layer_past[1]) - key_states, value_states = pkv.update(key_states, value_states, 0, cache_kwargs) - - if use_cache is True: - present = (pkv.key_cache[0], pkv.value_cache[0]) - else: - present = None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -122,11 +142,7 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffGPT2Block(GPT2Block): @@ -139,7 +155,7 @@ class QEffGPT2Block(GPT2Block): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -154,9 +170,9 @@ def forward( ]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_output, self_attn_weights = self.attn( hidden_states, - layer_past=layer_past, + past_key_value=past_key_value, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -164,8 +180,6 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -180,18 +194,17 @@ def forward( hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( + cross_attn_outputs, cross_attn_weights = self.crossattention( hidden_states, + past_key_value=past_key_value, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) - attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -199,10 +212,11 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if encoder_hidden_states is not None: + outputs += (cross_attn_weights,) return outputs # hidden_states, present, (attentions, cross_attentions) @@ -256,14 +270,23 @@ def forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + return_legacy_cache = False + if use_cache: + if past_key_values is None: + return_legacy_cache = True + past_key_values = QEffDynamicCache() + elif not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if self.config.add_cross_attention and not isinstance(past_key_values, QEffEncoderDecoderCache): + past_key_values = QEffEncoderDecoderCache(past_key_values, QEffDynamicCache()) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ).unsqueeze(0) # Attention mask. if attention_mask is not None: @@ -271,9 +294,10 @@ def forward( attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + elif attention_mask is None: # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_length, None) + attention_mask = _create_causal_mask(position_ids, past_seen_tokens, None) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -306,19 +330,17 @@ def forward( output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i in range(len(self.h)): - block, layer_past = self.h[i], past_key_values[i] + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past=layer_past, + past_key_value=past_key_values, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -329,8 +351,6 @@ def forward( output_attentions=output_attentions, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -344,9 +364,17 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_values = past_key_values if use_cache else None + if return_legacy_cache: + past_key_values = ( + past_key_values.self_attention_cache.to_legacy_cache() + if self.config.add_cross_attention + else past_key_values.to_legacy_cache() + ) + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5dd9362ee..af233870b 100644 --- a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -7,11 +7,12 @@ """PyTorch GPTBigCode model.""" -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint -from transformers.cache_utils import Cache +from torch import nn +from transformers.cache_utils import Cache, EncoderDecoderCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -25,6 +26,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE # Fused kernels @@ -55,71 +57,43 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x -class QEffGPTBigCodeAttention(GPTBigCodeAttention): - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - dtype = query.dtype - softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype - upcast = dtype != softmax_dtype - unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 - scale_factor = torch.tensor(1 / self.head_dim**0.5, dtype=torch.float32) - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-1) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - - attn_weights = (scale_factor * torch.bmm(query, key)).view(attn_shape) - - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, mask_value, attn_weights) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = self.attn_dropout(attn_weights) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - # Mask heads if we want to - if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) - attn_weights = attn_weights * head_mask + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - else: - attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights - return attn_output, attn_weights +class QEffGPTBigCodeAttention(GPTBigCodeAttention): def forward( self, hidden_states: torch.Tensor, @@ -136,49 +110,69 @@ def forward( Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: - if encoder_hidden_states is not None: + input_shape = hidden_states.shape[:-1] + + if layer_past is not None: + if isinstance(layer_past, EncoderDecoderCache): + if self.is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_past.cross_attention_cache + else: + curr_past_key_value = layer_past.self_attention_cache + else: + curr_past_key_value = layer_past + + if self.is_cross_attention: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." ) + if layer_past is not None: + # reuse k,v, cross_attentions + key = curr_past_key_value.layers[self.layer_idx].keys + value = curr_past_key_value.layers[self.layer_idx].values + else: + query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2) + key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1) - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key, value = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2) else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + if self.multi_query: + query, key, value = ( + self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3) + ) + query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2) + else: + query, key, value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split(3 * [self.head_dim], dim=3) + ) if layer_past is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - key, value = layer_past.update3D(key, value, self.layer_idx, cache_kwargs) - present = (layer_past.key_cache[self.layer_idx], layer_past.value_cache[self.layer_idx]) - - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + key, value = curr_past_key_value.update(key, value, self.layer_idx, cache_kwargs) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if self.is_cross_attention: + layer_past.is_updated[self.layer_idx] = True + + attention_interface = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask, + scaling=self.scaling, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffGPTBigCodeBlock(GPTBigCodeBlock): @@ -232,27 +226,22 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + hidden_states = residual + feed_forward_hidden_states - return outputs # hidden_states, present, (attentions, cross_attentions) + return (hidden_states,) + outputs class QEffGPTBigCodeModel(GPTBigCodeModel): def forward( self, input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, + past_key_values: Optional[list[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -273,10 +262,9 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] @@ -291,23 +279,19 @@ def forward( if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - position_ids = position_ids.view(-1, seq_length).long() + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) - # update attention mask for Cloud Ai 100 - self_attention_mask = _create_causal_mask(position_ids, past_length) - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - if self.multi_query: - self_attention_mask = self_attention_mask.transpose(1, 2) + position_ids = position_ids.view(-1, seq_length).long() - attention_mask = self_attention_mask + # update attention mask for Cloud AI 100 + attention_mask = _create_causal_mask(position_ids, past_length) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -325,14 +309,13 @@ def forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) position_ids1 = position_ids.clone() position_ids1[position_ids1 == -1] = 0 position_embeds = self.wpe(position_ids1) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds @@ -340,11 +323,10 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) - presents = [] if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -362,13 +344,11 @@ def forward( ) hidden_states = outputs[0] - if use_cache: - presents.append(outputs[1]) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) hidden_states = self.ln_f(hidden_states) @@ -377,9 +357,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -390,7 +373,7 @@ class QEffGPTBigCodeForCausalLM(GPTBigCodeForCausalLM): def forward( self, input_ids: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -442,14 +425,3 @@ def forward( attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index 6b11e3f4f..dc3e5e6d2 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -134,14 +134,7 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "position_ids": position_ids, - "batch_index": batch_index, - "partial_rotation_size": self.rotary_dim, - "cache_position": cache_position, - } + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -150,11 +143,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, layer_past) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class QEffGPTJBlock(GPTJBlock): @@ -172,7 +161,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, @@ -183,18 +172,11 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_output + feed_forward_hidden_states + residual + hidden_states = attn_outputs + feed_forward_hidden_states + residual - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions) + return hidden_states, attn_weights class QEffGPTJModel(GPTJModel): @@ -311,6 +293,7 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() if use_cache else None + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 13b308547..2a2d47d6d 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -140,10 +140,8 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8f840b4b4..c085f6a5e 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torch import nn from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -127,15 +128,17 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + bsz, q_len, _ = hidden_states.size() - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -149,23 +152,46 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface = eager_attention_forward - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - dropout = 0.0 if not self.training else self.attention_dropout - attn_weights = F.dropout(attn_weights, p=dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights class QEffGraniteMoeModel(GraniteMoeModel): @@ -212,9 +238,14 @@ def forward( inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama - return_legacy_cache = False + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + # if not isinstance(past_key_values, (type(None), Cache)): + # raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + # if use_cache and past_key_values is None: + # past_key_values = QEffDynamicCache() + if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True if past_key_values is None: past_key_values = QEffDynamicCache() else: @@ -230,39 +261,26 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions ) - hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = None + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - elif batch_index is not None: + if batch_index is not None: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -287,9 +305,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -298,15 +313,15 @@ def forward( # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -486,6 +501,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" @@ -537,6 +553,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -544,7 +561,8 @@ def forward( logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = logits / self.config.logits_scaling loss = None diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 21516ff5f..567a8e070 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -86,20 +86,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx) + kv_seq_len = past_key_value.get_seq_length(layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "batch_index": batch_index, - "position_ids": position_ids, - } # Specific to RoPE models + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 5106e2dc4..f2a68f80e 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -113,7 +113,6 @@ def eager_attention_forward( attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -134,7 +133,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -142,26 +140,24 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states, **kwargs) - key_states = self.k_proj(hidden_states, **kwargs) - value_states = self.v_proj(hidden_states, **kwargs) - - query_states = query_states.view(hidden_shape).transpose(1, 2) - key_states = key_states.view(hidden_shape).transpose(1, 2) - value_states = value_states.view(hidden_shape).transpose(1, 2) + kwargs.pop("output_attentions", None) + kwargs.pop("return_dict", None) + kwargs.pop("labels", None) + kwargs.pop("position_embeddings", None) - kv_seq_len = key_states.shape[-2] + query_states = self.q_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -172,10 +168,10 @@ def forward( scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output, **kwargs) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights class QEffLlamaDecoderLayer(LlamaDecoderLayer): @@ -192,7 +188,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -202,13 +197,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -221,21 +215,12 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffLlamaModel(LlamaModel): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - add new args cache idx for the kv retention """ def forward( @@ -247,18 +232,14 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -286,29 +267,22 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -318,20 +292,16 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=past_key_values, hidden_states=all_hidden_states, - attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class QEffLlamaForCausalLM(LlamaForCausalLM): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - add new args cache idx for the kv retention """ def forward( @@ -342,21 +312,14 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -366,19 +329,15 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 4b957ebec..212fe16ae 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -20,6 +20,7 @@ from transformers.models.llama4.modeling_llama4 import ( Llama4ForCausalLM, Llama4ForConditionalGeneration, + Llama4Router, Llama4TextAttention, Llama4TextConfig, Llama4TextDecoderLayer, @@ -402,6 +403,11 @@ def __qeff_init__(self): self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) +class QEffLlama4Router(Llama4Router): + def forward(self, hidden_states): + return torch.matmul(hidden_states, self.weight.T) + + class QEffLlama4TextMoe(Llama4TextMoe): def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape @@ -475,10 +481,6 @@ def forward( key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - ## if self.use_rope: # the 16E model skips rope for long context on certain layers query_states, key_states = qeff_apply_rotary_emb( query_states, key_states, position_embeddings.to(query_states.device) @@ -539,7 +541,6 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - chunk_causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, @@ -552,10 +553,6 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - # use local attention mask for ROPE layers - if self.use_chunked_attention: - attention_mask = chunk_causal_mask - hidden_states = self.input_layernorm(hidden_states) # Self Attention @@ -654,13 +651,17 @@ def forward( position_ids = cache_position.unsqueeze(0) causal_mask = _create_causal_mask( - position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2] + position_ids=position_ids, target_length=past_key_values.layers[3].keys.shape[-2] ) chunk_position_ids = torch.where( position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids ) - target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size)) + target_length = min(past_key_values.layers[0].keys.shape[-2], torch.tensor(self.config.attention_chunk_size)) chunk_causal_mask = _create_causal_mask(position_ids=chunk_position_ids, target_length=target_length) + causal_mask_mapping = { + "full_attention": causal_mask, + "chunked_attention": chunk_causal_mask, + } # embed positions hidden_states = inputs_embeds @@ -678,8 +679,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, - chunk_causal_mask=chunk_causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, @@ -720,6 +720,11 @@ class QEffLlama4ForCausalLM(Llama4ForCausalLM): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + logger.warning( + "Current output differs from HF output due to a bug in TF v_4.55. Switch to branch release/v_1.20 for TF match. Refer link: https://github.com/huggingface/transformers/pull/39501" + ) + def forward( self, input_ids: torch.LongTensor = None, diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f5e60c5de..9fd1ed782 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -98,7 +98,6 @@ def forward( # Reshape the query, key, and value tensors. query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - kv_seq_len = position_ids.shape[-1] if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -106,7 +105,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) @@ -120,13 +119,12 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it + if attention_mask is not None: attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -179,7 +177,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states, past_key_values + return hidden_states class QEffLlamaSwiftKVModel(nn.Module): @@ -207,9 +205,7 @@ def _run_swiftkv_layers( ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states, past_key_values = layer( - hidden_states, position_ids, past_key_values, causal_mask, batch_index - ) + hidden_states = layer(hidden_states, position_ids, past_key_values, causal_mask, batch_index) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -327,7 +323,7 @@ def forward( for layer_idx in range(self.config.num_key_value_layers): layer = self.layers[layer_idx] - hidden_states, next_decoder_cache = layer( + hidden_states = layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, @@ -351,7 +347,6 @@ def forward( 1, 2 ) - kv_seq_len = key_states.shape[-2] if past_key_values is not None: if self_attn.layer_idx is None: raise ValueError( @@ -359,11 +354,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self_attn.layer_idx) + kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -392,7 +387,7 @@ def forward( return hidden_states, next_cache -class QEffLlamaSwiftKVForCausalLM(PreTrainedModel): # +class QEffLlamaSwiftKVForCausalLM(PreTrainedModel): config_class = QEffLlamaSwiftKVConfig def __init__(self, config: QEffLlamaSwiftKVConfig): diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 99384cb55..e260beb05 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -49,6 +49,7 @@ def __init__(self, model): self.model = model self.config = self.model.config self.language_model = self.model.language_model + self.lm_head = self.model.lm_head def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) @@ -60,14 +61,19 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va vision_embeds_expanded = vision_embeds[indices0, indices1] vision_embeds_expanded = torch.where(mask.unsqueeze(-1), vision_embeds_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, vision_embeds_expanded) - outputs = self.model.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + return_dict=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return outputs.logits, vision_embeds, image_idx, outputs.past_key_values + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, vision_embeds, image_idx, outputs.past_key_values class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration): @@ -104,9 +110,15 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val position_ids=position_ids, past_key_values=past_key_values, ) + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) - return outputs.logits, pixel_values, image_idx, outputs.past_key_values + return logits, pixel_values, image_idx, outputs.past_key_values def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 23434fc18..2fa1d9234 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -92,11 +92,11 @@ def forward(self, pixel_values, image_sizes): new_height = int(round(original_height * scale_factor, 7)) padding = (current_height - new_height) // 2 image_feature = image_feature[:, padding : current_height - padding, :] - if self.model.image_newline is not None: + if self.model.model.image_newline is not None: image_feature = torch.cat( ( image_feature, - self.model.image_newline[:, None, None] + self.model.model.image_newline[:, None, None] .expand(*image_feature.shape[:-1], 1) .to(image_feature.device, image_feature.dtype), ), @@ -106,8 +106,10 @@ def forward(self, pixel_values, image_sizes): image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] - if self.model.image_newline is not None: - image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature)), dim=0) + if self.model.model.image_newline is not None: + image_feature = torch.cat( + (image_feature, self.model.model.image_newline[None].to(image_feature)), dim=0 + ) new_image_features.append(image_feature) image_features = torch.cat(new_image_features, dim=0) return image_features.unsqueeze(0) @@ -119,6 +121,7 @@ def __init__(self, model): self.model = model self.config = self.model.config self.language_model = self.model.language_model + self.lm_head = self.model.lm_head def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids) @@ -137,7 +140,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va past_key_values=past_key_values, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) - return outputs.logits, vision_embeds, image_idx, outputs.past_key_values + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + return logits, vision_embeds, image_idx, outputs.past_key_values class QEffLlavaNextForConditionalGeneration(LlavaNextForConditionalGeneration): diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 60b1c929d..ca23cc144 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -24,7 +24,6 @@ MistralForCausalLM, MistralModel, MistralRotaryEmbedding, - logger, repeat_kv, rotate_half, ) @@ -108,7 +107,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -146,6 +144,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, # kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -159,15 +158,12 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -179,12 +175,12 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights class QEffMistralDecoderLayer(MistralDecoderLayer): @@ -225,7 +221,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -244,15 +240,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffMistralModel(MistralModel): @@ -298,10 +286,6 @@ def forward( if use_cache and not isinstance(past_key_values, Cache) and not self.training: past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" - ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -327,7 +311,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, @@ -339,11 +323,6 @@ def forward( **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -410,10 +389,8 @@ def forward( # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] - - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index ef51c3421..9b9e3448a 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -7,7 +7,7 @@ """PyTorch Mixtral model.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -109,7 +109,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -148,7 +147,6 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -156,21 +154,15 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "batch_index": batch_index, - "position_ids": position_ids, - } + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -179,13 +171,12 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights MIXTRAL_ATTENTION_CLASSES = { @@ -255,7 +246,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -268,9 +258,6 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. @@ -289,14 +276,13 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -309,18 +295,7 @@ def forward( hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs + return hidden_states # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral @@ -342,14 +317,12 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **flash_attn_kwargs, + **kwargs, ) -> Union[Tuple, MoeModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) @@ -391,35 +364,24 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, - output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - **flash_attn_kwargs, + **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -429,16 +391,12 @@ def forward( if use_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = MoeModelOutputWithPast( + return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, ) - return output if return_dict else output.to_tuple() - class QEffMixtralForCausalLM(MixtralForCausalLM): """ @@ -456,46 +414,16 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MixtralForCausalLM - - >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -510,7 +438,6 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, @@ -520,9 +447,8 @@ def forward( # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = self.lm_head(hidden_states).float() aux_loss = None if output_router_logits: diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 8a98c4c96..cb24f1de4 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -7,7 +7,6 @@ """PyTorch Mllama model.""" -import math from typing import List, Optional, Tuple, Union import torch @@ -24,6 +23,7 @@ MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, MllamaForConditionalGeneration, + MllamaModel, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, @@ -49,6 +49,53 @@ NUM_CHANNEL = 3 +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def eager_self_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -132,7 +179,6 @@ def forward( past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, use_cache: bool = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -144,8 +190,8 @@ def forward( # elif past_key_value is not None: # Fetch old cache - key_states_old = past_key_value.key_cache[self.layer_idx] - value_states_old = past_key_value.value_cache[self.layer_idx] + key_states_old = past_key_value.layers[self.layer_idx].keys + value_states_old = past_key_value.layers[self.layer_idx].values # if cross_attention_states is not None: # Compute new KV states @@ -166,30 +212,25 @@ def forward( value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new) # Update the image cache - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + past_key_value.layers[self.layer_idx].keys = key_states + past_key_value.layers[self.layer_idx].values = value_states key_states = self.k_norm(key_states) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attention_interface = eager_attention_forward - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): @@ -210,7 +251,6 @@ def forward( past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, - output_attentions: bool = False, use_cache: bool = False, cache_position=None, **kwargs, @@ -225,7 +265,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -233,44 +272,32 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, "batch_index": batch_index, "position_ids": position_ids, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attention_interface = eager_self_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class QEffMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): @@ -290,7 +317,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -301,9 +327,6 @@ def forward( attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -322,13 +345,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -341,15 +363,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffMllamaTextCrossAttentionTwoQPC(MllamaTextCrossAttention): @@ -367,7 +381,6 @@ def forward( past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, use_cache: bool = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -393,35 +406,30 @@ def forward( ) elif past_key_value is not None: key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], + past_key_value.layers[self.layer_idx].keys, + past_key_value.layers[self.layer_idx].values, ) else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - key_states = self.k_norm(key_states) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + attention_interface = eager_attention_forward - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class QEffMllamaCrossAttentionDecoderLayer(MllamaCrossAttentionDecoderLayer): @@ -441,7 +449,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, @@ -449,13 +456,12 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states, attn_weights = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, cache_position=cache_position, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -467,15 +473,7 @@ def forward( hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs + return hidden_states class QEffMllamaVisionModel(MllamaVisionModel): @@ -484,16 +482,7 @@ def forward( pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, aspect_ratio_mask: torch.Tensor, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) @@ -546,10 +535,8 @@ def forward( output = self.transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=True, - output_attentions=output_attentions, ) - hidden_state = output[0] + hidden_state = output.last_hidden_state hidden_state = self.layernorm_post(hidden_state) @@ -564,10 +551,8 @@ def forward( global_output = self.global_transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, ) - hidden_state = global_output[0] + hidden_state = global_output.last_hidden_state # Remove padding form hidden state hidden_state = hidden_state.reshape( @@ -577,9 +562,8 @@ def forward( hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output - all_intermediate_hidden_states = output[1] + all_intermediate_hidden_states = [output.last_hidden_state for _ in self.intermediate_layers_indices] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] # Remove padding from intermediate hidden states intermediate_hidden_states = intermediate_hidden_states.reshape( @@ -589,29 +573,11 @@ def forward( intermediate_hidden_states = intermediate_hidden_states.reshape( batch_size, num_concurrent_media, num_tiles, num_patches, -1 ) - # Concatenate final hidden state and intermediate hidden states hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) - if output_hidden_states: - hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) - else: - hidden_states = None - - if output_attentions: - # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) - attentions = tuple(output[2]) + global_attn - else: - attentions = None - - if not return_dict: - return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) - return BaseModelOutput( last_hidden_state=hidden_state, - hidden_states=hidden_states, - attentions=attentions, ) @@ -634,17 +600,9 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -664,6 +622,7 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -672,14 +631,7 @@ def forward( # embed positions hidden_states = inputs_embeds - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - if output_hidden_states: - all_hidden_states += (hidden_states,) - # For text-only path we should skip cross attention layers. # Let's check if the layer is cross attention layer and if we have cross attention states # or cached cross attention states. @@ -698,7 +650,7 @@ def forward( if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, @@ -706,21 +658,13 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() @@ -728,8 +672,6 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, ) @@ -753,17 +695,8 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -776,9 +709,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, ) @@ -813,14 +743,14 @@ def forward( aspect_ratio_mask=aspect_ratio_mask, ) cross_attention_states = vision_outputs[0] - cross_attention_states = self.model.multi_modal_projector(cross_attention_states).reshape( - -1, cross_attention_states.shape[-2], self.model.hidden_size + cross_attention_states = self.model.model.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.model.model.hidden_size ) bsz = pixel_values.shape[0] outputs = [] for i in self.cross_attention_layers: - cross_attn = self.model.language_model.model.layers[i].cross_attn + cross_attn = self.model.language_model.layers[i].cross_attn key_states = cross_attn.k_proj(cross_attention_states) value_states = cross_attn.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) @@ -831,13 +761,7 @@ def forward( return outputs -class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration): - def get_qeff_vision_encoder(self): - return QEffMllamaVisionEncoder(self) - - def get_qeff_language_decoder(self): - return self - +class QEffMllamaModel(MllamaModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -850,30 +774,13 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, - batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + **kwargs, + ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and cross_attention_states is not None: raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") @@ -886,11 +793,8 @@ def forward( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, ) - cross_attention_states = vision_outputs[0] + cross_attention_states = vision_outputs.last_hidden_state cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size ) @@ -916,17 +820,64 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, - batch_index=batch_index, use_cache=use_cache, inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, cache_position=cache_position, + **kwargs, ) - return outputs.logits, image_idx, outputs.past_key_values, pixel_values + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffMllamaVisionEncoder(self) + + def get_qeff_language_decoder(self): + return self + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + image_idx: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + ) + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() + return logits, image_idx, outputs.past_key_values, pixel_values def get_dummy_inputs(self, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE @@ -974,22 +925,21 @@ def get_dummy_inputs(self, kv_offload: bool = False): ) lang_inputs["past_key_values"] = QEffDynamicCache() - lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers - lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].append_new_layers(num_hidden_layers - 1) for i in range(num_hidden_layers): if i in cross_attention_layers: idx = cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = torch.zeros( + lang_inputs["past_key_values"].layers[i].keys = torch.zeros( 1, num_key_value_heads, image_tokens_len, head_dim ) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros( + lang_inputs["past_key_values"].layers[i].values = torch.zeros( 1, num_key_value_heads, image_tokens_len, head_dim ) else: - lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) + lang_inputs["past_key_values"].layers[i].keys = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) + lang_inputs["past_key_values"].layers[i].values = torch.zeros(1, num_key_value_heads, CTX_LEN, head_dim) lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 89d474e15..9bf6a4422 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -12,6 +12,7 @@ import torch import torch.utils.checkpoint from torch import nn +from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -50,20 +51,12 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - if len(past_key_value) != 0: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - pkv = QEffDynamicCache() - pkv.key_cache.append(past_key_value[0]) - pkv.value_cache.append(past_key_value[1]) - key_states, value_states = pkv.update(key_states, value_states, 0, cache_kwargs) - if use_cache: - past_key_value = (pkv.key_cache[0], pkv.value_cache[0]) - else: - past_key_value = None + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value.get_seq_length() if position_bias is not None: if len(position_bias.shape) != 3: @@ -137,15 +130,7 @@ def forward( # MLP. output = self.ffn(layernorm_output, residual) - outputs = (output,) - - if use_cache: - outputs += (past_key_value,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs # hidden_states, present, attentions + return output, attn_weights class QEFfMptModel(MptModel): @@ -190,18 +175,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + hidden_states = inputs_embeds - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) @@ -213,13 +198,13 @@ def forward( elif attention_mask is None: causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values_length) - for block, layer_past in zip(self.blocks, past_key_values): + for block in self.blocks: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -228,24 +213,27 @@ def forward( position_bias=alibi, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.norm_f(hidden_states) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 602a73c84..4b5234a5a 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -155,18 +155,14 @@ def forward( query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, - "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, } @@ -203,7 +199,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -218,9 +213,6 @@ def forward( position_ids (`torch.LongTensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -243,7 +235,6 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -257,10 +248,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(hidden_states) - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - return outputs + return hidden_states class QEffPhi3Model(Phi3Model): @@ -280,20 +268,15 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" @@ -320,29 +303,22 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -352,13 +328,11 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() if use_cache else None - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, - attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class QEffPhi3ForCausalLM(Phi3ForCausalLM): @@ -377,45 +351,14 @@ def forward( batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - Example: - ```python - >>> from transformers import AutoTokenizer, Phi3ForCausalLM - >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") - >>> prompt = "This is an example script ." - >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -425,18 +368,15 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 7be72078c..99cd5374a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -1,661 +1,670 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import warnings -from types import MethodType -from typing import Callable, Optional, Tuple, Union - -from torch import nn -from transformers.models.codegen.modeling_codegen import ( - CodeGenAttention, - CodeGenBlock, - CodeGenForCausalLM, - CodeGenModel, -) -from transformers.models.falcon.modeling_falcon import ( - FalconAttention, - FalconDecoderLayer, - FalconForCausalLM, - FalconModel, -) -from transformers.models.gemma.modeling_gemma import ( - GemmaAttention, - GemmaDecoderLayer, - GemmaForCausalLM, - GemmaModel, - GemmaRMSNorm, -) -from transformers.models.gemma2.modeling_gemma2 import ( - Gemma2Attention, - Gemma2DecoderLayer, - Gemma2ForCausalLM, - Gemma2Model, - Gemma2RMSNorm, -) -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3Attention, - Gemma3DecoderLayer, - Gemma3ForCausalLM, - Gemma3ForConditionalGeneration, - Gemma3RMSNorm, - Gemma3TextModel, -) -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model -from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( - GPTBigCodeAttention, - GPTBigCodeBlock, - GPTBigCodeForCausalLM, - GPTBigCodeModel, -) -from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel -from transformers.models.granite.modeling_granite import ( - GraniteAttention, - GraniteForCausalLM, - GraniteModel, - GraniteRMSNorm, -) -from transformers.models.granitemoe.modeling_granitemoe import ( - GraniteMoeAttention, - GraniteMoeForCausalLM, - GraniteMoeModel, - GraniteMoeMoE, - GraniteMoeParallelExperts, - GraniteMoeRMSNorm, - GraniteMoeRotaryEmbedding, - GraniteMoeTopKGating, -) -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, -) -from transformers.models.llama4.modeling_llama4 import ( - Llama4ForCausalLM, - Llama4ForConditionalGeneration, - Llama4TextAttention, - Llama4TextDecoderLayer, - Llama4TextExperts, - Llama4TextModel, - Llama4TextMoe, - Llama4TextRMSNorm, - Llama4VisionAttention, - Llama4VisionModel, -) -from transformers.models.llava.modeling_llava import ( - LlavaForConditionalGeneration, -) -from transformers.models.llava_next.modeling_llava_next import ( - LlavaNextForConditionalGeneration, -) -from transformers.models.mistral.modeling_mistral import ( - MistralAttention, - MistralDecoderLayer, - MistralForCausalLM, - MistralModel, - MistralRMSNorm, -) -from transformers.models.mixtral.modeling_mixtral import ( - MixtralAttention, - MixtralDecoderLayer, - MixtralForCausalLM, - MixtralModel, - MixtralRMSNorm, - MixtralSparseMoeBlock, -) -from transformers.models.mllama.modeling_mllama import ( - MllamaCrossAttentionDecoderLayer, - MllamaForCausalLM, - MllamaForConditionalGeneration, - MllamaRotaryEmbedding, - MllamaSelfAttentionDecoderLayer, - MllamaTextCrossAttention, - MllamaTextModel, - MllamaTextRMSNorm, - MllamaTextSelfAttention, - MllamaVisionModel, -) -from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel -from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel -from transformers.models.phi3.modeling_phi3 import ( - Phi3Attention, - Phi3DecoderLayer, - Phi3ForCausalLM, - Phi3Model, - Phi3RMSNorm, -) -from transformers.models.qwen2.modeling_qwen2 import ( - Qwen2Attention, - Qwen2DecoderLayer, - Qwen2ForCausalLM, - Qwen2Model, - Qwen2RMSNorm, -) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - Qwen3MoeAttention, - Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeModel, - Qwen3MoeRMSNorm, - Qwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock, -) -from transformers.models.starcoder2.modeling_starcoder2 import ( - Starcoder2Attention, - Starcoder2DecoderLayer, - Starcoder2ForCausalLM, - Starcoder2Model, -) -from transformers.models.whisper.modeling_whisper import ( - WhisperAttention, - WhisperDecoder, - WhisperDecoderLayer, - WhisperEncoder, - WhisperForConditionalGeneration, - WhisperModel, - WhisperPositionalEmbedding, -) - -from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform -from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC -from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.codegen.modeling_codegen import ( - QEffCodeGenAttention, - QeffCodeGenBlock, - QEffCodeGenForCausalLM, - QEffCodeGenModel, -) -from QEfficient.transformers.models.falcon.modeling_falcon import ( - QEffFalconAttention, - QEffFalconDecoderLayer, - QEffFalconForCausalLM, - QEffFalconModel, -) -from QEfficient.transformers.models.gemma.modeling_gemma import ( - QEffGemmaAttention, - QEffGemmaDecoderLayer, - QEffGemmaForCausalLM, - QEffGemmaModel, -) -from QEfficient.transformers.models.gemma2.modeling_gemma2 import ( - QEffGemma2Attention, - QEffGemma2DecoderLayer, - QEffGemma2ForCausalLM, - QEffGemma2Model, -) -from QEfficient.transformers.models.gemma3.modeling_gemma3 import ( - QEffGemma3Attention, - QEffGemma3CustomRMSNormAIC, - QEffGemma3DecoderLayer, - QEffGemma3ForCausalLMModel, - QEffGemma3ForConditionalGeneration, - QEffGemma3TextModel, -) -from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( - QEffGPT2Attention, - QEffGPT2Block, - QEffGPT2LMHeadModel, - QEffGPT2Model, -) -from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( - QEffGPTBigCodeAttention, - QEffGPTBigCodeBlock, - QEffGPTBigCodeForCausalLM, - QEffGPTBigCodeModel, -) -from QEfficient.transformers.models.gptj.modeling_gptj import ( - QEffGPTJAttention, - QEffGPTJBlock, - QEffGPTJForCausalLM, - QEffGPTJModel, -) -from QEfficient.transformers.models.granite.modeling_granite import ( - QEffGraniteAttention, - QEffGraniteForCausalLM, - QEffGraniteModel, -) -from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( - QEffGraniteMoeAttention, - QEffGraniteMoeForCausalLM, - QEffGraniteMoeModel, - QEffGraniteMoeMoE, - QEffGraniteMoeParallelExperts, - QEffGraniteMoeRotaryEmbedding, - QEffGraniteMoeTopKGating, -) -from QEfficient.transformers.models.grok_1.modeling_grok1 import ( - QEFFGrok1CustomRMSNormAIC, - QEffGrok1DecoderLayer, - QEffGrok1Model, - QEffGrok1ModelForCausalLM, - QEffGrok1MoeBlock, - QEffGrok1MultiHeadAttention, -) -from QEfficient.transformers.models.internvl.modeling_internvl import ( - QEffInternVisionEmbeddings, - QEffInternVLModel, -) -from QEfficient.transformers.models.llama.modeling_llama import ( - QEffLlamaAttention, - QEffLlamaDecoderLayer, - QEffLlamaForCausalLM, - QEffLlamaModel, -) -from QEfficient.transformers.models.llama4.modeling_llama4 import ( - QEffLlama4ForCausalLM, - QEffLlama4ForConditionalGeneration, - QEffLlama4TextAttention, - QEffLlama4TextDecoderLayer, - QEffLlama4TextExperts, - QEffLlama4TextModel, - QEffLlama4TextMoe, - QEffLlama4VisionAttention, - QEffLlama4VisionModel, -) -from QEfficient.transformers.models.llava.modeling_llava import ( - QEffLlavaForConditionalGeneration, -) -from QEfficient.transformers.models.llava_next.modeling_llava_next import ( - QEffLlavaNextForConditionalGeneration, -) -from QEfficient.transformers.models.mistral.modeling_mistral import ( - QEffMistralAttention, - QEffMistralDecoderLayer, - QEffMistralForCausalLM, - QEffMistralModel, -) -from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( - QEffMixtralAttention, - QeffMixtralDecoderLayer, - QEffMixtralForCausalLM, - QEffMixtralModel, - QEffMixtralSparseMoeBlock, -) -from QEfficient.transformers.models.mllama.modeling_mllama import ( - QEffMllamaCrossAttentionDecoderLayer, - QEffMllamaForCausalLM, - QEffMllamaForConditionalGeneration, - QEffMllamaRotaryEmbedding, - QEffMllamaSelfAttentionDecoderLayer, - QEffMllamaTextCrossAttentionSingleQPC, - QEffMllamaTextCrossAttentionTwoQPC, - QEffMllamaTextModel, - QEffMllamaTextSelfAttention, - QEffMllamaVisionModel, -) -from QEfficient.transformers.models.mpt.modeling_mpt import ( - QEffMptAttention, - QEffMptBlock, - QEffMptForCausalLM, - QEFfMptModel, -) -from QEfficient.transformers.models.phi.modeling_phi import ( - QEffPhiAttention, - QEffPhiDecoderLayer, - QEffPhiForCausalLM, - QEffPhiModel, -) -from QEfficient.transformers.models.phi3.modeling_phi3 import ( - QEffPhi3Attention, - QEffPhi3DecoderLayer, - QEffPhi3ForCausalLM, - QEffPhi3Model, -) -from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( - QEffQwen2Attention, - QEffQwen2DecoderLayer, - QEffQwen2ForCausalLM, - QEffQwen2Model, -) -from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( - QEffQwen3MoeAttention, - QEffQwen3MoeDecoderLayer, - QEffQwen3MoeForCausalLM, - QEffQwen3MoeModel, - QEffQwen3MoeRotaryEmbedding, - QEffQwen3MoeSparseMoeBlock, -) -from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( - QEffStarcoder2Attention, - QEFFStarcoder2DecoderLayer, - QEffStarcoder2ForCausalLM, - QEffStarcoder2Model, -) -from QEfficient.transformers.models.whisper.modeling_whisper import ( - QEffWhisperAttention, - QEffWhisperDecoder, - QEffWhisperDecoderLayer, - QEffWhisperEncoder, - QEffWhisperForConditionalGeneration, - QEffWhisperModel, - QEffWhisperPositionalEmbedding, -) -from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry -from QEfficient.transformers.sampler.sampler import sampler_forward -from QEfficient.transformers.spd.spd_transform_forward import tlm_forward - -SPD_TARGET = "target" - - -class CustomOpsTransform(ModuleMappingTransform): - _module_mapping = { - GemmaRMSNorm: GemmaCustomRMSNormAIC, - Gemma2RMSNorm: GemmaCustomRMSNormAIC, - LlamaRMSNorm: CustomRMSNormAIC, - Llama4TextRMSNorm: CustomRMSNormAIC, - MistralRMSNorm: CustomRMSNormAIC, - MixtralRMSNorm: CustomRMSNormAIC, - Phi3RMSNorm: CustomRMSNormAIC, - Qwen2RMSNorm: CustomRMSNormAIC, - MllamaTextRMSNorm: CustomRMSNormAIC, - GraniteRMSNorm: CustomRMSNormAIC, - GraniteMoeRMSNorm: CustomRMSNormAIC, - Qwen3MoeRMSNorm: CustomRMSNormAIC, - Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, - } - - -class KVCacheTransform(ModuleMappingTransform): - _module_mapping = { - # CodeGen - CodeGenAttention: QEffCodeGenAttention, - CodeGenBlock: QeffCodeGenBlock, - CodeGenModel: QEffCodeGenModel, - CodeGenForCausalLM: QEffCodeGenForCausalLM, - # Falcon - FalconAttention: QEffFalconAttention, - FalconDecoderLayer: QEffFalconDecoderLayer, - FalconModel: QEffFalconModel, - FalconForCausalLM: QEffFalconForCausalLM, - # GPT2 - GPT2Attention: QEffGPT2Attention, - GPT2Block: QEffGPT2Block, - GPT2Model: QEffGPT2Model, - GPT2LMHeadModel: QEffGPT2LMHeadModel, - # GPTJ - GPTJAttention: QEffGPTJAttention, - GPTJBlock: QEffGPTJBlock, - GPTJModel: QEffGPTJModel, - GPTJForCausalLM: QEffGPTJForCausalLM, - # Llama - LlamaAttention: QEffLlamaAttention, - LlamaDecoderLayer: QEffLlamaDecoderLayer, - LlamaModel: QEffLlamaModel, - LlamaForCausalLM: QEffLlamaForCausalLM, - # Llama4 - Llama4TextAttention: QEffLlama4TextAttention, - Llama4ForCausalLM: QEffLlama4ForCausalLM, - Llama4TextDecoderLayer: QEffLlama4TextDecoderLayer, - Llama4TextModel: QEffLlama4TextModel, - Llama4TextMoe: QEffLlama4TextMoe, - Llama4ForConditionalGeneration: QEffLlama4ForConditionalGeneration, - Llama4VisionAttention: QEffLlama4VisionAttention, - Llama4VisionModel: QEffLlama4VisionModel, - Llama4TextExperts: QEffLlama4TextExperts, - # Llava - LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, - # Llava Next - LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration, - # Gemma - GemmaAttention: QEffGemmaAttention, - GemmaDecoderLayer: QEffGemmaDecoderLayer, - GemmaModel: QEffGemmaModel, - GemmaForCausalLM: QEffGemmaForCausalLM, - # Qwen3Moe - Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, - Qwen3MoeModel: QEffQwen3MoeModel, - Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, - Qwen3MoeAttention: QEffQwen3MoeAttention, - Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, - # Gemma2 - Gemma2Attention: QEffGemma2Attention, - Gemma2DecoderLayer: QEffGemma2DecoderLayer, - Gemma2Model: QEffGemma2Model, - Gemma2ForCausalLM: QEffGemma2ForCausalLM, - # Gemma3 - Gemma3Attention: QEffGemma3Attention, - Gemma3DecoderLayer: QEffGemma3DecoderLayer, - Gemma3TextModel: QEffGemma3TextModel, - Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, - Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, - # Granite - GraniteModel: QEffGraniteModel, - GraniteForCausalLM: QEffGraniteForCausalLM, - GraniteAttention: QEffGraniteAttention, - # GraniteMoe - GraniteMoeModel: QEffGraniteMoeModel, - GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, - GraniteMoeAttention: QEffGraniteMoeAttention, - GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding, - GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, - GraniteMoeTopKGating: QEffGraniteMoeTopKGating, - GraniteMoeMoE: QEffGraniteMoeMoE, - # mllama - MllamaTextRMSNorm: CustomRMSNormAIC, - MllamaTextSelfAttention: QEffMllamaTextSelfAttention, - MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, - MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, - MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, - MllamaVisionModel: QEffMllamaVisionModel, - MllamaTextModel: QEffMllamaTextModel, - MllamaForCausalLM: QEffMllamaForCausalLM, - MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, - # Mistral - MistralAttention: QEffMistralAttention, - MistralDecoderLayer: QEffMistralDecoderLayer, - MistralModel: QEffMistralModel, - MistralForCausalLM: QEffMistralForCausalLM, - # Mixtral - MixtralAttention: QEffMixtralAttention, - MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, - MixtralDecoderLayer: QeffMixtralDecoderLayer, - MixtralModel: QEffMixtralModel, - MixtralForCausalLM: QEffMixtralForCausalLM, - # Mpt - MptAttention: QEffMptAttention, - MptBlock: QEffMptBlock, - MptModel: QEFfMptModel, - MptForCausalLM: QEffMptForCausalLM, - # Phi3 - Phi3Attention: QEffPhi3Attention, - Phi3DecoderLayer: QEffPhi3DecoderLayer, - Phi3Model: QEffPhi3Model, - Phi3ForCausalLM: QEffPhi3ForCausalLM, - # Phi - PhiAttention: QEffPhiAttention, - PhiDecoderLayer: QEffPhiDecoderLayer, - PhiModel: QEffPhiModel, - PhiForCausalLM: QEffPhiForCausalLM, - # Qwen2 - Qwen2Attention: QEffQwen2Attention, - Qwen2DecoderLayer: QEffQwen2DecoderLayer, - Qwen2Model: QEffQwen2Model, - Qwen2ForCausalLM: QEffQwen2ForCausalLM, - # Starcoder2 - Starcoder2Attention: QEffStarcoder2Attention, - Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, - Starcoder2Model: QEffStarcoder2Model, - Starcoder2ForCausalLM: QEffStarcoder2ForCausalLM, - # GptBigcode - GPTBigCodeAttention: QEffGPTBigCodeAttention, - GPTBigCodeBlock: QEffGPTBigCodeBlock, - GPTBigCodeModel: QEffGPTBigCodeModel, - GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, - # Whisper encoder and decoder layers - WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, - WhisperAttention: QEffWhisperAttention, - WhisperDecoderLayer: QEffWhisperDecoderLayer, - WhisperEncoder: QEffWhisperEncoder, - WhisperDecoder: QEffWhisperDecoder, - WhisperModel: QEffWhisperModel, - WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, - } - - @classmethod - def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: - model, transformed = super().apply(model) - return model, transformed - - -class SpDTransform: - """ - Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. - This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits - against the speculated tokens from a smaller model. - Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. - - ``Mandatory`` Args: - :model (nn.Module): PyTorch model. - - Returns: - :model (nn.Module): PyTorch model. - :transformed (bool): whether transformation was applied successfully. - """ - - # supported architectures - _module_mapping = { - # Llama - QEffLlamaForCausalLM, - QEffQwen2ForCausalLM, - } - - @classmethod - def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: - transformed = False - pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None) - if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: - return model, transformed - elif speculative_model_type not in ( - supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys()) - ): - raise ValueError( - f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}" - ) - elif (model_class := model.__class__) in cls._module_mapping: - model.forward = MethodType(tlm_forward, model) - if speculative_model_type != SPD_TARGET: - # build and attach draft mlp - pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] - model = build_and_attach_mlp( - model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs - ) - transformed = True - else: - raise NotImplementedError( - f"model class {model_class} does not yet support returning multiple logits to keep." - ) - kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp - return model, transformed - - -class SamplerTransform: - """ - Add nodes at the output of any generic QEffForCausalLM model to enable the - sampling of next tokens at the device (instead of the host) and return the - next tokens and/or probability distributions. - - Note: To achieve this, the generic QEffForCausalLM model must provide the - logits as output. - - ``Mandatory`` Args: - :model (nn.Module): PyTorch model. - - Returns: - :model (nn.Module): PyTorch model. - :transformed (bool): whether transformation was applied successfully. - """ - - # supported architectures - _module_mapping = { - # Llama - QEffLlamaForCausalLM, - } - - @classmethod - def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: - transformed = False - if qaic_config is None or not qaic_config.get("include_sampler", False): - return model, transformed - elif (model_class := model.__class__) in cls._module_mapping: - model.old_forward = model.forward - model.forward = MethodType(sampler_forward, model) - transformed = True - else: - raise NotImplementedError(f"Model class {model_class} does not support on device sampling.") - return model, transformed - - -class VlmKVOffloadTransform(ModuleMappingTransform): - # supported architectures - _module_mapping = { - # Llama - MllamaTextCrossAttention: QEffMllamaTextCrossAttentionTwoQPC, - } - - -class VlmNoKVOffloadTransform(ModuleMappingTransform): - # supported architectures - _module_mapping = { - # Llama - MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, - } - - -class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): - _match_string_replace_method = { - "InternVLChatModel": { - "forward": QEffInternVLModel.forward, - "get_dummy_inputs": QEffInternVLModel.get_dummy_inputs, - "get_specializations": QEffInternVLModel.get_specializations, - "get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes, - "get_output_names": QEffInternVLModel.get_output_names, - "get_inputs_info": QEffInternVLModel.get_inputs_info, - "get_qeff_vision_encoder": QEffInternVLModel.get_qeff_vision_encoder, - "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, - }, - "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, - # Mapping for grok1 model - "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, - "Grok1Model": { - "forward": QEffGrok1Model.forward, - "__qeff_init__": QEffGrok1Model.__qeff_init__, - }, - "DecoderLayer": { - "forward": QEffGrok1DecoderLayer.forward, - "__qeff_init__": QEffGrok1DecoderLayer.__qeff_init__, - }, - "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, - "MultiHeadAttention": { - "forward": QEffGrok1MultiHeadAttention.forward, - }, - "RMSNorm": { - "forward": QEFFGrok1CustomRMSNormAIC.forward, - }, - } - - _match_class_replace_method = {} - - -class PoolingTransform: - """ - Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. - The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. - """ - - @classmethod - def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: - transformed = False - pooling_method = ( - POOLING_MAP[pooling] - if isinstance(pooling, str) and pooling in POOLING_MAP - else validate_user_pooling_function(pooling) - ) - model = PooledModel(model, pooling_method) - warnings.warn("Pooling is applied to the model.") - return model, transformed +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import warnings +from types import MethodType +from typing import Callable, Optional, Tuple, Union + +from torch import nn +from transformers.models.codegen.modeling_codegen import ( + CodeGenAttention, + CodeGenBlock, + CodeGenForCausalLM, + CodeGenModel, +) +from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconDecoderLayer, + FalconForCausalLM, + FalconModel, +) +from transformers.models.gemma.modeling_gemma import ( + GemmaAttention, + GemmaDecoderLayer, + GemmaForCausalLM, + GemmaModel, + GemmaRMSNorm, +) +from transformers.models.gemma2.modeling_gemma2 import ( + Gemma2Attention, + Gemma2DecoderLayer, + Gemma2ForCausalLM, + Gemma2Model, + Gemma2RMSNorm, +) +from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3RMSNorm, + Gemma3TextModel, +) +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( + GPTBigCodeAttention, + GPTBigCodeBlock, + GPTBigCodeForCausalLM, + GPTBigCodeModel, +) +from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel +from transformers.models.granite.modeling_granite import ( + GraniteAttention, + GraniteForCausalLM, + GraniteModel, + GraniteRMSNorm, +) +from transformers.models.granitemoe.modeling_granitemoe import ( + GraniteMoeAttention, + GraniteMoeForCausalLM, + GraniteMoeModel, + GraniteMoeMoE, + GraniteMoeParallelExperts, + GraniteMoeRMSNorm, + GraniteMoeRotaryEmbedding, + GraniteMoeTopKGating, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.models.llama4.modeling_llama4 import ( + Llama4ForCausalLM, + Llama4ForConditionalGeneration, + Llama4Router, + Llama4TextAttention, + Llama4TextDecoderLayer, + Llama4TextExperts, + Llama4TextModel, + Llama4TextMoe, + Llama4TextRMSNorm, + Llama4VisionAttention, + Llama4VisionModel, +) +from transformers.models.llava.modeling_llava import ( + LlavaForConditionalGeneration, +) +from transformers.models.llava_next.modeling_llava_next import ( + LlavaNextForConditionalGeneration, +) +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralModel, + MistralRMSNorm, +) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MixtralRMSNorm, + MixtralSparseMoeBlock, +) +from transformers.models.mllama.modeling_mllama import ( + MllamaCrossAttentionDecoderLayer, + MllamaForCausalLM, + MllamaForConditionalGeneration, + MllamaModel, + MllamaRotaryEmbedding, + MllamaSelfAttentionDecoderLayer, + MllamaTextCrossAttention, + MllamaTextModel, + MllamaTextRMSNorm, + MllamaTextSelfAttention, + MllamaVisionModel, +) +from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel +from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel +from transformers.models.phi3.modeling_phi3 import ( + Phi3Attention, + Phi3DecoderLayer, + Phi3ForCausalLM, + Phi3Model, + Phi3RMSNorm, +) +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2Model, + Qwen2RMSNorm, +) +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel, + Qwen3MoeRMSNorm, + Qwen3MoeRotaryEmbedding, + Qwen3MoeSparseMoeBlock, +) +from transformers.models.starcoder2.modeling_starcoder2 import ( + Starcoder2Attention, + Starcoder2DecoderLayer, + Starcoder2ForCausalLM, + Starcoder2Model, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperAttention, + WhisperDecoder, + WhisperDecoderLayer, + WhisperEncoder, + WhisperForConditionalGeneration, + WhisperModel, + WhisperPositionalEmbedding, +) + +from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.codegen.modeling_codegen import ( + QEffCodeGenAttention, + QEffCodeGenBlock, + QEffCodeGenForCausalLM, + QEffCodeGenModel, +) +from QEfficient.transformers.models.falcon.modeling_falcon import ( + QEffFalconAttention, + QEffFalconDecoderLayer, + QEffFalconForCausalLM, + QEffFalconModel, +) +from QEfficient.transformers.models.gemma.modeling_gemma import ( + QEffGemmaAttention, + QEffGemmaDecoderLayer, + QEffGemmaForCausalLM, + QEffGemmaModel, +) +from QEfficient.transformers.models.gemma2.modeling_gemma2 import ( + QEffGemma2Attention, + QEffGemma2DecoderLayer, + QEffGemma2ForCausalLM, + QEffGemma2Model, +) +from QEfficient.transformers.models.gemma3.modeling_gemma3 import ( + QEffGemma3Attention, + QEffGemma3CustomRMSNormAIC, + QEffGemma3DecoderLayer, + QEffGemma3ForCausalLMModel, + QEffGemma3ForConditionalGeneration, + QEffGemma3TextModel, +) +from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( + QEffGPT2Attention, + QEffGPT2Block, + QEffGPT2LMHeadModel, + QEffGPT2Model, +) +from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( + QEffGPTBigCodeAttention, + QEffGPTBigCodeBlock, + QEffGPTBigCodeForCausalLM, + QEffGPTBigCodeModel, +) +from QEfficient.transformers.models.gptj.modeling_gptj import ( + QEffGPTJAttention, + QEffGPTJBlock, + QEffGPTJForCausalLM, + QEffGPTJModel, +) +from QEfficient.transformers.models.granite.modeling_granite import ( + QEffGraniteAttention, + QEffGraniteForCausalLM, + QEffGraniteModel, +) +from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( + QEffGraniteMoeAttention, + QEffGraniteMoeForCausalLM, + QEffGraniteMoeModel, + QEffGraniteMoeMoE, + QEffGraniteMoeParallelExperts, + QEffGraniteMoeRotaryEmbedding, + QEffGraniteMoeTopKGating, +) +from QEfficient.transformers.models.grok_1.modeling_grok1 import ( + QEFFGrok1CustomRMSNormAIC, + QEffGrok1DecoderLayer, + QEffGrok1Model, + QEffGrok1ModelForCausalLM, + QEffGrok1MoeBlock, + QEffGrok1MultiHeadAttention, +) +from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternVisionEmbeddings, + QEffInternVLModel, +) +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaAttention, + QEffLlamaDecoderLayer, + QEffLlamaForCausalLM, + QEffLlamaModel, + QEffLlamaRotaryEmbedding, +) +from QEfficient.transformers.models.llama4.modeling_llama4 import ( + QEffLlama4ForCausalLM, + QEffLlama4ForConditionalGeneration, + QEffLlama4Router, + QEffLlama4TextAttention, + QEffLlama4TextDecoderLayer, + QEffLlama4TextExperts, + QEffLlama4TextModel, + QEffLlama4TextMoe, + QEffLlama4VisionAttention, + QEffLlama4VisionModel, +) +from QEfficient.transformers.models.llava.modeling_llava import ( + QEffLlavaForConditionalGeneration, +) +from QEfficient.transformers.models.llava_next.modeling_llava_next import ( + QEffLlavaNextForConditionalGeneration, +) +from QEfficient.transformers.models.mistral.modeling_mistral import ( + QEffMistralAttention, + QEffMistralDecoderLayer, + QEffMistralForCausalLM, + QEffMistralModel, +) +from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( + QEffMixtralAttention, + QeffMixtralDecoderLayer, + QEffMixtralForCausalLM, + QEffMixtralModel, + QEffMixtralSparseMoeBlock, +) +from QEfficient.transformers.models.mllama.modeling_mllama import ( + QEffMllamaCrossAttentionDecoderLayer, + QEffMllamaForCausalLM, + QEffMllamaForConditionalGeneration, + QEffMllamaModel, + QEffMllamaRotaryEmbedding, + QEffMllamaSelfAttentionDecoderLayer, + QEffMllamaTextCrossAttentionSingleQPC, + QEffMllamaTextCrossAttentionTwoQPC, + QEffMllamaTextModel, + QEffMllamaTextSelfAttention, + QEffMllamaVisionModel, +) +from QEfficient.transformers.models.mpt.modeling_mpt import ( + QEffMptAttention, + QEffMptBlock, + QEffMptForCausalLM, + QEFfMptModel, +) +from QEfficient.transformers.models.phi.modeling_phi import ( + QEffPhiAttention, + QEffPhiDecoderLayer, + QEffPhiForCausalLM, + QEffPhiModel, +) +from QEfficient.transformers.models.phi3.modeling_phi3 import ( + QEffPhi3Attention, + QEffPhi3DecoderLayer, + QEffPhi3ForCausalLM, + QEffPhi3Model, +) +from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( + QEffQwen2Attention, + QEffQwen2DecoderLayer, + QEffQwen2ForCausalLM, + QEffQwen2Model, +) +from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffQwen3MoeAttention, + QEffQwen3MoeDecoderLayer, + QEffQwen3MoeForCausalLM, + QEffQwen3MoeModel, + QEffQwen3MoeRotaryEmbedding, + QEffQwen3MoeSparseMoeBlock, +) +from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( + QEffStarcoder2Attention, + QEFFStarcoder2DecoderLayer, + QEffStarcoder2ForCausalLM, + QEffStarcoder2Model, +) +from QEfficient.transformers.models.whisper.modeling_whisper import ( + QEffWhisperAttention, + QEffWhisperDecoder, + QEffWhisperDecoderLayer, + QEffWhisperEncoder, + QEffWhisperForConditionalGeneration, + QEffWhisperModel, + QEffWhisperPositionalEmbedding, +) +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.sampler.sampler import sampler_forward +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + GemmaRMSNorm: GemmaCustomRMSNormAIC, + Gemma2RMSNorm: GemmaCustomRMSNormAIC, + LlamaRMSNorm: CustomRMSNormAIC, + Llama4TextRMSNorm: CustomRMSNormAIC, + MistralRMSNorm: CustomRMSNormAIC, + MixtralRMSNorm: CustomRMSNormAIC, + Phi3RMSNorm: CustomRMSNormAIC, + Qwen2RMSNorm: CustomRMSNormAIC, + MllamaTextRMSNorm: CustomRMSNormAIC, + GraniteRMSNorm: CustomRMSNormAIC, + GraniteMoeRMSNorm: CustomRMSNormAIC, + Qwen3MoeRMSNorm: CustomRMSNormAIC, + Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, + } + + +class KVCacheTransform(ModuleMappingTransform): + _module_mapping = { + # CodeGen + CodeGenAttention: QEffCodeGenAttention, + CodeGenBlock: QEffCodeGenBlock, + CodeGenModel: QEffCodeGenModel, + CodeGenForCausalLM: QEffCodeGenForCausalLM, + # Falcon + FalconAttention: QEffFalconAttention, + FalconDecoderLayer: QEffFalconDecoderLayer, + FalconModel: QEffFalconModel, + FalconForCausalLM: QEffFalconForCausalLM, + # GPT2 + GPT2Attention: QEffGPT2Attention, + GPT2Block: QEffGPT2Block, + GPT2Model: QEffGPT2Model, + GPT2LMHeadModel: QEffGPT2LMHeadModel, + # GPTJ + GPTJAttention: QEffGPTJAttention, + GPTJBlock: QEffGPTJBlock, + GPTJModel: QEffGPTJModel, + GPTJForCausalLM: QEffGPTJForCausalLM, + # Llama + LlamaAttention: QEffLlamaAttention, + LlamaDecoderLayer: QEffLlamaDecoderLayer, + LlamaModel: QEffLlamaModel, + LlamaForCausalLM: QEffLlamaForCausalLM, + LlamaRotaryEmbedding: QEffLlamaRotaryEmbedding, + # Llama4 + Llama4TextAttention: QEffLlama4TextAttention, + Llama4ForCausalLM: QEffLlama4ForCausalLM, + Llama4TextDecoderLayer: QEffLlama4TextDecoderLayer, + Llama4TextModel: QEffLlama4TextModel, + Llama4TextMoe: QEffLlama4TextMoe, + Llama4ForConditionalGeneration: QEffLlama4ForConditionalGeneration, + Llama4VisionAttention: QEffLlama4VisionAttention, + Llama4VisionModel: QEffLlama4VisionModel, + Llama4TextExperts: QEffLlama4TextExperts, + Llama4Router: QEffLlama4Router, + # Llava + LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, + # Llava Next + LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration, + # Gemma + GemmaAttention: QEffGemmaAttention, + GemmaDecoderLayer: QEffGemmaDecoderLayer, + GemmaModel: QEffGemmaModel, + GemmaForCausalLM: QEffGemmaForCausalLM, + # Qwen3Moe + Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, + Qwen3MoeModel: QEffQwen3MoeModel, + Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, + Qwen3MoeAttention: QEffQwen3MoeAttention, + Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, + Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + # Gemma2 + Gemma2Attention: QEffGemma2Attention, + Gemma2DecoderLayer: QEffGemma2DecoderLayer, + Gemma2Model: QEffGemma2Model, + Gemma2ForCausalLM: QEffGemma2ForCausalLM, + # Gemma3 + Gemma3Attention: QEffGemma3Attention, + Gemma3DecoderLayer: QEffGemma3DecoderLayer, + Gemma3TextModel: QEffGemma3TextModel, + Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, + Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # Granite + GraniteModel: QEffGraniteModel, + GraniteForCausalLM: QEffGraniteForCausalLM, + GraniteAttention: QEffGraniteAttention, + # GraniteMoe + GraniteMoeModel: QEffGraniteMoeModel, + GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, + GraniteMoeAttention: QEffGraniteMoeAttention, + GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding, + GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, + GraniteMoeTopKGating: QEffGraniteMoeTopKGating, + GraniteMoeMoE: QEffGraniteMoeMoE, + # mllama + MllamaTextRMSNorm: CustomRMSNormAIC, + MllamaTextSelfAttention: QEffMllamaTextSelfAttention, + MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaModel: QEffMllamaModel, + MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, + MllamaVisionModel: QEffMllamaVisionModel, + MllamaTextModel: QEffMllamaTextModel, + MllamaForCausalLM: QEffMllamaForCausalLM, + MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, + # Mistral + MistralAttention: QEffMistralAttention, + MistralDecoderLayer: QEffMistralDecoderLayer, + MistralModel: QEffMistralModel, + MistralForCausalLM: QEffMistralForCausalLM, + # Mixtral + MixtralAttention: QEffMixtralAttention, + MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, + MixtralDecoderLayer: QeffMixtralDecoderLayer, + MixtralModel: QEffMixtralModel, + MixtralForCausalLM: QEffMixtralForCausalLM, + # Mpt + MptAttention: QEffMptAttention, + MptBlock: QEffMptBlock, + MptModel: QEFfMptModel, + MptForCausalLM: QEffMptForCausalLM, + # Phi3 + Phi3Attention: QEffPhi3Attention, + Phi3DecoderLayer: QEffPhi3DecoderLayer, + Phi3Model: QEffPhi3Model, + Phi3ForCausalLM: QEffPhi3ForCausalLM, + # Phi + PhiAttention: QEffPhiAttention, + PhiDecoderLayer: QEffPhiDecoderLayer, + PhiModel: QEffPhiModel, + PhiForCausalLM: QEffPhiForCausalLM, + # Qwen2 + Qwen2Attention: QEffQwen2Attention, + Qwen2DecoderLayer: QEffQwen2DecoderLayer, + Qwen2Model: QEffQwen2Model, + Qwen2ForCausalLM: QEffQwen2ForCausalLM, + # Starcoder2 + Starcoder2Attention: QEffStarcoder2Attention, + Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, + Starcoder2Model: QEffStarcoder2Model, + Starcoder2ForCausalLM: QEffStarcoder2ForCausalLM, + # GptBigcode + GPTBigCodeAttention: QEffGPTBigCodeAttention, + GPTBigCodeBlock: QEffGPTBigCodeBlock, + GPTBigCodeModel: QEffGPTBigCodeModel, + GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, + # Whisper encoder and decoder layers + WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, + WhisperAttention: QEffWhisperAttention, + WhisperDecoderLayer: QEffWhisperDecoderLayer, + WhisperEncoder: QEffWhisperEncoder, + WhisperDecoder: QEffWhisperDecoder, + WhisperModel: QEffWhisperModel, + WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + QEffLlamaForCausalLM, + QEffQwen2ForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None) + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + elif speculative_model_type not in ( + supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys()) + ): + raise ValueError( + f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}" + ) + elif (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + # build and attach draft mlp + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) + transformed = True + else: + raise NotImplementedError( + f"model class {model_class} does not yet support returning multiple logits to keep." + ) + kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp + return model, transformed + + +class SamplerTransform: + """ + Add nodes at the output of any generic QEffForCausalLM model to enable the + sampling of next tokens at the device (instead of the host) and return the + next tokens and/or probability distributions. + + Note: To achieve this, the generic QEffForCausalLM model must provide the + logits as output. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + QEffLlamaForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + if qaic_config is None or not qaic_config.get("include_sampler", False): + return model, transformed + elif (model_class := model.__class__) in cls._module_mapping: + model.old_forward = model.forward + model.forward = MethodType(sampler_forward, model) + transformed = True + else: + raise NotImplementedError(f"Model class {model_class} does not support on device sampling.") + return model, transformed + + +class VlmKVOffloadTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + # Llama + MllamaTextCrossAttention: QEffMllamaTextCrossAttentionTwoQPC, + } + + +class VlmNoKVOffloadTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + # Llama + MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, + } + + +class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_string_replace_method = { + "InternVLChatModel": { + "forward": QEffInternVLModel.forward, + "get_dummy_inputs": QEffInternVLModel.get_dummy_inputs, + "get_specializations": QEffInternVLModel.get_specializations, + "get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes, + "get_output_names": QEffInternVLModel.get_output_names, + "get_inputs_info": QEffInternVLModel.get_inputs_info, + "get_qeff_vision_encoder": QEffInternVLModel.get_qeff_vision_encoder, + "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, + }, + "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # Mapping for grok1 model + "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, + "Grok1Model": { + "forward": QEffGrok1Model.forward, + "__qeff_init__": QEffGrok1Model.__qeff_init__, + }, + "DecoderLayer": { + "forward": QEffGrok1DecoderLayer.forward, + "__qeff_init__": QEffGrok1DecoderLayer.__qeff_init__, + }, + "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, + "MultiHeadAttention": { + "forward": QEffGrok1MultiHeadAttention.forward, + }, + "RMSNorm": { + "forward": QEFFGrok1CustomRMSNormAIC.forward, + }, + } + + _match_class_replace_method = {} + + +class PoolingTransform: + """ + Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. + The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. + """ + + @classmethod + def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: + transformed = False + pooling_method = ( + POOLING_MAP[pooling] + if isinstance(pooling, str) and pooling in POOLING_MAP + else validate_user_pooling_function(pooling) + ) + model = PooledModel(model, pooling_method) + warnings.warn("Pooling is applied to the model.") + return model, transformed diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 00a3989d8..24e8df46c 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -7,7 +7,7 @@ """PyTorch Qwen2 model.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -118,7 +118,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -128,7 +127,6 @@ def eager_attention_forward( attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -163,18 +161,15 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -183,12 +178,12 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights class QEffQwen2DecoderLayer(Qwen2DecoderLayer): @@ -206,7 +201,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -216,9 +210,6 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -235,13 +226,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -254,15 +244,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffQwen2Model(Qwen2Model): @@ -282,20 +264,15 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" @@ -326,28 +303,21 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -357,13 +327,11 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, - attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class QEffQwen2ForCausalLM(Qwen2ForCausalLM): @@ -382,21 +350,14 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -406,17 +367,13 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, ) # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index bf3defc1a..591f7c1b0 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.nn.functional as F @@ -100,7 +100,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) @@ -212,19 +211,16 @@ def forward( query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -233,12 +229,11 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + + return attn_output, attn_weights class QEffQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): @@ -248,12 +243,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -261,21 +253,12 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, - and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model @@ -286,13 +269,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) @@ -304,21 +286,11 @@ def forward( hidden_states = self.mlp(hidden_states) if isinstance(hidden_states, tuple): - hidden_states, router_logits = hidden_states - else: - router_logits = None + hidden_states, _ = hidden_states hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs + return hidden_states class QEffQwen3MoeModel(Qwen3MoeModel): @@ -330,16 +302,10 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, batch_index: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> MoeModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -366,33 +332,21 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, - output_attentions=output_attentions, - output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -405,8 +359,6 @@ def forward( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, ) @@ -419,58 +371,16 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, batch_index: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM - - >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -478,18 +388,15 @@ def forward( inputs_embeds=inputs_embeds, batch_index=batch_index, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = self.lm_head(hidden_states).float() return MoeCausalLMOutputWithPast( logits=logits, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index e3db4b490..9a327761d 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -7,7 +7,7 @@ """PyTorch Starcoder2 model.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -37,8 +37,6 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - dropout: float = 0.0, - **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -48,10 +46,10 @@ def eager_attention_forward( attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -82,17 +80,14 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -101,13 +96,12 @@ def forward( value_states, attention_mask, scaling=self.scaling, - **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class QEFFStarcoder2DecoderLayer(Starcoder2DecoderLayer): @@ -125,7 +119,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, batch_index: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -136,9 +129,6 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -158,13 +148,12 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -178,15 +167,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class QEffStarcoder2Model(Starcoder2Model): @@ -206,19 +187,14 @@ def forward( batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -251,29 +227,22 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, batch_index=batch_index, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer @@ -283,13 +252,11 @@ def forward( if use_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=past_key_values, hidden_states=all_hidden_states, - attentions=all_self_attns, ) - return output if return_dict else output.to_tuple() class QEffStarcoder2ForCausalLM(Starcoder2ForCausalLM): @@ -308,46 +275,14 @@ def forward( past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM - - >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -357,18 +292,15 @@ def forward( batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] - logits = self.lm_head(hidden_states) - logits = logits.float() + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states).float() return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index afa2a6b07..e078493a7 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -63,18 +63,23 @@ def forward( is_cross_attention: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - # get query proj - query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(*q_input_shape) + query_states = query_states.transpose(1, 2).contiguous() if self.is_decoder: if is_cross_attention and past_key_value: # cross_attentions key_states_old = past_key_value[self.layer_idx][0] value_states_old = past_key_value[self.layer_idx][1] - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() indices = (torch.arange(bsz),) key_states_new = torch.index_put(key_states_old, indices, key_states) value_states_new = torch.index_put(value_states_old, indices, value_states) @@ -85,12 +90,14 @@ def forward( input_features.shape[2] == torch.tensor(1), value_states_old, value_states_new ) - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + past_key_value.layers[self.layer_idx].keys = key_states + past_key_value.layers[self.layer_idx].values = value_states else: # self attention decoder - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: cache_kwargs = {"position_ids": position_ids_layer} key_states, value_states = past_key_value.update( @@ -98,8 +105,10 @@ def forward( ) else: # self_attention Encoder - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() src_len = key_states.size(2) @@ -150,7 +159,7 @@ def forward( attn_output = self.out_proj(attn_output) - return [attn_output, attn_weights, past_key_value] + return [attn_output, attn_weights] class QEffWhisperDecoderLayer(WhisperDecoderLayer): @@ -203,7 +212,7 @@ def forward( # Self Attention self_attn_past_key_value = past_key_value.self_attention_cache if past_key_value is not None else None - hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, @@ -217,13 +226,12 @@ def forward( hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if is_encoder_decoder: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) cross_attn_past_key_value = past_key_value.cross_attention_cache if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -237,13 +245,6 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout) hidden_states = residual + hidden_states - # update the cached past_key_values accordingly - past_key_value.self_attention_cache = self_attn_present_key_value - past_key_value.cross_attention_cache = cross_attn_present_key_value - else: - # if no cross_attention, still need to update self_attn cache - past_key_value = self_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 14cfb65ce..abe383556 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -382,7 +382,7 @@ def get_padding_shape_from_config(config, batch_size, seq_len): ): # Check for num_key_value_heads (Llama/Mistral) n_heads = config.num_key_value_heads - if hasattr(config, "head_dim"): + if hasattr(config, "head_dim") and config.head_dim is not None: d_head = config.head_dim else: d_head = config.hidden_size // config.num_attention_heads @@ -404,10 +404,16 @@ def get_padding_shape_from_config(config, batch_size, seq_len): d_head = config.hidden_size // config.num_attention_heads else: raise ValueError("Invalid model configuration: n_head/d_heads or num_key_value_heads not found.") - padding_shape = [batch_size, n_heads, seq_len, d_head] + if hasattr(config, "architectures") and config.architectures is not None: # Check for Starcoder1 - 3D layout if "GPTBigCodeForCausalLM" in config.architectures: - padding_shape = [batch_size, seq_len, d_head] + if hasattr(config, "multi_query"): + multi_query_value = getattr(config, "multi_query") + if multi_query_value: + n_heads = 1 # MQA , multi query is true + else: + n_heads = config.n_head + padding_shape = [batch_size, n_heads, seq_len, d_head] return padding_shape diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 3d70ac4f3..9fc977154 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -165,7 +165,60 @@ class ModelConfig: } EXTERNAL_MODELS = { - "hpcai-tech/grok-1", + "hpcai-tech/grok-1": { + "pytorch_hf_tokens_custom_case": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "pytorch_hf_tokens_normal_case": [ + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + 391, + ], + } } SWIFTKV_MODELS = { diff --git a/examples/gemma3_example/fp32_nodes_gemma3_27b.yaml b/examples/gemma3_example/fp32_nodes_gemma3_27b.yaml index c57e846e1..d2a4bf164 100755 --- a/examples/gemma3_example/fp32_nodes_gemma3_27b.yaml +++ b/examples/gemma3_example/fp32_nodes_gemma3_27b.yaml @@ -1,685 +1,685 @@ FP32NodeInstanceNames: - - /language_model/model/layers.0/Add_1_output_0 - - /language_model/model/layers.0/Add_output_0 - - /language_model/model/layers.0/Add_2_output_0 - - /language_model/model/layers.0/Add_3_output_0 - - /language_model/model/layers.1/Add_1_output_0 - - /language_model/model/layers.1/Add_2_output_0 - - /language_model/model/layers.1/Add_3_output_0 - - /language_model/model/layers.1/Add_output_0 - - /language_model/model/layers.2/Add_1_output_0 - - /language_model/model/layers.2/Add_2_output_0 - - /language_model/model/layers.2/Add_3_output_0 - - /language_model/model/layers.2/Add_output_0 - - /language_model/model/layers.3/Add_1_output_0 - - /language_model/model/layers.3/Add_2_output_0 - - /language_model/model/layers.3/Add_3_output_0 - - /language_model/model/layers.3/Add_output_0 - - /language_model/model/layers.4/Add_1_output_0 - - /language_model/model/layers.4/Add_2_output_0 - - /language_model/model/layers.4/Add_3_output_0 - - /language_model/model/layers.4/Add_output_0 - - /language_model/model/layers.5/Add_1_output_0 - - /language_model/model/layers.5/Add_2_output_0 - - /language_model/model/layers.5/Add_3_output_0 - - /language_model/model/layers.5/Add_output_0 - - /language_model/model/layers.6/Add_1_output_0 - - /language_model/model/layers.6/Add_2_output_0 - - /language_model/model/layers.6/Add_3_output_0 - - /language_model/model/layers.6/Add_output_0 - - /language_model/model/layers.7/Add_1_output_0 - - /language_model/model/layers.7/Add_2_output_0 - - /language_model/model/layers.7/Add_3_output_0 - - /language_model/model/layers.7/Add_output_0 - - /language_model/model/layers.8/Add_1_output_0 - - /language_model/model/layers.8/Add_2_output_0 - - /language_model/model/layers.8/Add_3_output_0 - - /language_model/model/layers.8/Add_output_0 - - /language_model/model/layers.9/Add_1_output_0 - - /language_model/model/layers.9/Add_2_output_0 - - /language_model/model/layers.9/Add_3_output_0 - - /language_model/model/layers.9/Add_output_0 - - /language_model/model/layers.10/Add_1_output_0 - - /language_model/model/layers.10/Add_2_output_0 - - /language_model/model/layers.10/Add_3_output_0 - - /language_model/model/layers.10/Add_output_0 - - /language_model/model/layers.11/Add_1_output_0 - - /language_model/model/layers.11/Add_2_output_0 - - /language_model/model/layers.11/Add_3_output_0 - - /language_model/model/layers.11/Add_output_0 - - /language_model/model/layers.12/Add_1_output_0 - - /language_model/model/layers.12/Add_2_output_0 - - /language_model/model/layers.12/Add_3_output_0 - - /language_model/model/layers.12/Add_output_0 - - /language_model/model/layers.13/Add_1_output_0 - - /language_model/model/layers.13/Add_2_output_0 - - /language_model/model/layers.13/Add_3_output_0 - - /language_model/model/layers.13/Add_output_0 - - /language_model/model/layers.14/Add_1_output_0 - - /language_model/model/layers.14/Add_2_output_0 - - /language_model/model/layers.14/Add_3_output_0 - - /language_model/model/layers.14/Add_output_0 - - /language_model/model/layers.15/Add_1_output_0 - - /language_model/model/layers.15/Add_2_output_0 - - /language_model/model/layers.15/Add_3_output_0 - - /language_model/model/layers.15/Add_output_0 - - /language_model/model/layers.16/Add_1_output_0 - - /language_model/model/layers.16/Add_2_output_0 - - /language_model/model/layers.16/Add_3_output_0 - - /language_model/model/layers.16/Add_output_0 - - /language_model/model/layers.17/Add_1_output_0 - - /language_model/model/layers.17/Add_2_output_0 - - /language_model/model/layers.17/Add_3_output_0 - - /language_model/model/layers.17/Add_output_0 - - /language_model/model/layers.18/Add_1_output_0 - - /language_model/model/layers.18/Add_2_output_0 - - /language_model/model/layers.18/Add_3_output_0 - - /language_model/model/layers.18/Add_output_0 - - /language_model/model/layers.19/Add_1_output_0 - - /language_model/model/layers.19/Add_2_output_0 - - /language_model/model/layers.19/Add_3_output_0 - - /language_model/model/layers.19/Add_output_0 - - /language_model/model/layers.20/Add_1_output_0 - - /language_model/model/layers.20/Add_2_output_0 - - /language_model/model/layers.20/Add_3_output_0 - - /language_model/model/layers.20/Add_output_0 - - /language_model/model/layers.21/Add_1_output_0 - - /language_model/model/layers.21/Add_2_output_0 - - /language_model/model/layers.21/Add_3_output_0 - - /language_model/model/layers.21/Add_output_0 - - /language_model/model/layers.22/Add_1_output_0 - - /language_model/model/layers.22/Add_2_output_0 - - /language_model/model/layers.22/Add_3_output_0 - - /language_model/model/layers.22/Add_output_0 - - /language_model/model/layers.23/Add_1_output_0 - - /language_model/model/layers.23/Add_2_output_0 - - /language_model/model/layers.23/Add_output_0 - - /language_model/model/layers.24/Add_1_output_0 - - /language_model/model/layers.24/Add_2_output_0 - - /language_model/model/layers.24/Add_3_output_0 - - /language_model/model/layers.24/Add_output_0 - - /language_model/model/layers.25/Add_1_output_0 - - /language_model/model/layers.25/Add_2_output_0 - - /language_model/model/layers.25/Add_3_output_0 - - /language_model/model/layers.25/Add_output_0 - - /language_model/model/layers.26/Add_1_output_0 - - /language_model/model/layers.26/Add_2_output_0 - - /language_model/model/layers.26/Add_3_output_0 - - /language_model/model/layers.26/Add_output_0 - - /language_model/model/layers.27/Add_1_output_0 - - /language_model/model/layers.27/Add_2_output_0 - - /language_model/model/layers.27/Add_3_output_0 - - /language_model/model/layers.27/Add_output_0 - - /language_model/model/layers.28/Add_1_output_0 - - /language_model/model/layers.28/Add_2_output_0 - - /language_model/model/layers.28/Add_3_output_0 - - /language_model/model/layers.28/Add_output_0 - - /language_model/model/layers.29/Add_1_output_0 - - /language_model/model/layers.29/Add_2_output_0 - - /language_model/model/layers.29/Add_3_output_0 - - /language_model/model/layers.29/Add_output_0 - - /language_model/model/layers.30/Add_1_output_0 - - /language_model/model/layers.30/Add_2_output_0 - - /language_model/model/layers.30/Add_3_output_0 - - /language_model/model/layers.30/Add_output_0 - - /language_model/model/layers.31/Add_1_output_0 - - /language_model/model/layers.31/Add_2_output_0 - - /language_model/model/layers.31/Add_3_output_0 - - /language_model/model/layers.31/Add_output_0 - - /language_model/model/layers.32/Add_1_output_0 - - /language_model/model/layers.32/Add_2_output_0 - - /language_model/model/layers.32/Add_3_output_0 - - /language_model/model/layers.32/Add_output_0 - - /language_model/model/layers.33/Add_1_output_0 - - /language_model/model/layers.33/Add_2_output_0 - - /language_model/model/layers.33/Add_3_output_0 - - /language_model/model/layers.33/Add_output_0 - - /language_model/model/layers.34/Add_1_output_0 - - /language_model/model/layers.34/Add_2_output_0 - - /language_model/model/layers.34/Add_3_output_0 - - /language_model/model/layers.34/Add_output_0 - - /language_model/model/layers.35/Add_1_output_0 - - /language_model/model/layers.35/Add_2_output_0 - - /language_model/model/layers.35/Add_3_output_0 - - /language_model/model/layers.35/Add_output_0 - - /language_model/model/layers.36/Add_1_output_0 - - /language_model/model/layers.36/Add_2_output_0 - - /language_model/model/layers.36/Add_3_output_0 - - /language_model/model/layers.36/Add_output_0 - - /language_model/model/layers.37/Add_1_output_0 - - /language_model/model/layers.37/Add_2_output_0 - - /language_model/model/layers.37/Add_3_output_0 - - /language_model/model/layers.37/Add_output_0 - - /language_model/model/layers.38/Add_1_output_0 - - /language_model/model/layers.38/Add_2_output_0 - - /language_model/model/layers.38/Add_3_output_0 - - /language_model/model/layers.38/Add_output_0 - - /language_model/model/layers.39/Add_1_output_0 - - /language_model/model/layers.39/Add_2_output_0 - - /language_model/model/layers.39/Add_3_output_0 - - /language_model/model/layers.39/Add_output_0 - - /language_model/model/layers.40/Add_1_output_0 - - /language_model/model/layers.40/Add_2_output_0 - - /language_model/model/layers.40/Add_3_output_0 - - /language_model/model/layers.40/Add_output_0 - - /language_model/model/layers.41/Add_1_output_0 - - /language_model/model/layers.41/Add_2_output_0 - - /language_model/model/layers.41/Add_3_output_0 - - /language_model/model/layers.41/Add_output_0 - - /language_model/model/layers.42/Add_1_output_0 - - /language_model/model/layers.42/Add_2_output_0 - - /language_model/model/layers.42/Add_3_output_0 - - /language_model/model/layers.42/Add_output_0 - - /language_model/model/layers.43/Add_1_output_0 - - /language_model/model/layers.43/Add_2_output_0 - - /language_model/model/layers.43/Add_3_output_0 - - /language_model/model/layers.43/Add_output_0 - - /language_model/model/layers.44/Add_1_output_0 - - /language_model/model/layers.44/Add_2_output_0 - - /language_model/model/layers.44/Add_3_output_0 - - /language_model/model/layers.44/Add_output_0 - - /language_model/model/layers.45/Add_1_output_0 - - /language_model/model/layers.45/Add_2_output_0 - - /language_model/model/layers.45/Add_3_output_0 - - /language_model/model/layers.45/Add_output_0 - - /language_model/model/layers.46/Add_1_output_0 - - /language_model/model/layers.46/Add_2_output_0 - - /language_model/model/layers.46/Add_3_output_0 - - /language_model/model/layers.46/Add_output_0 - - /language_model/model/layers.47/Add_1_output_0 - - /language_model/model/layers.47/Add_2_output_0 - - /language_model/model/layers.47/Add_3_output_0 - - /language_model/model/layers.47/Add_output_0 - - /language_model/model/layers.48/Add_1_output_0 - - /language_model/model/layers.48/Add_2_output_0 - - /language_model/model/layers.48/Add_3_output_0 - - /language_model/model/layers.48/Add_output_0 - - /language_model/model/layers.49/Add_1_output_0 - - /language_model/model/layers.49/Add_2_output_0 - - /language_model/model/layers.49/Add_3_output_0 - - /language_model/model/layers.49/Add_output_0 - - /language_model/model/layers.50/Add_1_output_0 - - /language_model/model/layers.50/Add_2_output_0 - - /language_model/model/layers.50/Add_3_output_0 - - /language_model/model/layers.50/Add_output_0 - - /language_model/model/layers.51/Add_1_output_0 - - /language_model/model/layers.51/Add_2_output_0 - - /language_model/model/layers.51/Add_3_output_0 - - /language_model/model/layers.51/Add_output_0 - - /language_model/model/layers.52/Add_1_output_0 - - /language_model/model/layers.52/Add_2_output_0 - - /language_model/model/layers.52/Add_3_output_0 - - /language_model/model/layers.52/Add_output_0 - - /language_model/model/layers.53/Add_1_output_0 - - /language_model/model/layers.53/Add_2_output_0 - - /language_model/model/layers.53/Add_3_output_0 - - /language_model/model/layers.53/Add_output_0 - - /language_model/model/layers.54/Add_1_output_0 - - /language_model/model/layers.54/Add_2_output_0 - - /language_model/model/layers.54/Add_3_output_0 - - /language_model/model/layers.54/Add_output_0 - - /language_model/model/layers.55/Add_1_output_0 - - /language_model/model/layers.55/Add_2_output_0 - - /language_model/model/layers.55/Add_3_output_0 - - /language_model/model/layers.55/Add_output_0 - - /language_model/model/layers.56/Add_1_output_0 - - /language_model/model/layers.56/Add_2_output_0 - - /language_model/model/layers.56/Add_3_output_0 - - /language_model/model/layers.56/Add_output_0 - - /language_model/model/layers.57/Add_1_output_0 - - /language_model/model/layers.57/Add_2_output_0 - - /language_model/model/layers.57/Add_3_output_0 - - /language_model/model/layers.57/Add_output_0 - - /language_model/model/layers.58/Add_1_output_0 - - /language_model/model/layers.58/Add_2_output_0 - - /language_model/model/layers.58/Add_3_output_0 - - /language_model/model/layers.58/Add_output_0 - - /language_model/model/layers.59/Add_1_output_0 - - /language_model/model/layers.59/Add_2_output_0 - - /language_model/model/layers.59/Add_3_output_0 - - /language_model/model/layers.59/Add_output_0 - - /language_model/model/layers.60/Add_1_output_0 - - /language_model/model/layers.60/Add_2_output_0 - - /language_model/model/layers.60/Add_3_output_0 - - /language_model/model/layers.60/Add_output_0 - - /language_model/model/layers.61/Add_1_output_0 - - /language_model/model/layers.61/Add_2_output_0 - - /language_model/model/layers.61/Add_3_output_0 - - /language_model/model/layers.61/Add_output_0 - - /language_model/model/norm/Add_output_0 - - /language_model/model/layers.0/self_attn/Mul_output_0 - - /language_model/model/layers.2/self_attn/Mul_output_0 - - /language_model/model/layers.3/self_attn/Mul_output_0 - - /language_model/model/layers.4/self_attn/Mul_output_0 - - /language_model/model/layers.5/self_attn/Mul_output_0 - - /language_model/model/layers.6/self_attn/Mul_output_0 - - /language_model/model/layers.7/self_attn/Mul_output_0 - - /language_model/model/layers.8/self_attn/Mul_output_0 - - /language_model/model/layers.9/self_attn/Mul_output_0 - - /language_model/model/layers.10/self_attn/Mul_output_0 - - /language_model/model/layers.11/self_attn/Mul_output_0 - - /language_model/model/layers.12/self_attn/Mul_output_0 - - /language_model/model/layers.13/self_attn/Mul_output_0 - - /language_model/model/layers.14/self_attn/Mul_output_0 - - /language_model/model/layers.15/self_attn/Mul_output_0 - - /language_model/model/layers.16/self_attn/Mul_output_0 - - /language_model/model/layers.17/self_attn/Mul_output_0 - - /language_model/model/layers.18/self_attn/Mul_output_0 - - /language_model/model/layers.19/self_attn/Mul_output_0 - - /language_model/model/layers.20/self_attn/Mul_output_0 - - /language_model/model/layers.21/self_attn/Mul_output_0 - - /language_model/model/layers.22/self_attn/Mul_output_0 - - /language_model/model/layers.23/self_attn/Mul_output_0 - - /language_model/model/layers.24/self_attn/Mul_output_0 - - /language_model/model/layers.25/self_attn/Mul_output_0 - - /language_model/model/layers.26/self_attn/Mul_output_0 - - /language_model/model/layers.27/self_attn/Mul_output_0 - - /language_model/model/layers.28/self_attn/Mul_output_0 - - /language_model/model/layers.29/self_attn/Mul_output_0 - - /language_model/model/layers.30/self_attn/Mul_output_0 - - /language_model/model/layers.31/self_attn/Mul_output_0 - - /language_model/model/layers.32/self_attn/Mul_output_0 - - /language_model/model/layers.33/self_attn/Mul_output_0 - - /language_model/model/layers.34/self_attn/Mul_output_0 - - /language_model/model/layers.35/self_attn/Mul_output_0 - - /language_model/model/layers.36/self_attn/Mul_output_0 - - /language_model/model/layers.37/self_attn/Mul_output_0 - - /language_model/model/layers.38/self_attn/Mul_output_0 - - /language_model/model/layers.39/self_attn/Mul_output_0 - - /language_model/model/layers.40/self_attn/Mul_output_0 - - /language_model/model/layers.41/self_attn/Mul_output_0 - - /language_model/model/layers.42/self_attn/Mul_output_0 - - /language_model/model/layers.43/self_attn/Mul_output_0 - - /language_model/model/layers.44/self_attn/Mul_output_0 - - /language_model/model/layers.45/self_attn/Mul_output_0 - - /language_model/model/layers.46/self_attn/Mul_output_0 - - /language_model/model/layers.47/self_attn/Mul_output_0 - - /language_model/model/layers.48/self_attn/Mul_output_0 - - /language_model/model/layers.49/self_attn/Mul_output_0 - - /language_model/model/layers.50/self_attn/Mul_output_0 - - /language_model/model/layers.51/self_attn/Mul_output_0 - - /language_model/model/layers.52/self_attn/Mul_output_0 - - /language_model/model/layers.53/self_attn/Mul_output_0 - - /language_model/model/layers.54/self_attn/Mul_output_0 - - /language_model/model/layers.55/self_attn/Mul_output_0 - - /language_model/model/layers.56/self_attn/Mul_output_0 - - /language_model/model/layers.57/self_attn/Mul_output_0 - - /language_model/model/layers.58/self_attn/Mul_output_0 - - /language_model/model/layers.59/self_attn/Mul_output_0 - - /language_model/model/layers.60/self_attn/Mul_output_0 - - /language_model/model/layers.61/self_attn/Mul_output_0 - - /language_model/model/layers.0/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/norm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.0/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.5/Add_2_output_0 + - /language_model/layers.5/Add_3_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.11/Add_2_output_0 + - /language_model/layers.11/Add_3_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.17/Add_2_output_0 + - /language_model/layers.17/Add_3_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.23/Add_2_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.29/Add_2_output_0 + - /language_model/layers.29/Add_3_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.34/Add_1_output_0 + - /language_model/layers.34/Add_2_output_0 + - /language_model/layers.34/Add_3_output_0 + - /language_model/layers.34/Add_output_0 + - /language_model/layers.35/Add_1_output_0 + - /language_model/layers.35/Add_2_output_0 + - /language_model/layers.35/Add_3_output_0 + - /language_model/layers.35/Add_output_0 + - /language_model/layers.36/Add_1_output_0 + - /language_model/layers.36/Add_2_output_0 + - /language_model/layers.36/Add_3_output_0 + - /language_model/layers.36/Add_output_0 + - /language_model/layers.37/Add_1_output_0 + - /language_model/layers.37/Add_2_output_0 + - /language_model/layers.37/Add_3_output_0 + - /language_model/layers.37/Add_output_0 + - /language_model/layers.38/Add_1_output_0 + - /language_model/layers.38/Add_2_output_0 + - /language_model/layers.38/Add_3_output_0 + - /language_model/layers.38/Add_output_0 + - /language_model/layers.39/Add_1_output_0 + - /language_model/layers.39/Add_2_output_0 + - /language_model/layers.39/Add_3_output_0 + - /language_model/layers.39/Add_output_0 + - /language_model/layers.40/Add_1_output_0 + - /language_model/layers.40/Add_2_output_0 + - /language_model/layers.40/Add_3_output_0 + - /language_model/layers.40/Add_output_0 + - /language_model/layers.41/Add_1_output_0 + - /language_model/layers.41/Add_2_output_0 + - /language_model/layers.41/Add_3_output_0 + - /language_model/layers.41/Add_output_0 + - /language_model/layers.42/Add_1_output_0 + - /language_model/layers.42/Add_2_output_0 + - /language_model/layers.42/Add_3_output_0 + - /language_model/layers.42/Add_output_0 + - /language_model/layers.43/Add_1_output_0 + - /language_model/layers.43/Add_2_output_0 + - /language_model/layers.43/Add_3_output_0 + - /language_model/layers.43/Add_output_0 + - /language_model/layers.44/Add_1_output_0 + - /language_model/layers.44/Add_2_output_0 + - /language_model/layers.44/Add_3_output_0 + - /language_model/layers.44/Add_output_0 + - /language_model/layers.45/Add_1_output_0 + - /language_model/layers.45/Add_2_output_0 + - /language_model/layers.45/Add_3_output_0 + - /language_model/layers.45/Add_output_0 + - /language_model/layers.46/Add_1_output_0 + - /language_model/layers.46/Add_2_output_0 + - /language_model/layers.46/Add_3_output_0 + - /language_model/layers.46/Add_output_0 + - /language_model/layers.47/Add_1_output_0 + - /language_model/layers.47/Add_2_output_0 + - /language_model/layers.47/Add_3_output_0 + - /language_model/layers.47/Add_output_0 + - /language_model/layers.48/Add_1_output_0 + - /language_model/layers.48/Add_2_output_0 + - /language_model/layers.48/Add_3_output_0 + - /language_model/layers.48/Add_output_0 + - /language_model/layers.49/Add_1_output_0 + - /language_model/layers.49/Add_2_output_0 + - /language_model/layers.49/Add_3_output_0 + - /language_model/layers.49/Add_output_0 + - /language_model/layers.50/Add_1_output_0 + - /language_model/layers.50/Add_2_output_0 + - /language_model/layers.50/Add_3_output_0 + - /language_model/layers.50/Add_output_0 + - /language_model/layers.51/Add_1_output_0 + - /language_model/layers.51/Add_2_output_0 + - /language_model/layers.51/Add_3_output_0 + - /language_model/layers.51/Add_output_0 + - /language_model/layers.52/Add_1_output_0 + - /language_model/layers.52/Add_2_output_0 + - /language_model/layers.52/Add_3_output_0 + - /language_model/layers.52/Add_output_0 + - /language_model/layers.53/Add_1_output_0 + - /language_model/layers.53/Add_2_output_0 + - /language_model/layers.53/Add_3_output_0 + - /language_model/layers.53/Add_output_0 + - /language_model/layers.54/Add_1_output_0 + - /language_model/layers.54/Add_2_output_0 + - /language_model/layers.54/Add_3_output_0 + - /language_model/layers.54/Add_output_0 + - /language_model/layers.55/Add_1_output_0 + - /language_model/layers.55/Add_2_output_0 + - /language_model/layers.55/Add_3_output_0 + - /language_model/layers.55/Add_output_0 + - /language_model/layers.56/Add_1_output_0 + - /language_model/layers.56/Add_2_output_0 + - /language_model/layers.56/Add_3_output_0 + - /language_model/layers.56/Add_output_0 + - /language_model/layers.57/Add_1_output_0 + - /language_model/layers.57/Add_2_output_0 + - /language_model/layers.57/Add_3_output_0 + - /language_model/layers.57/Add_output_0 + - /language_model/layers.58/Add_1_output_0 + - /language_model/layers.58/Add_2_output_0 + - /language_model/layers.58/Add_3_output_0 + - /language_model/layers.58/Add_output_0 + - /language_model/layers.59/Add_1_output_0 + - /language_model/layers.59/Add_2_output_0 + - /language_model/layers.59/Add_3_output_0 + - /language_model/layers.59/Add_output_0 + - /language_model/layers.60/Add_1_output_0 + - /language_model/layers.60/Add_2_output_0 + - /language_model/layers.60/Add_3_output_0 + - /language_model/layers.60/Add_output_0 + - /language_model/layers.61/Add_1_output_0 + - /language_model/layers.61/Add_2_output_0 + - /language_model/layers.61/Add_3_output_0 + - /language_model/layers.61/Add_output_0 + - /language_model/norm/Add_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.34/self_attn/Mul_output_0 + - /language_model/layers.35/self_attn/Mul_output_0 + - /language_model/layers.36/self_attn/Mul_output_0 + - /language_model/layers.37/self_attn/Mul_output_0 + - /language_model/layers.38/self_attn/Mul_output_0 + - /language_model/layers.39/self_attn/Mul_output_0 + - /language_model/layers.40/self_attn/Mul_output_0 + - /language_model/layers.41/self_attn/Mul_output_0 + - /language_model/layers.42/self_attn/Mul_output_0 + - /language_model/layers.43/self_attn/Mul_output_0 + - /language_model/layers.44/self_attn/Mul_output_0 + - /language_model/layers.45/self_attn/Mul_output_0 + - /language_model/layers.46/self_attn/Mul_output_0 + - /language_model/layers.47/self_attn/Mul_output_0 + - /language_model/layers.48/self_attn/Mul_output_0 + - /language_model/layers.49/self_attn/Mul_output_0 + - /language_model/layers.50/self_attn/Mul_output_0 + - /language_model/layers.51/self_attn/Mul_output_0 + - /language_model/layers.52/self_attn/Mul_output_0 + - /language_model/layers.53/self_attn/Mul_output_0 + - /language_model/layers.54/self_attn/Mul_output_0 + - /language_model/layers.55/self_attn/Mul_output_0 + - /language_model/layers.56/self_attn/Mul_output_0 + - /language_model/layers.57/self_attn/Mul_output_0 + - /language_model/layers.58/self_attn/Mul_output_0 + - /language_model/layers.59/self_attn/Mul_output_0 + - /language_model/layers.60/self_attn/Mul_output_0 + - /language_model/layers.61/self_attn/Mul_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/norm/CustomRMSNorm_output_0 diff --git a/examples/gemma3_example/fp32_nodes_gemma3_4b.yaml b/examples/gemma3_example/fp32_nodes_gemma3_4b.yaml index 28e7485fa..1c8aa1c41 100755 --- a/examples/gemma3_example/fp32_nodes_gemma3_4b.yaml +++ b/examples/gemma3_example/fp32_nodes_gemma3_4b.yaml @@ -1,697 +1,698 @@ FP32NodeInstanceNames: - - /language_model/model/layers.0/Add_output_0 - - /language_model/model/layers.0/Add_1_output_0 - - /language_model/model/layers.0/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/Add_2_output_0 - - /language_model/model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/Add_3_output_0 - - /language_model/model/layers.1/Add_output_0 - - /language_model/model/layers.1/Add_1_output_0 - - /language_model/model/layers.1/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/Add_2_output_0 - - /language_model/model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.1/Add_3_output_0 - - /language_model/model/layers.2/Add_output_0 - - /language_model/model/layers.2/Add_1_output_0 - - /language_model/model/layers.2/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/Add_2_output_0 - - /language_model/model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.2/Add_3_output_0 - - /language_model/model/layers.3/Add_output_0 - - /language_model/model/layers.3/Add_1_output_0 - - /language_model/model/layers.3/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/Add_2_output_0 - - /language_model/model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.3/Add_3_output_0 - - /language_model/model/layers.4/Add_output_0 - - /language_model/model/layers.4/Add_1_output_0 - - /language_model/model/layers.4/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/Add_2_output_0 - - /language_model/model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.4/Add_3_output_0 - - /language_model/model/layers.5/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/Add_output_0 - - /language_model/model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.5/Add_1_output_0 - - /language_model/model/layers.6/Add_output_0 - - /language_model/model/layers.6/Add_1_output_0 - - /language_model/model/layers.6/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/Add_2_output_0 - - /language_model/model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.6/Add_3_output_0 - - /language_model/model/layers.7/Add_output_0 - - /language_model/model/layers.7/Add_1_output_0 - - /language_model/model/layers.7/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/Add_2_output_0 - - /language_model/model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.7/Add_3_output_0 - - /language_model/model/layers.8/Add_output_0 - - /language_model/model/layers.8/Add_1_output_0 - - /language_model/model/layers.8/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/Add_2_output_0 - - /language_model/model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.8/Add_3_output_0 - - /language_model/model/layers.9/Add_output_0 - - /language_model/model/layers.9/Add_1_output_0 - - /language_model/model/layers.9/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/Add_2_output_0 - - /language_model/model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.9/Add_3_output_0 - - /language_model/model/layers.10/Add_output_0 - - /language_model/model/layers.10/Add_1_output_0 - - /language_model/model/layers.10/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/Add_2_output_0 - - /language_model/model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.10/Add_3_output_0 - - /language_model/model/layers.11/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/Add_output_0 - - /language_model/model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.11/Add_1_output_0 - - /language_model/model/layers.12/Add_output_0 - - /language_model/model/layers.12/Add_1_output_0 - - /language_model/model/layers.12/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/Add_2_output_0 - - /language_model/model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.12/Add_3_output_0 - - /language_model/model/layers.13/Add_output_0 - - /language_model/model/layers.13/Add_1_output_0 - - /language_model/model/layers.13/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/Add_2_output_0 - - /language_model/model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.13/Add_3_output_0 - - /language_model/model/layers.14/Add_output_0 - - /language_model/model/layers.14/Add_1_output_0 - - /language_model/model/layers.14/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/Add_2_output_0 - - /language_model/model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.14/Add_3_output_0 - - /language_model/model/layers.15/Add_output_0 - - /language_model/model/layers.15/Add_1_output_0 - - /language_model/model/layers.15/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/Add_2_output_0 - - /language_model/model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.15/Add_3_output_0 - - /language_model/model/layers.16/Add_output_0 - - /language_model/model/layers.16/Add_1_output_0 - - /language_model/model/layers.16/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/Add_2_output_0 - - /language_model/model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.16/Add_3_output_0 - - /language_model/model/layers.17/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/Add_output_0 - - /language_model/model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.17/Add_1_output_0 - - /language_model/model/layers.18/Add_output_0 - - /language_model/model/layers.18/Add_1_output_0 - - /language_model/model/layers.18/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/Add_2_output_0 - - /language_model/model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.18/Add_3_output_0 - - /language_model/model/layers.19/Add_output_0 - - /language_model/model/layers.19/Add_1_output_0 - - /language_model/model/layers.19/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/Add_2_output_0 - - /language_model/model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.19/Add_3_output_0 - - /language_model/model/layers.20/Add_output_0 - - /language_model/model/layers.20/Add_1_output_0 - - /language_model/model/layers.20/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/Add_2_output_0 - - /language_model/model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.20/Add_3_output_0 - - /language_model/model/layers.21/Add_output_0 - - /language_model/model/layers.21/Add_1_output_0 - - /language_model/model/layers.21/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/Add_2_output_0 - - /language_model/model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.21/Add_3_output_0 - - /language_model/model/layers.22/Add_output_0 - - /language_model/model/layers.22/Add_1_output_0 - - /language_model/model/layers.22/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/Add_2_output_0 - - /language_model/model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.22/Add_3_output_0 - - /language_model/model/layers.23/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/Add_output_0 - - /language_model/model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.23/Add_1_output_0 - - /language_model/model/layers.24/Add_output_0 - - /language_model/model/layers.24/Add_1_output_0 - - /language_model/model/layers.24/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/Add_2_output_0 - - /language_model/model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.24/Add_3_output_0 - - /language_model/model/layers.25/Add_output_0 - - /language_model/model/layers.25/Add_1_output_0 - - /language_model/model/layers.25/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/Add_2_output_0 - - /language_model/model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.25/Add_3_output_0 - - /language_model/model/layers.26/Add_output_0 - - /language_model/model/layers.26/Add_1_output_0 - - /language_model/model/layers.26/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/Add_2_output_0 - - /language_model/model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.26/Add_3_output_0 - - /language_model/model/layers.27/Add_output_0 - - /language_model/model/layers.27/Add_1_output_0 - - /language_model/model/layers.27/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/Add_2_output_0 - - /language_model/model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.27/Add_3_output_0 - - /language_model/model/layers.28/Add_output_0 - - /language_model/model/layers.28/Add_1_output_0 - - /language_model/model/layers.28/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/Add_2_output_0 - - /language_model/model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.28/Add_3_output_0 - - /language_model/model/layers.29/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/Add_output_0 - - /language_model/model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.29/Add_1_output_0 - - /language_model/model/layers.30/Add_output_0 - - /language_model/model/layers.30/Add_1_output_0 - - /language_model/model/layers.30/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/Add_2_output_0 - - /language_model/model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.30/Add_3_output_0 - - /language_model/model/layers.31/Add_output_0 - - /language_model/model/layers.31/Add_1_output_0 - - /language_model/model/layers.31/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/Add_2_output_0 - - /language_model/model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.31/Add_3_output_0 - - /language_model/model/layers.32/Add_output_0 - - /language_model/model/layers.32/Add_1_output_0 - - /language_model/model/layers.32/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/Add_2_output_0 - - /language_model/model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.32/Add_3_output_0 - - /language_model/model/layers.33/Add_output_0 - - /language_model/model/layers.33/Add_1_output_0 - - /language_model/model/layers.33/input_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/Add_2_output_0 - - /language_model/model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 - - /language_model/model/layers.33/Add_3_output_0 - - /language_model/model/norm/CustomRMSNorm_output_0 - - /language_model/model/layers.0/self_attn/Mul_output_0 - - /language_model/model/layers.0/self_attn/Mul_1_output_0 - - /language_model/model/layers.0/self_attn/Mul_2_output_0 - - /language_model/model/layers.0/self_attn/Mul_3_output_0 - - /language_model/model/layers.0/self_attn/Mul_4_output_0 - - /language_model/model/layers.0/self_attn/Mul_5_output_0 - - /language_model/model/layers.0/self_attn/Mul_6_output_0 - - /language_model/model/layers.0/self_attn/Mul_7_output_0 - - /language_model/model/layers.0/self_attn/Mul_8_output_0 - - /language_model/model/layers.1/self_attn/Mul_9_output_0 - - /language_model/model/layers.2/self_attn/Mul_output_0 - - /language_model/model/layers.2/self_attn/Mul_1_output_0 - - /language_model/model/layers.2/self_attn/Mul_2_output_0 - - /language_model/model/layers.2/self_attn/Mul_3_output_0 - - /language_model/model/layers.2/self_attn/Mul_4_output_0 - - /language_model/model/layers.2/self_attn/Mul_5_output_0 - - /language_model/model/layers.2/self_attn/Mul_6_output_0 - - /language_model/model/layers.2/self_attn/Mul_7_output_0 - - /language_model/model/layers.2/self_attn/Mul_8_output_0 - - /language_model/model/layers.2/self_attn/Mul_9_output_0 - - /language_model/model/layers.3/self_attn/Mul_output_0 - - /language_model/model/layers.3/self_attn/Mul_1_output_0 - - /language_model/model/layers.3/self_attn/Mul_2_output_0 - - /language_model/model/layers.3/self_attn/Mul_3_output_0 - - /language_model/model/layers.3/self_attn/Mul_4_output_0 - - /language_model/model/layers.3/self_attn/Mul_5_output_0 - - /language_model/model/layers.3/self_attn/Mul_6_output_0 - - /language_model/model/layers.3/self_attn/Mul_7_output_0 - - /language_model/model/layers.3/self_attn/Mul_8_output_0 - - /language_model/model/layers.3/self_attn/Mul_9_output_0 - - /language_model/model/layers.4/self_attn/Mul_output_0 - - /language_model/model/layers.4/self_attn/Mul_1_output_0 - - /language_model/model/layers.4/self_attn/Mul_2_output_0 - - /language_model/model/layers.4/self_attn/Mul_3_output_0 - - /language_model/model/layers.4/self_attn/Mul_4_output_0 - - /language_model/model/layers.4/self_attn/Mul_5_output_0 - - /language_model/model/layers.4/self_attn/Mul_6_output_0 - - /language_model/model/layers.4/self_attn/Mul_7_output_0 - - /language_model/model/layers.4/self_attn/Mul_8_output_0 - - /language_model/model/layers.4/self_attn/Mul_9_output_0 - - /language_model/model/layers.5/self_attn/Mul_output_0 - - /language_model/model/layers.5/self_attn/Mul_1_output_0 - - /language_model/model/layers.5/self_attn/Mul_2_output_0 - - /language_model/model/layers.5/self_attn/Mul_3_output_0 - - /language_model/model/layers.5/self_attn/Mul_4_output_0 - - /language_model/model/layers.5/self_attn/Mul_5_output_0 - - /language_model/model/layers.5/self_attn/Mul_6_output_0 - - /language_model/model/layers.5/self_attn/Mul_7_output_0 - - /language_model/model/layers.5/self_attn/Mul_8_output_0 - - /language_model/model/layers.5/self_attn/Mul_9_output_0 - - /language_model/model/layers.6/self_attn/Mul_output_0 - - /language_model/model/layers.6/self_attn/Mul_1_output_0 - - /language_model/model/layers.6/self_attn/Mul_2_output_0 - - /language_model/model/layers.6/self_attn/Mul_3_output_0 - - /language_model/model/layers.6/self_attn/Mul_4_output_0 - - /language_model/model/layers.6/self_attn/Mul_5_output_0 - - /language_model/model/layers.6/self_attn/Mul_6_output_0 - - /language_model/model/layers.6/self_attn/Mul_7_output_0 - - /language_model/model/layers.6/self_attn/Mul_8_output_0 - - /language_model/model/layers.6/self_attn/Mul_9_output_0 - - /language_model/model/layers.7/self_attn/Mul_output_0 - - /language_model/model/layers.7/self_attn/Mul_1_output_0 - - /language_model/model/layers.7/self_attn/Mul_2_output_0 - - /language_model/model/layers.7/self_attn/Mul_3_output_0 - - /language_model/model/layers.7/self_attn/Mul_4_output_0 - - /language_model/model/layers.7/self_attn/Mul_5_output_0 - - /language_model/model/layers.7/self_attn/Mul_6_output_0 - - /language_model/model/layers.7/self_attn/Mul_7_output_0 - - /language_model/model/layers.7/self_attn/Mul_8_output_0 - - /language_model/model/layers.7/self_attn/Mul_9_output_0 - - /language_model/model/layers.8/self_attn/Mul_output_0 - - /language_model/model/layers.8/self_attn/Mul_1_output_0 - - /language_model/model/layers.8/self_attn/Mul_2_output_0 - - /language_model/model/layers.8/self_attn/Mul_3_output_0 - - /language_model/model/layers.8/self_attn/Mul_4_output_0 - - /language_model/model/layers.8/self_attn/Mul_5_output_0 - - /language_model/model/layers.8/self_attn/Mul_6_output_0 - - /language_model/model/layers.8/self_attn/Mul_7_output_0 - - /language_model/model/layers.8/self_attn/Mul_8_output_0 - - /language_model/model/layers.8/self_attn/Mul_9_output_0 - - /language_model/model/layers.9/self_attn/Mul_output_0 - - /language_model/model/layers.9/self_attn/Mul_1_output_0 - - /language_model/model/layers.9/self_attn/Mul_2_output_0 - - /language_model/model/layers.9/self_attn/Mul_3_output_0 - - /language_model/model/layers.9/self_attn/Mul_4_output_0 - - /language_model/model/layers.9/self_attn/Mul_5_output_0 - - /language_model/model/layers.9/self_attn/Mul_6_output_0 - - /language_model/model/layers.9/self_attn/Mul_7_output_0 - - /language_model/model/layers.9/self_attn/Mul_8_output_0 - - /language_model/model/layers.9/self_attn/Mul_9_output_0 - - /language_model/model/layers.10/self_attn/Mul_output_0 - - /language_model/model/layers.10/self_attn/Mul_1_output_0 - - /language_model/model/layers.10/self_attn/Mul_2_output_0 - - /language_model/model/layers.10/self_attn/Mul_3_output_0 - - /language_model/model/layers.10/self_attn/Mul_4_output_0 - - /language_model/model/layers.10/self_attn/Mul_5_output_0 - - /language_model/model/layers.10/self_attn/Mul_6_output_0 - - /language_model/model/layers.10/self_attn/Mul_7_output_0 - - /language_model/model/layers.10/self_attn/Mul_8_output_0 - - /language_model/model/layers.10/self_attn/Mul_9_output_0 - - /language_model/model/layers.11/self_attn/Mul_output_0 - - /language_model/model/layers.11/self_attn/Mul_1_output_0 - - /language_model/model/layers.11/self_attn/Mul_2_output_0 - - /language_model/model/layers.11/self_attn/Mul_3_output_0 - - /language_model/model/layers.11/self_attn/Mul_4_output_0 - - /language_model/model/layers.11/self_attn/Mul_5_output_0 - - /language_model/model/layers.11/self_attn/Mul_6_output_0 - - /language_model/model/layers.11/self_attn/Mul_7_output_0 - - /language_model/model/layers.11/self_attn/Mul_8_output_0 - - /language_model/model/layers.11/self_attn/Mul_9_output_0 - - /language_model/model/layers.12/self_attn/Mul_output_0 - - /language_model/model/layers.12/self_attn/Mul_1_output_0 - - /language_model/model/layers.12/self_attn/Mul_2_output_0 - - /language_model/model/layers.12/self_attn/Mul_3_output_0 - - /language_model/model/layers.12/self_attn/Mul_4_output_0 - - /language_model/model/layers.12/self_attn/Mul_5_output_0 - - /language_model/model/layers.12/self_attn/Mul_6_output_0 - - /language_model/model/layers.12/self_attn/Mul_7_output_0 - - /language_model/model/layers.12/self_attn/Mul_8_output_0 - - /language_model/model/layers.12/self_attn/Mul_9_output_0 - - /language_model/model/layers.13/self_attn/Mul_output_0 - - /language_model/model/layers.13/self_attn/Mul_1_output_0 - - /language_model/model/layers.13/self_attn/Mul_2_output_0 - - /language_model/model/layers.13/self_attn/Mul_3_output_0 - - /language_model/model/layers.13/self_attn/Mul_4_output_0 - - /language_model/model/layers.13/self_attn/Mul_5_output_0 - - /language_model/model/layers.13/self_attn/Mul_6_output_0 - - /language_model/model/layers.13/self_attn/Mul_7_output_0 - - /language_model/model/layers.13/self_attn/Mul_8_output_0 - - /language_model/model/layers.13/self_attn/Mul_9_output_0 - - /language_model/model/layers.14/self_attn/Mul_output_0 - - /language_model/model/layers.14/self_attn/Mul_1_output_0 - - /language_model/model/layers.14/self_attn/Mul_2_output_0 - - /language_model/model/layers.14/self_attn/Mul_3_output_0 - - /language_model/model/layers.14/self_attn/Mul_4_output_0 - - /language_model/model/layers.14/self_attn/Mul_5_output_0 - - /language_model/model/layers.14/self_attn/Mul_6_output_0 - - /language_model/model/layers.14/self_attn/Mul_7_output_0 - - /language_model/model/layers.14/self_attn/Mul_8_output_0 - - /language_model/model/layers.14/self_attn/Mul_9_output_0 - - /language_model/model/layers.15/self_attn/Mul_output_0 - - /language_model/model/layers.15/self_attn/Mul_1_output_0 - - /language_model/model/layers.15/self_attn/Mul_2_output_0 - - /language_model/model/layers.15/self_attn/Mul_3_output_0 - - /language_model/model/layers.15/self_attn/Mul_4_output_0 - - /language_model/model/layers.15/self_attn/Mul_5_output_0 - - /language_model/model/layers.15/self_attn/Mul_6_output_0 - - /language_model/model/layers.15/self_attn/Mul_7_output_0 - - /language_model/model/layers.15/self_attn/Mul_8_output_0 - - /language_model/model/layers.15/self_attn/Mul_9_output_0 - - /language_model/model/layers.16/self_attn/Mul_output_0 - - /language_model/model/layers.16/self_attn/Mul_1_output_0 - - /language_model/model/layers.16/self_attn/Mul_2_output_0 - - /language_model/model/layers.16/self_attn/Mul_3_output_0 - - /language_model/model/layers.16/self_attn/Mul_4_output_0 - - /language_model/model/layers.16/self_attn/Mul_5_output_0 - - /language_model/model/layers.16/self_attn/Mul_6_output_0 - - /language_model/model/layers.16/self_attn/Mul_7_output_0 - - /language_model/model/layers.16/self_attn/Mul_8_output_0 - - /language_model/model/layers.16/self_attn/Mul_9_output_0 - - /language_model/model/layers.17/self_attn/Mul_output_0 - - /language_model/model/layers.17/self_attn/Mul_1_output_0 - - /language_model/model/layers.17/self_attn/Mul_2_output_0 - - /language_model/model/layers.17/self_attn/Mul_3_output_0 - - /language_model/model/layers.17/self_attn/Mul_4_output_0 - - /language_model/model/layers.17/self_attn/Mul_5_output_0 - - /language_model/model/layers.17/self_attn/Mul_6_output_0 - - /language_model/model/layers.17/self_attn/Mul_7_output_0 - - /language_model/model/layers.17/self_attn/Mul_8_output_0 - - /language_model/model/layers.17/self_attn/Mul_9_output_0 - - /language_model/model/layers.18/self_attn/Mul_output_0 - - /language_model/model/layers.18/self_attn/Mul_1_output_0 - - /language_model/model/layers.18/self_attn/Mul_2_output_0 - - /language_model/model/layers.18/self_attn/Mul_3_output_0 - - /language_model/model/layers.18/self_attn/Mul_4_output_0 - - /language_model/model/layers.18/self_attn/Mul_5_output_0 - - /language_model/model/layers.18/self_attn/Mul_6_output_0 - - /language_model/model/layers.18/self_attn/Mul_7_output_0 - - /language_model/model/layers.18/self_attn/Mul_8_output_0 - - /language_model/model/layers.18/self_attn/Mul_9_output_0 - - /language_model/model/layers.19/self_attn/Mul_output_0 - - /language_model/model/layers.19/self_attn/Mul_1_output_0 - - /language_model/model/layers.19/self_attn/Mul_2_output_0 - - /language_model/model/layers.19/self_attn/Mul_3_output_0 - - /language_model/model/layers.19/self_attn/Mul_4_output_0 - - /language_model/model/layers.19/self_attn/Mul_5_output_0 - - /language_model/model/layers.19/self_attn/Mul_6_output_0 - - /language_model/model/layers.19/self_attn/Mul_7_output_0 - - /language_model/model/layers.19/self_attn/Mul_8_output_0 - - /language_model/model/layers.19/self_attn/Mul_9_output_0 - - /language_model/model/layers.20/self_attn/Mul_output_0 - - /language_model/model/layers.20/self_attn/Mul_1_output_0 - - /language_model/model/layers.20/self_attn/Mul_2_output_0 - - /language_model/model/layers.20/self_attn/Mul_3_output_0 - - /language_model/model/layers.20/self_attn/Mul_4_output_0 - - /language_model/model/layers.20/self_attn/Mul_5_output_0 - - /language_model/model/layers.20/self_attn/Mul_6_output_0 - - /language_model/model/layers.20/self_attn/Mul_7_output_0 - - /language_model/model/layers.20/self_attn/Mul_8_output_0 - - /language_model/model/layers.20/self_attn/Mul_9_output_0 - - /language_model/model/layers.21/self_attn/Mul_output_0 - - /language_model/model/layers.21/self_attn/Mul_1_output_0 - - /language_model/model/layers.21/self_attn/Mul_2_output_0 - - /language_model/model/layers.21/self_attn/Mul_3_output_0 - - /language_model/model/layers.21/self_attn/Mul_4_output_0 - - /language_model/model/layers.21/self_attn/Mul_5_output_0 - - /language_model/model/layers.21/self_attn/Mul_6_output_0 - - /language_model/model/layers.21/self_attn/Mul_7_output_0 - - /language_model/model/layers.21/self_attn/Mul_8_output_0 - - /language_model/model/layers.21/self_attn/Mul_9_output_0 - - /language_model/model/layers.22/self_attn/Mul_output_0 - - /language_model/model/layers.22/self_attn/Mul_1_output_0 - - /language_model/model/layers.22/self_attn/Mul_2_output_0 - - /language_model/model/layers.22/self_attn/Mul_3_output_0 - - /language_model/model/layers.22/self_attn/Mul_4_output_0 - - /language_model/model/layers.22/self_attn/Mul_5_output_0 - - /language_model/model/layers.22/self_attn/Mul_6_output_0 - - /language_model/model/layers.22/self_attn/Mul_7_output_0 - - /language_model/model/layers.22/self_attn/Mul_8_output_0 - - /language_model/model/layers.22/self_attn/Mul_9_output_0 - - /language_model/model/layers.23/self_attn/Mul_output_0 - - /language_model/model/layers.23/self_attn/Mul_1_output_0 - - /language_model/model/layers.23/self_attn/Mul_2_output_0 - - /language_model/model/layers.23/self_attn/Mul_3_output_0 - - /language_model/model/layers.23/self_attn/Mul_4_output_0 - - /language_model/model/layers.23/self_attn/Mul_5_output_0 - - /language_model/model/layers.23/self_attn/Mul_6_output_0 - - /language_model/model/layers.23/self_attn/Mul_7_output_0 - - /language_model/model/layers.23/self_attn/Mul_8_output_0 - - /language_model/model/layers.23/self_attn/Mul_9_output_0 - - /language_model/model/layers.24/self_attn/Mul_output_0 - - /language_model/model/layers.24/self_attn/Mul_1_output_0 - - /language_model/model/layers.24/self_attn/Mul_2_output_0 - - /language_model/model/layers.24/self_attn/Mul_3_output_0 - - /language_model/model/layers.24/self_attn/Mul_4_output_0 - - /language_model/model/layers.24/self_attn/Mul_5_output_0 - - /language_model/model/layers.24/self_attn/Mul_6_output_0 - - /language_model/model/layers.24/self_attn/Mul_7_output_0 - - /language_model/model/layers.24/self_attn/Mul_8_output_0 - - /language_model/model/layers.24/self_attn/Mul_9_output_0 - - /language_model/model/layers.25/self_attn/Mul_output_0 - - /language_model/model/layers.25/self_attn/Mul_1_output_0 - - /language_model/model/layers.25/self_attn/Mul_2_output_0 - - /language_model/model/layers.25/self_attn/Mul_3_output_0 - - /language_model/model/layers.25/self_attn/Mul_4_output_0 - - /language_model/model/layers.25/self_attn/Mul_5_output_0 - - /language_model/model/layers.25/self_attn/Mul_6_output_0 - - /language_model/model/layers.25/self_attn/Mul_7_output_0 - - /language_model/model/layers.25/self_attn/Mul_8_output_0 - - /language_model/model/layers.25/self_attn/Mul_9_output_0 - - /language_model/model/layers.26/self_attn/Mul_output_0 - - /language_model/model/layers.26/self_attn/Mul_1_output_0 - - /language_model/model/layers.26/self_attn/Mul_2_output_0 - - /language_model/model/layers.26/self_attn/Mul_3_output_0 - - /language_model/model/layers.26/self_attn/Mul_4_output_0 - - /language_model/model/layers.26/self_attn/Mul_5_output_0 - - /language_model/model/layers.26/self_attn/Mul_6_output_0 - - /language_model/model/layers.26/self_attn/Mul_7_output_0 - - /language_model/model/layers.26/self_attn/Mul_8_output_0 - - /language_model/model/layers.26/self_attn/Mul_9_output_0 - - /language_model/model/layers.27/self_attn/Mul_output_0 - - /language_model/model/layers.27/self_attn/Mul_1_output_0 - - /language_model/model/layers.27/self_attn/Mul_2_output_0 - - /language_model/model/layers.27/self_attn/Mul_3_output_0 - - /language_model/model/layers.27/self_attn/Mul_4_output_0 - - /language_model/model/layers.27/self_attn/Mul_5_output_0 - - /language_model/model/layers.27/self_attn/Mul_6_output_0 - - /language_model/model/layers.27/self_attn/Mul_7_output_0 - - /language_model/model/layers.27/self_attn/Mul_8_output_0 - - /language_model/model/layers.27/self_attn/Mul_9_output_0 - - /language_model/model/layers.28/self_attn/Mul_output_0 - - /language_model/model/layers.28/self_attn/Mul_1_output_0 - - /language_model/model/layers.28/self_attn/Mul_2_output_0 - - /language_model/model/layers.28/self_attn/Mul_3_output_0 - - /language_model/model/layers.28/self_attn/Mul_4_output_0 - - /language_model/model/layers.28/self_attn/Mul_5_output_0 - - /language_model/model/layers.28/self_attn/Mul_6_output_0 - - /language_model/model/layers.28/self_attn/Mul_7_output_0 - - /language_model/model/layers.28/self_attn/Mul_8_output_0 - - /language_model/model/layers.28/self_attn/Mul_9_output_0 - - /language_model/model/layers.29/self_attn/Mul_output_0 - - /language_model/model/layers.29/self_attn/Mul_1_output_0 - - /language_model/model/layers.29/self_attn/Mul_2_output_0 - - /language_model/model/layers.29/self_attn/Mul_3_output_0 - - /language_model/model/layers.29/self_attn/Mul_4_output_0 - - /language_model/model/layers.29/self_attn/Mul_5_output_0 - - /language_model/model/layers.29/self_attn/Mul_6_output_0 - - /language_model/model/layers.29/self_attn/Mul_7_output_0 - - /language_model/model/layers.29/self_attn/Mul_8_output_0 - - /language_model/model/layers.29/self_attn/Mul_9_output_0 - - /language_model/model/layers.30/self_attn/Mul_output_0 - - /language_model/model/layers.30/self_attn/Mul_1_output_0 - - /language_model/model/layers.30/self_attn/Mul_2_output_0 - - /language_model/model/layers.30/self_attn/Mul_3_output_0 - - /language_model/model/layers.30/self_attn/Mul_4_output_0 - - /language_model/model/layers.30/self_attn/Mul_5_output_0 - - /language_model/model/layers.30/self_attn/Mul_6_output_0 - - /language_model/model/layers.30/self_attn/Mul_7_output_0 - - /language_model/model/layers.30/self_attn/Mul_8_output_0 - - /language_model/model/layers.30/self_attn/Mul_9_output_0 - - /language_model/model/layers.31/self_attn/Mul_output_0 - - /language_model/model/layers.31/self_attn/Mul_1_output_0 - - /language_model/model/layers.31/self_attn/Mul_2_output_0 - - /language_model/model/layers.31/self_attn/Mul_3_output_0 - - /language_model/model/layers.31/self_attn/Mul_4_output_0 - - /language_model/model/layers.31/self_attn/Mul_5_output_0 - - /language_model/model/layers.31/self_attn/Mul_6_output_0 - - /language_model/model/layers.31/self_attn/Mul_7_output_0 - - /language_model/model/layers.31/self_attn/Mul_8_output_0 - - /language_model/model/layers.31/self_attn/Mul_9_output_0 - - /language_model/model/layers.32/self_attn/Mul_output_0 - - /language_model/model/layers.32/self_attn/Mul_1_output_0 - - /language_model/model/layers.32/self_attn/Mul_2_output_0 - - /language_model/model/layers.32/self_attn/Mul_3_output_0 - - /language_model/model/layers.32/self_attn/Mul_4_output_0 - - /language_model/model/layers.32/self_attn/Mul_5_output_0 - - /language_model/model/layers.32/self_attn/Mul_6_output_0 - - /language_model/model/layers.32/self_attn/Mul_7_output_0 - - /language_model/model/layers.32/self_attn/Mul_8_output_0 - - /language_model/model/layers.32/self_attn/Mul_9_output_0 - - /language_model/model/layers.33/self_attn/Mul_output_0 - - /language_model/model/layers.33/self_attn/Mul_1_output_0 - - /language_model/model/layers.33/self_attn/Mul_2_output_0 - - /language_model/model/layers.33/self_attn/Mul_3_output_0 - - /language_model/model/layers.33/self_attn/Mul_4_output_0 - - /language_model/model/layers.33/self_attn/Mul_5_output_0 - - /language_model/model/layers.33/self_attn/Mul_6_output_0 - - /language_model/model/layers.33/self_attn/Mul_7_output_0 - - /language_model/model/layers.33/self_attn/Mul_8_output_0 - - /language_model/model/layers.33/self_attn/Mul_9_output_0 - - /language_model/model/layers.0/self_attn/Softmax_output_0 - - /language_model/model/layers.1/self_attn/Softmax_output_0 - - /language_model/model/layers.2/self_attn/Softmax_output_0 - - /language_model/model/layers.3/self_attn/Softmax_output_0 - - /language_model/model/layers.4/self_attn/Softmax_output_0 - - /language_model/model/layers.5/self_attn/Softmax_output_0 - - /language_model/model/layers.6/self_attn/Softmax_output_0 - - /language_model/model/layers.7/self_attn/Softmax_output_0 - - /language_model/model/layers.8/self_attn/Softmax_output_0 - - /language_model/model/layers.9/self_attn/Softmax_output_0 - - /language_model/model/layers.10/self_attn/Softmax_output_0 - - /language_model/model/layers.11/self_attn/Softmax_output_0 - - /language_model/model/layers.12/self_attn/Softmax_output_0 - - /language_model/model/layers.13/self_attn/Softmax_output_0 - - /language_model/model/layers.14/self_attn/Softmax_output_0 - - /language_model/model/layers.15/self_attn/Softmax_output_0 - - /language_model/model/layers.16/self_attn/Softmax_output_0 - - /language_model/model/layers.17/self_attn/Softmax_output_0 - - /language_model/model/layers.18/self_attn/Softmax_output_0 - - /language_model/model/layers.19/self_attn/Softmax_output_0 - - /language_model/model/layers.20/self_attn/Softmax_output_0 - - /language_model/model/layers.21/self_attn/Softmax_output_0 - - /language_model/model/layers.22/self_attn/Softmax_output_0 - - /language_model/model/layers.23/self_attn/Softmax_output_0 - - /language_model/model/layers.24/self_attn/Softmax_output_0 - - /language_model/model/layers.25/self_attn/Softmax_output_0 - - /language_model/model/layers.26/self_attn/Softmax_output_0 - - /language_model/model/layers.27/self_attn/Softmax_output_0 - - /language_model/model/layers.28/self_attn/Softmax_output_0 - - /language_model/model/layers.29/self_attn/Softmax_output_0 - - /language_model/model/layers.30/self_attn/Softmax_output_0 - - /language_model/model/layers.31/self_attn/Softmax_output_0 - - /language_model/model/layers.32/self_attn/Softmax_output_0 - - /language_model/model/layers.33/self_attn/Softmax_output_0 \ No newline at end of file + - /language_model/layers.0/Add_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.0/self_attn/Mul_1_output_0 + - /language_model/layers.0/self_attn/Mul_2_output_0 + - /language_model/layers.0/self_attn/Mul_3_output_0 + - /language_model/layers.0/self_attn/Mul_4_output_0 + - /language_model/layers.0/self_attn/Mul_5_output_0 + - /language_model/layers.0/self_attn/Mul_6_output_0 + - /language_model/layers.0/self_attn/Mul_7_output_0 + - /language_model/layers.0/self_attn/Mul_8_output_0 + - /language_model/layers.1/self_attn/Mul_9_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_1_output_0 + - /language_model/layers.2/self_attn/Mul_2_output_0 + - /language_model/layers.2/self_attn/Mul_3_output_0 + - /language_model/layers.2/self_attn/Mul_4_output_0 + - /language_model/layers.2/self_attn/Mul_5_output_0 + - /language_model/layers.2/self_attn/Mul_6_output_0 + - /language_model/layers.2/self_attn/Mul_7_output_0 + - /language_model/layers.2/self_attn/Mul_8_output_0 + - /language_model/layers.2/self_attn/Mul_9_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_1_output_0 + - /language_model/layers.3/self_attn/Mul_2_output_0 + - /language_model/layers.3/self_attn/Mul_3_output_0 + - /language_model/layers.3/self_attn/Mul_4_output_0 + - /language_model/layers.3/self_attn/Mul_5_output_0 + - /language_model/layers.3/self_attn/Mul_6_output_0 + - /language_model/layers.3/self_attn/Mul_7_output_0 + - /language_model/layers.3/self_attn/Mul_8_output_0 + - /language_model/layers.3/self_attn/Mul_9_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_1_output_0 + - /language_model/layers.4/self_attn/Mul_2_output_0 + - /language_model/layers.4/self_attn/Mul_3_output_0 + - /language_model/layers.4/self_attn/Mul_4_output_0 + - /language_model/layers.4/self_attn/Mul_5_output_0 + - /language_model/layers.4/self_attn/Mul_6_output_0 + - /language_model/layers.4/self_attn/Mul_7_output_0 + - /language_model/layers.4/self_attn/Mul_8_output_0 + - /language_model/layers.4/self_attn/Mul_9_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_1_output_0 + - /language_model/layers.5/self_attn/Mul_2_output_0 + - /language_model/layers.5/self_attn/Mul_3_output_0 + - /language_model/layers.5/self_attn/Mul_4_output_0 + - /language_model/layers.5/self_attn/Mul_5_output_0 + - /language_model/layers.5/self_attn/Mul_6_output_0 + - /language_model/layers.5/self_attn/Mul_7_output_0 + - /language_model/layers.5/self_attn/Mul_8_output_0 + - /language_model/layers.5/self_attn/Mul_9_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_1_output_0 + - /language_model/layers.6/self_attn/Mul_2_output_0 + - /language_model/layers.6/self_attn/Mul_3_output_0 + - /language_model/layers.6/self_attn/Mul_4_output_0 + - /language_model/layers.6/self_attn/Mul_5_output_0 + - /language_model/layers.6/self_attn/Mul_6_output_0 + - /language_model/layers.6/self_attn/Mul_7_output_0 + - /language_model/layers.6/self_attn/Mul_8_output_0 + - /language_model/layers.6/self_attn/Mul_9_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_1_output_0 + - /language_model/layers.7/self_attn/Mul_2_output_0 + - /language_model/layers.7/self_attn/Mul_3_output_0 + - /language_model/layers.7/self_attn/Mul_4_output_0 + - /language_model/layers.7/self_attn/Mul_5_output_0 + - /language_model/layers.7/self_attn/Mul_6_output_0 + - /language_model/layers.7/self_attn/Mul_7_output_0 + - /language_model/layers.7/self_attn/Mul_8_output_0 + - /language_model/layers.7/self_attn/Mul_9_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_1_output_0 + - /language_model/layers.8/self_attn/Mul_2_output_0 + - /language_model/layers.8/self_attn/Mul_3_output_0 + - /language_model/layers.8/self_attn/Mul_4_output_0 + - /language_model/layers.8/self_attn/Mul_5_output_0 + - /language_model/layers.8/self_attn/Mul_6_output_0 + - /language_model/layers.8/self_attn/Mul_7_output_0 + - /language_model/layers.8/self_attn/Mul_8_output_0 + - /language_model/layers.8/self_attn/Mul_9_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_1_output_0 + - /language_model/layers.9/self_attn/Mul_2_output_0 + - /language_model/layers.9/self_attn/Mul_3_output_0 + - /language_model/layers.9/self_attn/Mul_4_output_0 + - /language_model/layers.9/self_attn/Mul_5_output_0 + - /language_model/layers.9/self_attn/Mul_6_output_0 + - /language_model/layers.9/self_attn/Mul_7_output_0 + - /language_model/layers.9/self_attn/Mul_8_output_0 + - /language_model/layers.9/self_attn/Mul_9_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_1_output_0 + - /language_model/layers.10/self_attn/Mul_2_output_0 + - /language_model/layers.10/self_attn/Mul_3_output_0 + - /language_model/layers.10/self_attn/Mul_4_output_0 + - /language_model/layers.10/self_attn/Mul_5_output_0 + - /language_model/layers.10/self_attn/Mul_6_output_0 + - /language_model/layers.10/self_attn/Mul_7_output_0 + - /language_model/layers.10/self_attn/Mul_8_output_0 + - /language_model/layers.10/self_attn/Mul_9_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_1_output_0 + - /language_model/layers.11/self_attn/Mul_2_output_0 + - /language_model/layers.11/self_attn/Mul_3_output_0 + - /language_model/layers.11/self_attn/Mul_4_output_0 + - /language_model/layers.11/self_attn/Mul_5_output_0 + - /language_model/layers.11/self_attn/Mul_6_output_0 + - /language_model/layers.11/self_attn/Mul_7_output_0 + - /language_model/layers.11/self_attn/Mul_8_output_0 + - /language_model/layers.11/self_attn/Mul_9_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_1_output_0 + - /language_model/layers.12/self_attn/Mul_2_output_0 + - /language_model/layers.12/self_attn/Mul_3_output_0 + - /language_model/layers.12/self_attn/Mul_4_output_0 + - /language_model/layers.12/self_attn/Mul_5_output_0 + - /language_model/layers.12/self_attn/Mul_6_output_0 + - /language_model/layers.12/self_attn/Mul_7_output_0 + - /language_model/layers.12/self_attn/Mul_8_output_0 + - /language_model/layers.12/self_attn/Mul_9_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_1_output_0 + - /language_model/layers.13/self_attn/Mul_2_output_0 + - /language_model/layers.13/self_attn/Mul_3_output_0 + - /language_model/layers.13/self_attn/Mul_4_output_0 + - /language_model/layers.13/self_attn/Mul_5_output_0 + - /language_model/layers.13/self_attn/Mul_6_output_0 + - /language_model/layers.13/self_attn/Mul_7_output_0 + - /language_model/layers.13/self_attn/Mul_8_output_0 + - /language_model/layers.13/self_attn/Mul_9_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_1_output_0 + - /language_model/layers.14/self_attn/Mul_2_output_0 + - /language_model/layers.14/self_attn/Mul_3_output_0 + - /language_model/layers.14/self_attn/Mul_4_output_0 + - /language_model/layers.14/self_attn/Mul_5_output_0 + - /language_model/layers.14/self_attn/Mul_6_output_0 + - /language_model/layers.14/self_attn/Mul_7_output_0 + - /language_model/layers.14/self_attn/Mul_8_output_0 + - /language_model/layers.14/self_attn/Mul_9_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_1_output_0 + - /language_model/layers.15/self_attn/Mul_2_output_0 + - /language_model/layers.15/self_attn/Mul_3_output_0 + - /language_model/layers.15/self_attn/Mul_4_output_0 + - /language_model/layers.15/self_attn/Mul_5_output_0 + - /language_model/layers.15/self_attn/Mul_6_output_0 + - /language_model/layers.15/self_attn/Mul_7_output_0 + - /language_model/layers.15/self_attn/Mul_8_output_0 + - /language_model/layers.15/self_attn/Mul_9_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_1_output_0 + - /language_model/layers.16/self_attn/Mul_2_output_0 + - /language_model/layers.16/self_attn/Mul_3_output_0 + - /language_model/layers.16/self_attn/Mul_4_output_0 + - /language_model/layers.16/self_attn/Mul_5_output_0 + - /language_model/layers.16/self_attn/Mul_6_output_0 + - /language_model/layers.16/self_attn/Mul_7_output_0 + - /language_model/layers.16/self_attn/Mul_8_output_0 + - /language_model/layers.16/self_attn/Mul_9_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_1_output_0 + - /language_model/layers.17/self_attn/Mul_2_output_0 + - /language_model/layers.17/self_attn/Mul_3_output_0 + - /language_model/layers.17/self_attn/Mul_4_output_0 + - /language_model/layers.17/self_attn/Mul_5_output_0 + - /language_model/layers.17/self_attn/Mul_6_output_0 + - /language_model/layers.17/self_attn/Mul_7_output_0 + - /language_model/layers.17/self_attn/Mul_8_output_0 + - /language_model/layers.17/self_attn/Mul_9_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_1_output_0 + - /language_model/layers.18/self_attn/Mul_2_output_0 + - /language_model/layers.18/self_attn/Mul_3_output_0 + - /language_model/layers.18/self_attn/Mul_4_output_0 + - /language_model/layers.18/self_attn/Mul_5_output_0 + - /language_model/layers.18/self_attn/Mul_6_output_0 + - /language_model/layers.18/self_attn/Mul_7_output_0 + - /language_model/layers.18/self_attn/Mul_8_output_0 + - /language_model/layers.18/self_attn/Mul_9_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_1_output_0 + - /language_model/layers.19/self_attn/Mul_2_output_0 + - /language_model/layers.19/self_attn/Mul_3_output_0 + - /language_model/layers.19/self_attn/Mul_4_output_0 + - /language_model/layers.19/self_attn/Mul_5_output_0 + - /language_model/layers.19/self_attn/Mul_6_output_0 + - /language_model/layers.19/self_attn/Mul_7_output_0 + - /language_model/layers.19/self_attn/Mul_8_output_0 + - /language_model/layers.19/self_attn/Mul_9_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_1_output_0 + - /language_model/layers.20/self_attn/Mul_2_output_0 + - /language_model/layers.20/self_attn/Mul_3_output_0 + - /language_model/layers.20/self_attn/Mul_4_output_0 + - /language_model/layers.20/self_attn/Mul_5_output_0 + - /language_model/layers.20/self_attn/Mul_6_output_0 + - /language_model/layers.20/self_attn/Mul_7_output_0 + - /language_model/layers.20/self_attn/Mul_8_output_0 + - /language_model/layers.20/self_attn/Mul_9_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_1_output_0 + - /language_model/layers.21/self_attn/Mul_2_output_0 + - /language_model/layers.21/self_attn/Mul_3_output_0 + - /language_model/layers.21/self_attn/Mul_4_output_0 + - /language_model/layers.21/self_attn/Mul_5_output_0 + - /language_model/layers.21/self_attn/Mul_6_output_0 + - /language_model/layers.21/self_attn/Mul_7_output_0 + - /language_model/layers.21/self_attn/Mul_8_output_0 + - /language_model/layers.21/self_attn/Mul_9_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_1_output_0 + - /language_model/layers.22/self_attn/Mul_2_output_0 + - /language_model/layers.22/self_attn/Mul_3_output_0 + - /language_model/layers.22/self_attn/Mul_4_output_0 + - /language_model/layers.22/self_attn/Mul_5_output_0 + - /language_model/layers.22/self_attn/Mul_6_output_0 + - /language_model/layers.22/self_attn/Mul_7_output_0 + - /language_model/layers.22/self_attn/Mul_8_output_0 + - /language_model/layers.22/self_attn/Mul_9_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_1_output_0 + - /language_model/layers.23/self_attn/Mul_2_output_0 + - /language_model/layers.23/self_attn/Mul_3_output_0 + - /language_model/layers.23/self_attn/Mul_4_output_0 + - /language_model/layers.23/self_attn/Mul_5_output_0 + - /language_model/layers.23/self_attn/Mul_6_output_0 + - /language_model/layers.23/self_attn/Mul_7_output_0 + - /language_model/layers.23/self_attn/Mul_8_output_0 + - /language_model/layers.23/self_attn/Mul_9_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_1_output_0 + - /language_model/layers.24/self_attn/Mul_2_output_0 + - /language_model/layers.24/self_attn/Mul_3_output_0 + - /language_model/layers.24/self_attn/Mul_4_output_0 + - /language_model/layers.24/self_attn/Mul_5_output_0 + - /language_model/layers.24/self_attn/Mul_6_output_0 + - /language_model/layers.24/self_attn/Mul_7_output_0 + - /language_model/layers.24/self_attn/Mul_8_output_0 + - /language_model/layers.24/self_attn/Mul_9_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_1_output_0 + - /language_model/layers.25/self_attn/Mul_2_output_0 + - /language_model/layers.25/self_attn/Mul_3_output_0 + - /language_model/layers.25/self_attn/Mul_4_output_0 + - /language_model/layers.25/self_attn/Mul_5_output_0 + - /language_model/layers.25/self_attn/Mul_6_output_0 + - /language_model/layers.25/self_attn/Mul_7_output_0 + - /language_model/layers.25/self_attn/Mul_8_output_0 + - /language_model/layers.25/self_attn/Mul_9_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_1_output_0 + - /language_model/layers.26/self_attn/Mul_2_output_0 + - /language_model/layers.26/self_attn/Mul_3_output_0 + - /language_model/layers.26/self_attn/Mul_4_output_0 + - /language_model/layers.26/self_attn/Mul_5_output_0 + - /language_model/layers.26/self_attn/Mul_6_output_0 + - /language_model/layers.26/self_attn/Mul_7_output_0 + - /language_model/layers.26/self_attn/Mul_8_output_0 + - /language_model/layers.26/self_attn/Mul_9_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_1_output_0 + - /language_model/layers.27/self_attn/Mul_2_output_0 + - /language_model/layers.27/self_attn/Mul_3_output_0 + - /language_model/layers.27/self_attn/Mul_4_output_0 + - /language_model/layers.27/self_attn/Mul_5_output_0 + - /language_model/layers.27/self_attn/Mul_6_output_0 + - /language_model/layers.27/self_attn/Mul_7_output_0 + - /language_model/layers.27/self_attn/Mul_8_output_0 + - /language_model/layers.27/self_attn/Mul_9_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_1_output_0 + - /language_model/layers.28/self_attn/Mul_2_output_0 + - /language_model/layers.28/self_attn/Mul_3_output_0 + - /language_model/layers.28/self_attn/Mul_4_output_0 + - /language_model/layers.28/self_attn/Mul_5_output_0 + - /language_model/layers.28/self_attn/Mul_6_output_0 + - /language_model/layers.28/self_attn/Mul_7_output_0 + - /language_model/layers.28/self_attn/Mul_8_output_0 + - /language_model/layers.28/self_attn/Mul_9_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_1_output_0 + - /language_model/layers.29/self_attn/Mul_2_output_0 + - /language_model/layers.29/self_attn/Mul_3_output_0 + - /language_model/layers.29/self_attn/Mul_4_output_0 + - /language_model/layers.29/self_attn/Mul_5_output_0 + - /language_model/layers.29/self_attn/Mul_6_output_0 + - /language_model/layers.29/self_attn/Mul_7_output_0 + - /language_model/layers.29/self_attn/Mul_8_output_0 + - /language_model/layers.29/self_attn/Mul_9_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_1_output_0 + - /language_model/layers.30/self_attn/Mul_2_output_0 + - /language_model/layers.30/self_attn/Mul_3_output_0 + - /language_model/layers.30/self_attn/Mul_4_output_0 + - /language_model/layers.30/self_attn/Mul_5_output_0 + - /language_model/layers.30/self_attn/Mul_6_output_0 + - /language_model/layers.30/self_attn/Mul_7_output_0 + - /language_model/layers.30/self_attn/Mul_8_output_0 + - /language_model/layers.30/self_attn/Mul_9_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_1_output_0 + - /language_model/layers.31/self_attn/Mul_2_output_0 + - /language_model/layers.31/self_attn/Mul_3_output_0 + - /language_model/layers.31/self_attn/Mul_4_output_0 + - /language_model/layers.31/self_attn/Mul_5_output_0 + - /language_model/layers.31/self_attn/Mul_6_output_0 + - /language_model/layers.31/self_attn/Mul_7_output_0 + - /language_model/layers.31/self_attn/Mul_8_output_0 + - /language_model/layers.31/self_attn/Mul_9_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_1_output_0 + - /language_model/layers.32/self_attn/Mul_2_output_0 + - /language_model/layers.32/self_attn/Mul_3_output_0 + - /language_model/layers.32/self_attn/Mul_4_output_0 + - /language_model/layers.32/self_attn/Mul_5_output_0 + - /language_model/layers.32/self_attn/Mul_6_output_0 + - /language_model/layers.32/self_attn/Mul_7_output_0 + - /language_model/layers.32/self_attn/Mul_8_output_0 + - /language_model/layers.32/self_attn/Mul_9_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_1_output_0 + - /language_model/layers.33/self_attn/Mul_2_output_0 + - /language_model/layers.33/self_attn/Mul_3_output_0 + - /language_model/layers.33/self_attn/Mul_4_output_0 + - /language_model/layers.33/self_attn/Mul_5_output_0 + - /language_model/layers.33/self_attn/Mul_6_output_0 + - /language_model/layers.33/self_attn/Mul_7_output_0 + - /language_model/layers.33/self_attn/Mul_8_output_0 + - /language_model/layers.33/self_attn/Mul_9_output_0 + - /language_model/layers.0/self_attn/Softmax_output_0 + - /language_model/layers.1/self_attn/Softmax_output_0 + - /language_model/layers.2/self_attn/Softmax_output_0 + - /language_model/layers.3/self_attn/Softmax_output_0 + - /language_model/layers.4/self_attn/Softmax_output_0 + - /language_model/layers.5/self_attn/Softmax_output_0 + - /language_model/layers.6/self_attn/Softmax_output_0 + - /language_model/layers.7/self_attn/Softmax_output_0 + - /language_model/layers.8/self_attn/Softmax_output_0 + - /language_model/layers.9/self_attn/Softmax_output_0 + - /language_model/layers.10/self_attn/Softmax_output_0 + - /language_model/layers.11/self_attn/Softmax_output_0 + - /language_model/layers.12/self_attn/Softmax_output_0 + - /language_model/layers.13/self_attn/Softmax_output_0 + - /language_model/layers.14/self_attn/Softmax_output_0 + - /language_model/layers.15/self_attn/Softmax_output_0 + - /language_model/layers.16/self_attn/Softmax_output_0 + - /language_model/layers.17/self_attn/Softmax_output_0 + - /language_model/layers.18/self_attn/Softmax_output_0 + - /language_model/layers.19/self_attn/Softmax_output_0 + - /language_model/layers.20/self_attn/Softmax_output_0 + - /language_model/layers.21/self_attn/Softmax_output_0 + - /language_model/layers.22/self_attn/Softmax_output_0 + - /language_model/layers.23/self_attn/Softmax_output_0 + - /language_model/layers.24/self_attn/Softmax_output_0 + - /language_model/layers.25/self_attn/Softmax_output_0 + - /language_model/layers.26/self_attn/Softmax_output_0 + - /language_model/layers.27/self_attn/Softmax_output_0 + - /language_model/layers.28/self_attn/Softmax_output_0 + - /language_model/layers.29/self_attn/Softmax_output_0 + - /language_model/layers.30/self_attn/Softmax_output_0 + - /language_model/layers.31/self_attn/Softmax_output_0 + - /language_model/layers.32/self_attn/Softmax_output_0 + - /language_model/layers.33/self_attn/Softmax_output_0 + diff --git a/examples/intern_example/readme.md b/examples/intern_example/readme.md index 1e58482a0..6b0b674c9 100644 --- a/examples/intern_example/readme.md +++ b/examples/intern_example/readme.md @@ -2,14 +2,14 @@ This directory contains an example script of how to run inference on InternVL-1B model via QEFFAutoModelForCausalLM class. ## Required packages: -- `torch==2.4.1+cpu` -- `torchvision==0.19.1+cpu` +- `torch==2.7.0+cpu` +- `torchvision==0.22.0+cpu` - `timm==1.0.14` - `einops==0.8.1` You can install them using pip: ```sh -pip install torch==2.4.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.19.1+cpu einops==0.8.1 +pip install torch==2.7.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 ``` To run example script after package installations: diff --git a/pyproject.toml b/pyproject.toml index 479736c22..ea3c3405d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.51.3", - "huggingface-hub==0.30.0", + "transformers==4.55.0", + "huggingface-hub==0.34.0", "hf_transfer==0.1.9", "peft==0.13.2", "datasets==2.20.0", @@ -39,11 +39,11 @@ dependencies = [ "fire", "py7zr", "torchmetrics==1.7.0", - "torch==2.4.1; platform_machine=='aarch64'", + "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", ] [project.optional-dependencies] diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 2b69824bb..ec8dff25a 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -24,7 +24,7 @@ pipeline { pip install .[test] && pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing - pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.19.1+cpu einops==0.8.1 && #packages to load VLMs + pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl && # For finetuning tests rm -rf QEfficient" ''' diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index dc6308531..ddf3a68c7 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -143,6 +143,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -169,8 +170,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( Constants.PROMPT_LEN, Constants.CTX_LEN, ) - - if model_name not in ModelConfig.SWIFTKV_MODELS: + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) is_tlm = False if num_speculative_tokens is None else True @@ -232,10 +232,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( full_batch_size, ) - if model_name not in ModelConfig.SWIFTKV_MODELS: + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf) pytorch_hf_tokens = np.vstack(pytorch_hf_tokens) + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)] + qeff_model = QEFFAutoModelForCausalLM( model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name ) @@ -319,11 +322,17 @@ def test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, custom_causa """ config = custom_causal_model_config_dict.get(model_name) + # Using fixed reference tokens for external models for specific test cases. + # These tokens are hardcoded, therefore will not match if the model config changes. + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_custom_case"] + if model_name in ModelConfig.QUANTIZED_MODELS: n_layer = get_custom_n_layers(model_name) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens) else: - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=config) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, config=config, pytorch_hf_tokens=pytorch_hf_tokens) @pytest.mark.nightly @@ -337,7 +346,15 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ n_layer = get_custom_n_layers(model_name) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) + # Using fixed reference tokens for external models for specific test cases. + # These tokens are hardcoded, therefore will not match if the model config changes. + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"] + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, pytorch_hf_tokens=pytorch_hf_tokens + ) @pytest.mark.on_qaic diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index be0e84d23..e4f546de6 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -66,28 +66,29 @@ "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", 1, ), - ( - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - True, - 1, - 128, - 3072, - 336, - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", - "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", - 4, - ), - ( - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - False, - 1, - 128, - 3072, - 336, - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", - "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", - 4, - ), + # Disabled in CI due to performance issues + # ( + # "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # True, + # 1, + # 128, + # 3072, + # 336, + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", + # "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", + # 4, + # ), + # ( + # "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # False, + # 1, + # 128, + # 3072, + # 336, + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", + # "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", + # 4, + # ), ( "google/gemma-3-4b-it", True, @@ -295,7 +296,12 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) config._attn_implementation = "eager" config = set_num_layers(config, n_layer=n_layer) - model_hf, _ = load_image_text_to_text_model(config) + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, + ) n_layer = get_num_layers_vlm(config) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)