diff --git a/README.md b/README.md index 95c4f2d20d..88554b151b 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,7 @@ The following model architectures, tasks and device distributions have been vali | Qwen2-VL | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | VideoLLaVA | |
  • Single card
  • |
  • [Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension)
  • | | GLM-4V | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • +| Mamba | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 7c0246dc0f..01b7264f62 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -113,6 +113,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | ChatGLM |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Qwen2-VL | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | GLM-4V | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| Mamba | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | - Diffusers diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 13fd355288..b2300f81e1 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -209,6 +209,17 @@ PT_HPU_LAZY_MODE=1 python3 ./run_generation.py \ > --sdp_on_bf16 > ``` +To run Mamba-130m inference on 1 Gaudi2 card, use the following command, for example if default custom kernel path is in /root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so, if libcustom_tpc_perf_lib.so is in different folder, set accordingly, +```bash +--model_name_or_path state-spaces/mamba-130m-hf \ +--max_input_tokens 128 \ +--max_new_tokens 128 \ +--bf16 \ +--use_hpu_graphs \ +--use_kv_cache \ +--batch_size 1024 \ +``` + ### Use any dataset from the Hugging Face Hub You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`. diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index dfdf11769c..af08621588 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -265,8 +265,10 @@ gaudi_gpt_neox_model_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, + gaudi_MambaCache_update_conv_state, gaudi_MambaForCausalLM_prepare_inputs_for_generation, gaudi_MambaForCausalLM_update_model_kwargs_for_generation, + gaudi_MambaMixer, gaudi_mistral_rmsnorm_forward, gaudi_mixtral_rmsnorm_forward, gaudi_opt_attention_forward, @@ -763,6 +765,9 @@ def adapt_transformers_to_gaudi(): transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = ( gaudi_MambaForCausalLM_update_model_kwargs_for_generation ) + transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer + transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state + transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaForCausalLM.prepare_inputs_for_generation = ( gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation ) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 1bc8207356..65b5031b89 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -176,8 +176,10 @@ from .llava_next import GaudiLlavaNextForConditionalGeneration from .llava_onevision import GaudiLlavaOnevisionForConditionalGeneration from .mamba import ( + gaudi_MambaCache_update_conv_state, gaudi_MambaForCausalLM_prepare_inputs_for_generation, gaudi_MambaForCausalLM_update_model_kwargs_for_generation, + gaudi_MambaMixer, ) from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM from .mistral import ( diff --git a/optimum/habana/transformers/models/mamba/__init__.py b/optimum/habana/transformers/models/mamba/__init__.py index c22d12877c..6bd4566df3 100644 --- a/optimum/habana/transformers/models/mamba/__init__.py +++ b/optimum/habana/transformers/models/mamba/__init__.py @@ -1,4 +1,6 @@ from .modeling_mamba import ( + gaudi_MambaCache_update_conv_state, gaudi_MambaForCausalLM_prepare_inputs_for_generation, gaudi_MambaForCausalLM_update_model_kwargs_for_generation, + gaudi_MambaMixer, ) diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py index db07b00abd..615c93eb98 100644 --- a/optimum/habana/transformers/models/mamba/modeling_mamba.py +++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py @@ -1,6 +1,9 @@ from typing import Any, Optional import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.models.mamba.configuration_mamba import MambaConfig from transformers.models.mamba.modeling_mamba import ( MambaCache, ) @@ -11,6 +14,42 @@ logger = logging.get_logger(__name__) +use_pscan_kernel = True + +def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z): + in_state_h = in_state.unsqueeze(1).transpose(2, 3) + in_x_h = in_x.transpose(1, 2).unsqueeze(2) + in_dt_h = in_dt.unsqueeze(2) + in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3) + in_B_h = in_B.unsqueeze(3) + in_C_h = in_C.unsqueeze(3) + in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2) + in_z_h = in_z.transpose(1, 2).unsqueeze(2) + + state_out_h = torch.ops.hpu.mamba_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h) + output_h = torch.ops.hpu.mamba_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h) + + output_hpu = output_h.squeeze(2).transpose(1, 2) + state_hpu = state_out_h.transpose(2, 3) + state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1) + + return output_hpu, state_out + + +def gaudi_MambaCache_update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor +) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + # conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + for c, i in enumerate(cache_position): + conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device) + + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] def gaudi_MambaForCausalLM_update_model_kwargs_for_generation( @@ -94,3 +133,152 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation( } ) return model_inputs + +class gaudi_MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + We only replaced the slow path with custom op + """ + + def __init__(self, config: MambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.intermediate_size, + padding=config.conv_kernel - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.use_mambapy = config.use_mambapy + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + # fmt: off + def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): + """ + We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance. + """ + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + if use_pscan_kernel: + scan_output, ssm_state = Run_Mamba_Forward_Gaudi( + ssm_state, + hidden_states, + discrete_time_step, + A, + B, + C, + self.D, + gate + ) + else: + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) \ No newline at end of file diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 0e7b7c6d65..b196aa9786 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -391,7 +391,7 @@ }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[state-spaces/mamba-130m-hf-1536-False-False]": { "gaudi2": { - "throughput": 3100.9825044466907 + "throughput": 20208.867657545277 }, "gaudi3": { "throughput": 1948.1615848330302 @@ -417,7 +417,7 @@ }, "tests/test_text_generation_example.py::test_text_generation_bf16_1x[tiiuae/falcon-mamba-7b-1-False-False]": { "gaudi2": { - "throughput": 47.1464839567739 + "throughput": 73.36018500761314 }, "gaudi3": { "throughput": 45.90538768350833 diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index b53cdcb365..dda619742b 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -51,9 +51,7 @@ ("google/gemma-7b", 1, False, True, False), ("google/gemma-2-9b", 1, False, True, False), ("google/gemma-2-27b", 1, False, True, False), - pytest.param( - "state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated") - ), + ("state-spaces/mamba-130m-hf", 1536, False, False), # ("Deci/DeciLM-7B", 1, False, False, False), ("Qwen/Qwen2-7B", 256, False, True, False), ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False, False),