diff --git a/README.md b/README.md
index 95c4f2d20d..88554b151b 100644
--- a/README.md
+++ b/README.md
@@ -285,6 +285,7 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2-VL | |
Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| VideoLLaVA | | Single card | [Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension) |
| GLM-4V | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
+| Mamba | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 7c0246dc0f..01b7264f62 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -113,6 +113,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| ChatGLM | DeepSpeed | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-VL | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| GLM-4V | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| Mamba | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
- Diffusers
diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md
index 13fd355288..b2300f81e1 100755
--- a/examples/text-generation/README.md
+++ b/examples/text-generation/README.md
@@ -209,6 +209,17 @@ PT_HPU_LAZY_MODE=1 python3 ./run_generation.py \
> --sdp_on_bf16
> ```
+To run Mamba-130m inference on 1 Gaudi2 card, use the following command, for example if default custom kernel path is in /root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so, if libcustom_tpc_perf_lib.so is in different folder, set accordingly,
+```bash
+--model_name_or_path state-spaces/mamba-130m-hf \
+--max_input_tokens 128 \
+--max_new_tokens 128 \
+--bf16 \
+--use_hpu_graphs \
+--use_kv_cache \
+--batch_size 1024 \
+```
+
### Use any dataset from the Hugging Face Hub
You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`.
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index dfdf11769c..af08621588 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -265,8 +265,10 @@
gaudi_gpt_neox_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
+ gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+ gaudi_MambaMixer,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_rmsnorm_forward,
gaudi_opt_attention_forward,
@@ -763,6 +765,9 @@ def adapt_transformers_to_gaudi():
transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = (
gaudi_MambaForCausalLM_update_model_kwargs_for_generation
)
+ transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
+ transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state
+
transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaForCausalLM.prepare_inputs_for_generation = (
gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation
)
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 1bc8207356..65b5031b89 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -176,8 +176,10 @@
from .llava_next import GaudiLlavaNextForConditionalGeneration
from .llava_onevision import GaudiLlavaOnevisionForConditionalGeneration
from .mamba import (
+ gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+ gaudi_MambaMixer,
)
from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM
from .mistral import (
diff --git a/optimum/habana/transformers/models/mamba/__init__.py b/optimum/habana/transformers/models/mamba/__init__.py
index c22d12877c..6bd4566df3 100644
--- a/optimum/habana/transformers/models/mamba/__init__.py
+++ b/optimum/habana/transformers/models/mamba/__init__.py
@@ -1,4 +1,6 @@
from .modeling_mamba import (
+ gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
+ gaudi_MambaMixer,
)
diff --git a/optimum/habana/transformers/models/mamba/modeling_mamba.py b/optimum/habana/transformers/models/mamba/modeling_mamba.py
index db07b00abd..615c93eb98 100644
--- a/optimum/habana/transformers/models/mamba/modeling_mamba.py
+++ b/optimum/habana/transformers/models/mamba/modeling_mamba.py
@@ -1,6 +1,9 @@
from typing import Any, Optional
import torch
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.models.mamba.modeling_mamba import (
MambaCache,
)
@@ -11,6 +14,42 @@
logger = logging.get_logger(__name__)
+use_pscan_kernel = True
+
+def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
+ in_state_h = in_state.unsqueeze(1).transpose(2, 3)
+ in_x_h = in_x.transpose(1, 2).unsqueeze(2)
+ in_dt_h = in_dt.unsqueeze(2)
+ in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3)
+ in_B_h = in_B.unsqueeze(3)
+ in_C_h = in_C.unsqueeze(3)
+ in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
+ in_z_h = in_z.transpose(1, 2).unsqueeze(2)
+
+ state_out_h = torch.ops.hpu.mamba_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
+ output_h = torch.ops.hpu.mamba_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)
+
+ output_hpu = output_h.squeeze(2).transpose(1, 2)
+ state_hpu = state_out_h.transpose(2, 3)
+ state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1)
+
+ return output_hpu, state_out
+
+
+def gaudi_MambaCache_update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+) -> torch.Tensor:
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ # conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+ for c, i in enumerate(cache_position):
+ conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device)
+
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
@@ -94,3 +133,152 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
}
)
return model_inputs
+
+class gaudi_MambaMixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+ and is why Mamba is called **selective** state spaces)
+ We only replaced the slow path with custom op
+ """
+
+ def __init__(self, config: MambaConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+ self.intermediate_size = config.intermediate_size
+ self.time_step_rank = int(config.time_step_rank)
+ self.layer_idx = layer_idx
+ self.use_conv_bias = config.use_conv_bias
+ self.conv1d = nn.Conv1d(
+ in_channels=self.intermediate_size,
+ out_channels=self.intermediate_size,
+ bias=config.use_conv_bias,
+ kernel_size=config.conv_kernel,
+ groups=self.intermediate_size,
+ padding=config.conv_kernel - 1,
+ )
+
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+
+ self.use_mambapy = config.use_mambapy
+
+ # projection of the input hidden states
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
+ # selective projection used to make dt, B and C input dependant
+ self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
+ # time step projection (discretization)
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(self.intermediate_size, -1).contiguous()
+
+ self.A_log = nn.Parameter(torch.log(A))
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
+ self.use_bias = config.use_bias
+
+ # fmt: off
+ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
+ """
+ We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance.
+ """
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
+ hidden_states, gate = projected_states.chunk(2, dim=1)
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 2. Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ # use `cache_position.shape[0]` to check whether we are in prefill
+ # stage, it's equivalent to check `cache_position[0] == 0`, which
+ # breaks dynamo fullgraph constraints
+ if cache_position.shape[0] == self.conv_kernel_size:
+ conv_state = nn.functional.pad(
+ hidden_states,
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+
+ cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+ else:
+ conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.intermediate_size, self.ssm_state_size),
+ device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
+
+ if attention_mask is not None:
+ hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+ # 3. State Space Model sequence transformation
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+ time_step, B, C = torch.split(
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+ )
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
+
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
+ if use_pscan_kernel:
+ scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
+ ssm_state,
+ hidden_states,
+ discrete_time_step,
+ A,
+ B,
+ C,
+ self.D,
+ gate
+ )
+ else:
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+ scan_outputs = []
+ for i in range(seq_len):
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
+ scan_outputs.append(scan_output[:, :, 0])
+ scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
+ scan_output = (scan_output * self.act(gate))
+
+ if cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+ # fmt: on
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[MambaCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
\ No newline at end of file
diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json
index 0e7b7c6d65..b196aa9786 100644
--- a/tests/baselines/fixture/tests/test_text_generation_example.json
+++ b/tests/baselines/fixture/tests/test_text_generation_example.json
@@ -391,7 +391,7 @@
},
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[state-spaces/mamba-130m-hf-1536-False-False]": {
"gaudi2": {
- "throughput": 3100.9825044466907
+ "throughput": 20208.867657545277
},
"gaudi3": {
"throughput": 1948.1615848330302
@@ -417,7 +417,7 @@
},
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[tiiuae/falcon-mamba-7b-1-False-False]": {
"gaudi2": {
- "throughput": 47.1464839567739
+ "throughput": 73.36018500761314
},
"gaudi3": {
"throughput": 45.90538768350833
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index b53cdcb365..dda619742b 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -51,9 +51,7 @@
("google/gemma-7b", 1, False, True, False),
("google/gemma-2-9b", 1, False, True, False),
("google/gemma-2-27b", 1, False, True, False),
- pytest.param(
- "state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated")
- ),
+ ("state-spaces/mamba-130m-hf", 1536, False, False),
# ("Deci/DeciLM-7B", 1, False, False, False),
("Qwen/Qwen2-7B", 256, False, True, False),
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False, False),