From 02ca4de5862cf3c4da496abc2bb41ba36b2379e6 Mon Sep 17 00:00:00 2001 From: Daniel Huang Date: Thu, 23 Jan 2025 17:33:48 -0800 Subject: [PATCH] Integrate KV cache Signed-off-by: Daniel Huang --- .../models/snowflake/modeling_arctic.py | 511 +++++++++--------- 1 file changed, 269 insertions(+), 242 deletions(-) diff --git a/optimum/habana/transformers/models/snowflake/modeling_arctic.py b/optimum/habana/transformers/models/snowflake/modeling_arctic.py index feed57f60e..e326a039e8 100644 --- a/optimum/habana/transformers/models/snowflake/modeling_arctic.py +++ b/optimum/habana/transformers/models/snowflake/modeling_arctic.py @@ -25,21 +25,21 @@ - Added mark steps """ -import copy +import contextlib import inspect import math import re import warnings from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import GenerationMixin from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, @@ -56,7 +56,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -67,10 +66,10 @@ GaudiLlamaLinearScalingRotaryEmbedding, GaudiLlamaRotaryEmbedding, ) -from .configuration_arctic import ArcticConfig +from ..mixtral.modeling_mixtral import GaudiMixtralAttentionLongSequence from ..modeling_all_models import KVCache, apply_customized_rope_module +from .configuration_arctic import ArcticConfig -import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -90,20 +89,26 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +try: + from habana_frameworks.torch.hpu import sdp_kernel + + SDPContext = True +except ImportError: + SDPContext = False if is_deepspeed_available(): from deepspeed.moe.layer import MoE # Note that below will crash if there is an available deepspeed that does not have ds_linear. try: - import deepspeed.linear as ds_linear + pass except Exception: pass else: MoE = None if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn import flash_attn_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) @@ -239,10 +244,7 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states): """ - Modified from original ArcticRMS implementation: - - Use Habana fused RMSNorm - - Modifications copied from ../llama/modeling_llama.py:gaudi_llama_rmsnorm_forward() + Copied from optimum/habana/transformers/models/llama/modeling_llama.py gaudi_llama_rmsnorm_forward """ if hidden_states.device.type == "hpu" and FusedRMSNorm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype @@ -286,8 +288,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -295,8 +297,8 @@ def forward(self, x, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self._cos_cached[:seq_len].to(dtype=x.dtype), + self._sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -337,7 +339,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed, k_embed -# Copied from ../llama/modeling_llama.py gaudi_llama_repeat_kv() +# Copied from optimum/habana/transformers/models/llama/modeling_llama.py gaudi_llama_repeat_kv() def repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -391,6 +393,11 @@ def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwar "when creating this class." ) + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.block_size = 1024 + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -400,54 +407,34 @@ def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwar self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - deepspeed_lora_config = kwargs.get(DEEPSPEED_LORA_CONFIG) - quantization_config = kwargs.get(QUANTIZATION_CONFIG, None) - - self.q_proj = get_arctic_linear( + self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, dtype=torch.bfloat16, ) - self.k_proj = get_arctic_linear( + self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, dtype=torch.bfloat16, ) - self.v_proj = get_arctic_linear( + self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, dtype=torch.bfloat16, ) - self.o_proj = get_arctic_linear( + self.o_proj = nn.Linear( self.hidden_size, self.hidden_size, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, dtype=torch.bfloat16, ) @@ -495,6 +482,7 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + def forward( self, hidden_states: torch.Tensor, @@ -503,12 +491,28 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: Optional[int] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Adapted from ArcticAttention.forward: https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/be318cae5aba5291208f27d30991a5150500887d + + Referenece Gaudi implementation from ../mixtral/modeling_mixtral.py GaudiMixtralAttention + + Changes made: + - Added new args + - token_idx + - attn_softmax_bf16 + - reuse_cache + - flash_attention_recompute + - cache_idx + - Optimize KV cache + - Use FusedSDPA attention + """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -542,43 +546,91 @@ def forward( else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - query_states, key_states, value_states, attention_mask = repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, self.training ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if use_cache: + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + if FusedSDPA: + if query_states.dtype != key_states.dtype: + key_states = key_states.type(query_states.dtype) + value_states = value_states.type(query_states.dtype) + # support long sequences exceeding 8192 + if not self.training and q_len == key_states.size(-2) and q_len > 8192: + htcore.mark_step() + attn_output = GaudiMixtralAttentionLongSequence.forward( + query_states, + key_states, + value_states, + attention_mask, + False, + self.block_size, + ) + htcore.mark_step() + else: + with ( + sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext() + ): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # repeat k/v heads if n_kv_heads < n_heads + query_states, key_states, value_states, attention_mask = repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = attn_weights + attention_mask - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -591,44 +643,10 @@ def forward( return attn_output, attn_weights, past_key_value -def get_arctic_linear( - input_dim, - output_dim, - bias=False, - use_deepspeed_implementation=False, - ds_optimized_lora_config=None, - ds_optimized_quantization_config=None, - ds_optimized_base_weight_sharding=False, - dtype=torch.bfloat16, -): - """Can return deepspeed optimized linear if available. - Args: - input_dim, output_dim, bias, dtype: self explanatory (same as from nn.Linear) - ds_optimized_lora_config: config of type ds_linear.LoRAConfig that contains lora specific parameter if we want to add lora to this layer. - ds_optimized_quantization_config: config of type ds_linear.QuantizationConfig. - ds_optimized_base_weight_sharding: bool. If true, the base weight for lora (provided ds_optimized_lora_config is not None) will be sharded across all available gpus - in a tensor parallel way. - """ - if is_deepspeed_available(): - if ds_optimized_lora_config is not None: - ds_optimized_lora_config: ds_linear.LoRAConfig = copy.deepcopy(ds_optimized_lora_config) - ds_optimized_lora_config.base_weight_sharding = ( - torch.distributed.get_world_size() if ds_optimized_base_weight_sharding else 1 - ) - return ds_linear.OptimizedLinear( - input_dim, output_dim, bias, ds_optimized_lora_config, ds_optimized_quantization_config, dtype=dtype - ) - return nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) - - class ArcticMLP(nn.Module): def __init__( self, config: ArcticConfig, - use_deepspeed_implementation=False, - ds_optimized_lora_config=None, - ds_optimized_quantization_config=None, - shard_base_weights_if_doing_lora=False, is_residual_mlp=False, ): """MLP class for Arctic supporting vanilla linear layers as well as some deepspeed optimizations. @@ -642,34 +660,22 @@ def __init__( super(ArcticMLP, self).__init__() self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size if not is_residual_mlp else self.hidden_dim - self.w1 = get_arctic_linear( + self.w1 = nn.Linear( self.hidden_dim, self.ffn_dim, - False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + bias=False, dtype=torch.bfloat16, ) - self.w2 = get_arctic_linear( + self.w2 = nn.Linear( self.ffn_dim, self.hidden_dim, - False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + bias=False, dtype=torch.bfloat16, ) - self.w3 = get_arctic_linear( + self.w3 = nn.Linear( self.hidden_dim, self.ffn_dim, - False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + bias=False, dtype=torch.bfloat16, ) self.act_fn = ACT2FN[config.hidden_act] @@ -690,57 +696,12 @@ def __init__(self, config: ArcticConfig, layer_id: int, **kwargs): self.top_k = config.num_experts_per_tok self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] - if self.use_deepspeed_implementation and MoE is None: - raise ValueError("Deepspeed is not installed") - quantization_config = kwargs.get(QUANTIZATION_CONFIG, None) - deepspeed_lora = kwargs.get(DEEPSPEED_LORA_CONFIG) if not self.is_moe_layer: # dense, not MoE - self.mlp = ArcticMLP( - config, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True, - ) + self.mlp = ArcticMLP(config) else: - if self.use_deepspeed_implementation: # DeepSpeed's MoE - moe_expert_parallel_size = kwargs.get(MOE_EXPERT_PARALLEL_SIZE_ARG, 1) - self.mlp = MoE( - self.hidden_dim, - # base weight sharding false for all deepspeed moe calls because it is already sharded - ArcticMLP( - config, - use_deepspeed_implementation=True, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=False, - ), - num_experts=config.num_local_experts, - ep_size=moe_expert_parallel_size, - k=config.num_experts_per_tok, - use_residual=False, - capacity_factor=config.moe_train_capacity_factor, - eval_capacity_factor=config.moe_eval_capacity_factor, - enable_expert_tensor_parallelism=config.enable_expert_tensor_parallelism, - min_capacity=config.moe_min_capacity, - drop_tokens=config.moe_token_dropping, - ) - else: - # "local" MoE implementation - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList( - [ - ArcticMLP( - config, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True, - ) - for i in range(self.num_experts) - ] - ) + # "local" MoE implementation + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = nn.ModuleList([ArcticMLP(config) for i in range(self.num_experts)]) # if torch.distributed.get_rank() == 0: # deepspeed.runtime.utils.see_memory_usage("", force=True) @@ -788,12 +749,7 @@ def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor): if self.is_moe_layer: - if self.use_deepspeed_implementation: - # deepspeed returns a tuple including output, gate loss, and expert count. - hidden_states, moe_loss, _ = self.mlp(hidden_states) - return hidden_states, moe_loss - else: - return self._moe_foreward(hidden_states) + return self._moe_foreward(hidden_states) else: return self.mlp(hidden_states), torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) @@ -807,22 +763,15 @@ def __init__(self, config: ArcticConfig, layer_idx: int, **kwargs): self.block_sparse_moe = ArcticMoE(config, layer_id=layer_idx, **kwargs) self.input_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] self.parallel_attn_mlp_res = ( config.parallel_attn_mlp_res and self.block_sparse_moe.is_moe_layer ) # add residual only when it is moe layer - deepspeed_quantization = kwargs.get(DEEPSPEED_QUANTIZATION_CONFIG) - deepspeed_lora = kwargs.get(DEEPSPEED_LORA_CONFIG) if self.parallel_attn_mlp_res: self.residual_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_mlp = ArcticMLP( config, - use_deepspeed_implementation=self.use_deepspeed_implementation, is_residual_mlp=True, - ds_optimized_quantization_config=deepspeed_quantization, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True, ) # for the residual layer. always shard the base weight if doing deepspeed lora. def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -836,13 +785,22 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: Optional[int] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) """ + Modified from original Arctic forward + Changes: + - Add new arg cache_position + - Add new arg token_idx + - Add new arg reuse_cache + - Add new arg flash_attention_recompute + - Add new arg cache_idx + Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size @@ -856,6 +814,11 @@ def forward( (see `past_key_values`). """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual_input = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -868,6 +831,11 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = residual_input + hidden_states @@ -1076,7 +1044,22 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Modified from original Arctic forward + Changes: + - Add new arg cache_position + - Add new arg token_idx + - Add new arg reuse_cache + - Add new arg flash_attention_recompute + - Add new arg cache_idx + - Force legacy KV cache + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1104,11 +1087,12 @@ def forward( ) use_cache = False - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + # NOTE: Forcing legacy cache for HPU + if past_key_values is not None and use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][0][2] + else: + past_key_values_length = past_key_values[0][0].shape[2] if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1122,6 +1106,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if cache_position is None: + past_seen_tokens = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] + + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1159,7 +1158,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_losses = () - next_decoder_cache = None + next_decoder_cache = () if use_cache else None for i, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1174,27 +1173,27 @@ def forward( past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=None if past_key_values is None else past_key_values[i], output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = layer_outputs[0] if use_cache: - if hasattr(layer_outputs[2 if output_attentions else 1], "to_legacy_cache"): - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - else: - if next_decoder_cache is None: - next_decoder_cache = [layer_outputs[2 if output_attentions else 1]] - else: - next_decoder_cache.append(layer_outputs[2 if output_attentions else 1]) + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1210,9 +1209,7 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache and hasattr(next_decoder_cache, "to_legacy_cache") - else next_decoder_cache + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) if not return_dict: @@ -1230,7 +1227,7 @@ def forward( ) -class ArcticForCausalLM(ArcticPreTrainedModel, GenerationMixin): +class ArcticForCausalLM(ArcticPreTrainedModel): # TODO(jeffra): update _keys_to_ignore_on_load_unexpected with expert keys not relevant for this rank _keys_to_ignore_on_load_unexpected = [ r"model\.layers\.\d+\.block_sparse_moe\.experts\.\d+\.w\d+\.weight" @@ -1374,9 +1371,9 @@ def _load_from_state_dict( world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 # TODO(jeffra): currently assumes fine-tuning only on one node, fix for world_size != ep size if self.moe_expert_parallel_size > 1: - assert ( - self.moe_expert_parallel_size == world_size - ), f"currently only support expert parallel size equal to world size but {self.moe_expert_parallel_size=} and {world_size=}" + assert self.moe_expert_parallel_size == world_size, ( + f"currently only support expert parallel size equal to world size but {self.moe_expert_parallel_size=} and {world_size=}" + ) rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 num_local_experts = self.num_experts // self.moe_expert_parallel_size @@ -1451,9 +1448,9 @@ def _load_from_state_dict( if "deepspeed_moe" in incoming_param_name: assert shape_local == shape_incoming, "deepspeed moe weights are never sharded" else: - assert ( - shape_incoming[1] == shape_local[1] * world_size - ), "weights should be sharded equally across world size" + assert shape_incoming[1] == shape_local[1] * world_size, ( + "weights should be sharded equally across world size" + ) incoming_param = incoming_param[:, rank * shape_local[1] : (rank + 1) * shape_local[1]] print(f"Deepspeed lora: {rank=}, renaming {incoming_param_name} -> {param_name}") state_dict[param_name] = incoming_param @@ -1478,8 +1475,20 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" + Modified from original. Only differences are: + - Add new arg cache_position + - Add new arg token_idx + - Add new arg reuse_cache + - Add new arg flash_attention_recompute + - Add new arg cache_idx + Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1523,6 +1532,11 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -1561,58 +1575,70 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, ): + """ + Copied from GaudiMixtralForCausalLM in optimum/habana/transformers/models/mixtral/modeling_mixtral.py + """ + reuse_cache = kwargs.get("reuse_cache") + token_idx = kwargs.get("token_idx", None) + # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + if token_idx is not None: + idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 + input_ids = torch.index_select(input_ids, 1, idx) else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, + "token_idx": token_idx, + "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs @@ -1750,6 +1776,7 @@ def forward( attentions=transformer_outputs.attentions, ) + # Copied from optimum.habana.transformers.models.llama.modeling_llama:apply_customized_rope() def apply_customized_rope(q, k, cos, sin, position_ids, training=True): if q.device.type == "hpu" and FusedRoPE: