Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2-VL | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| VideoLLaVA | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension)</li> |
| GLM-4V | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li>
| Mamba | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

</div>

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-VL | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| GLM-4V | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| Mamba | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

- Diffusers

Expand Down
11 changes: 11 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ PT_HPU_LAZY_MODE=1 python3 ./run_generation.py \
> --sdp_on_bf16
> ```

To run Mamba-130m inference on 1 Gaudi2 card, use the following command, for example if default custom kernel path is in /root/.cache/huggingface/hub/models--Habana--mamba/blobs/libcustom_tpc_perf_lib.so, if libcustom_tpc_perf_lib.so is in different folder, set accordingly,
```bash
--model_name_or_path state-spaces/mamba-130m-hf \
--max_input_tokens 128 \
--max_new_tokens 128 \
--bf16 \
--use_hpu_graphs \
--use_kv_cache \
--batch_size 1024 \
```

### Use any dataset from the Hugging Face Hub

You can also provide the name of a dataset from the Hugging Face Hub to perform generation on it with the argument `--dataset_name`.
Expand Down
5 changes: 5 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@
gaudi_gpt_neox_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_rmsnorm_forward,
gaudi_opt_attention_forward,
Expand Down Expand Up @@ -763,6 +765,9 @@ def adapt_transformers_to_gaudi():
transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = (
gaudi_MambaForCausalLM_update_model_kwargs_for_generation
)
transformers.models.mamba.modeling_mamba.MambaMixer = gaudi_MambaMixer
transformers.cache_utils.MambaCache.update_conv_state = gaudi_MambaCache_update_conv_state

transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaForCausalLM.prepare_inputs_for_generation = (
gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation
)
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@
from .llava_next import GaudiLlavaNextForConditionalGeneration
from .llava_onevision import GaudiLlavaOnevisionForConditionalGeneration
from .mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
)
from .minicpm import MiniCPM3Config, MiniCPM3ForCausalLM
from .mistral import (
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .modeling_mamba import (
gaudi_MambaCache_update_conv_state,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
gaudi_MambaForCausalLM_update_model_kwargs_for_generation,
gaudi_MambaMixer,
)
188 changes: 188 additions & 0 deletions optimum/habana/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Optional

import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.models.mamba.modeling_mamba import (
MambaCache,
)
Expand All @@ -11,6 +14,42 @@


logger = logging.get_logger(__name__)
use_pscan_kernel = True

def Run_Mamba_Forward_Gaudi(in_state, in_x, in_dt, in_A, in_B, in_C, in_D, in_z):
in_state_h = in_state.unsqueeze(1).transpose(2, 3)
in_x_h = in_x.transpose(1, 2).unsqueeze(2)
in_dt_h = in_dt.unsqueeze(2)
in_A_h = in_A.unsqueeze(0).unsqueeze(1).transpose(2, 3)
in_B_h = in_B.unsqueeze(3)
in_C_h = in_C.unsqueeze(3)
in_D_h = in_D.unsqueeze(0).unsqueeze(1).unsqueeze(2)
in_z_h = in_z.transpose(1, 2).unsqueeze(2)

state_out_h = torch.ops.hpu.mamba_pscan(in_state_h, in_x_h, in_dt_h, in_A_h, in_B_h)
output_h = torch.ops.hpu.mamba_pscan_update(state_out_h, in_x_h, in_C_h, in_D_h, in_z_h)

output_hpu = output_h.squeeze(2).transpose(1, 2)
state_hpu = state_out_h.transpose(2, 3)
state_out = torch.select(state_hpu, 1, output_hpu.shape[2] - 1)

return output_hpu, state_out


def gaudi_MambaCache_update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
# conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
for c, i in enumerate(cache_position):
conv_state[:, :, i] = new_conv_state[:, :, c].to(conv_state.device)

self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]


def gaudi_MambaForCausalLM_update_model_kwargs_for_generation(
Expand Down Expand Up @@ -94,3 +133,152 @@ def gaudi_MambaForCausalLM_prepare_inputs_for_generation(
}
)
return model_inputs

class gaudi_MambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
We only replaced the slow path with custom op
"""

def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)

self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_mambapy = config.use_mambapy

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()

self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
"""
We replaced the 3c and 3d parts with custom op "Run_Mamba_Forward_Gaudi", which removed the sequence length loop and gain the performance.
"""
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
# use `cache_position.shape[0]` to check whether we are in prefill
# stage, it's equivalent to check `cache_position[0] == 0`, which
# breaks dynamo fullgraph constraints
if cache_position.shape[0] == self.conv_kernel_size:
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)

cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
else:
conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]

# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
if use_pscan_kernel:
scan_output, ssm_state = Run_Mamba_Forward_Gaudi(
ssm_state,
hidden_states,
discrete_time_step,
A,
B,
C,
self.D,
gate
)
else:
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on

def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@
},
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[state-spaces/mamba-130m-hf-1536-False-False]": {
"gaudi2": {
"throughput": 3100.9825044466907
"throughput": 20208.867657545277
},
"gaudi3": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gaudi3 baselines are unchanged?

"throughput": 1948.1615848330302
Expand All @@ -417,7 +417,7 @@
},
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[tiiuae/falcon-mamba-7b-1-False-False]": {
"gaudi2": {
"throughput": 47.1464839567739
"throughput": 73.36018500761314
},
"gaudi3": {
"throughput": 45.90538768350833
Expand Down
4 changes: 1 addition & 3 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@
("google/gemma-7b", 1, False, True, False),
("google/gemma-2-9b", 1, False, True, False),
("google/gemma-2-27b", 1, False, True, False),
pytest.param(
"state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated")
),
("state-spaces/mamba-130m-hf", 1536, False, False),
# ("Deci/DeciLM-7B", 1, False, False, False),
("Qwen/Qwen2-7B", 256, False, True, False),
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False, False),
Expand Down