diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..710b2668f 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -43,7 +43,7 @@ class NormalizationConfig(BaseModelConfig): pass @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": pass @classmethod @@ -63,7 +63,7 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": return torch.nn.Identity() diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index a1f357de9..fb178e7d5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,29 +1,14 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig from fast_llm.utils import Assert - -class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads - - # Mamba 2 - x_proj_dim_2 = "x_proj_dim_2" # d_xb - c_heads = "c_heads" +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer class SSMBlockType(enum.StrEnum): @@ -53,6 +38,16 @@ def get_mixer_class(self): raise NotImplementedError(self) +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float) -> "Initializer": + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + + @config_class() class SSMConfig(LLMBlockConfig): _abstract = False @@ -62,74 +57,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.architecture, - ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", + desc="Inner dimension for Mamba2 blocks.", hint=FieldHint.core, - valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # xb_size [Mamba2] + d_xb: int = Field( + default=None, + desc="Dimension of the xB in Mamba2 blocks.", + hint=FieldHint.architecture, ) - d_inner: None | int = Field( + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( + default=False, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, + ) + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + hint=FieldHint.architecture, ) + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, + ) + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, + ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -137,31 +145,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( + default=DTInitType.random, + desc="Initialization method for dt", + hint=FieldHint.core, ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), ) - dt_init: str = Field( - default="random", - desc="Initialization method for dt", + # dt_bias_initialization_min [MambaLayer, Mamba2] + dt_min: float = Field( + default=0.001, + desc="Minimum step size for discretization", hint=FieldHint.core, + valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_max [MambaLayer, Mamba2] dt_max: float = Field( default=0.1, desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +187,7 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + # Handled in the model. + pass diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 6012f74a7..47a94214a 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,17 +1,24 @@ import logging -import math import typing import einops import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_ -from fast_llm.utils import get_lr_scale +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -32,12 +39,6 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" @@ -50,208 +51,194 @@ def __init__( tensor_space: TensorSpace, transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config - bias = config.add_bias_linear + self._config: SSMConfig = config layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space[SSMDimNames.inner_dim] - td_state = tensor_space[SSMDimNames.state_dim] - td_model = tensor_space[SSMDimNames.model_dim] - td_conv = tensor_space[SSMDimNames.conv_dim] - td_n_qk_heads = tensor_space[SSMDimNames.qk_heads] - td_n_v_heads = tensor_space[SSMDimNames.v_heads] - td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] - td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2] - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + hidden_dim = tensor_space[TransformerDimNames.hidden] + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) + + head_groups_dim = TensorDim( + "head_groups", + self._config.n_qk_heads, + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) + + # local_head_groups = head_groups / TP + self._local_head_groups = head_groups_dim.size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size + self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state + self._local_bc_size = bc_dim.size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self.in_proj = OutputParallelLinear( + hidden_dim, + inner_projection_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not bias - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, + ( + convolution_dim, + tensor_space[DefaultDimNames.scalar], + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 + ), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (convolution_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - def forward(self, hidden_states, kwargs): - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - input_ = hidden_states - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - state = None - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection) + inner_projection = self.in_proj(input_) + # Standardize to (batch, padded_sequence, local_inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + y = _mamba_chunk_scan_combined( + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() + if not self._config.add_bias_linear: + z = z + self.z_bias + + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + return self.out_proj(y) - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", - "swish", - "identity", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 9dfad8462..7151da394 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,17 +1,24 @@ -import math +import logging import typing -import einops import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_ -from fast_llm.utils import get_lr_scale +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -27,23 +34,7 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) +logger = logging.getLogger(__name__) class Mamba2(Mixer): @@ -53,18 +44,6 @@ class Mamba2(Mixer): _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.inner_dim, - TransformerDimNames.sequence_q, - ) - _BC_DIMS = ( - TransformerDimNames.batch, - SSMDimNames.c_heads, - SSMDimNames.state_dim, - TransformerDimNames.sequence_q, - ) - def __init__( self, config: SSMConfig, @@ -73,198 +52,206 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config - bias: bool = config.add_bias_linear + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", + num_head_groups, + self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) - td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim] - td_state: TensorDim = tensor_space[SSMDimNames.state_dim] - td_model: TensorDim = tensor_space[SSMDimNames.model_dim] - tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank] - td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2] - td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2] - td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size] - - self.repeat_kv_before_conv = config.repeat_kv_before_conv - - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), - init_method=init_uniform_( - -1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel), - init_method=init_uniform_( - -1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - self.activation = "silu" + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size + + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + self.conv1d_weight = ParameterMeta.from_dims( + ( + conv1d_dim, + tensor_space[DefaultDimNames.scalar], + convolution_kernel_dim, + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) - self.dt_in_proj = Linear( - td_model, - tdt_rank, + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), - lr_scale=mamba_layer_lr_scale, + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - self.dt_proj = Linear( - tdt_rank, - td_inner, + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_A(self.config.state_size, self.config.d_inner), - lr_scale=mamba_layer_lr_scale, + (inner_dim, state_dim), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=init_kaiming_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) + if self._debug_level: + self._xz_dims = ( + TransformerDimNames.batch, + inner_dim, + TransformerDimNames.sequence_q, + ) + self._bc_dims = ( + TransformerDimNames.batch, + heads_dim, + state_dim, + TransformerDimNames.sequence_q, + ) - def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbc = self.in_proj(hidden_states) - z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) - - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") + assert _causal_conv1d_available + + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, local_inner_projection) + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias + # Standardize to (batch, sequence, local_inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) + + sequence_length = inner_projection.size(1) + + z, x, b, c = torch.split( + inner_projection, + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], + dim=2, + ) - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) + z = z.transpose(1, 2) - dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + ) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) if self._debug_level: - self._debug_log(z, "z", self._XZ_DIMS, kwargs) - self._debug_log(x, "x", self._XZ_DIMS, kwargs) - self._debug_log(B, "b", self._BC_DIMS, kwargs) - self._debug_log(C, "c", self._BC_DIMS, kwargs) - self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + self._debug_log(z, "z", self._xz_dims, kwargs) + self._debug_log(x, "x", self._xz_dims, kwargs) + self._debug_log(b, "b", self._bc_dims, kwargs) + self._debug_log(c, "c", self._bc_dims, kwargs) + self._debug_log(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) if self._debug_level: - self._debug_log(y, "y", self._XZ_DIMS, kwargs) - - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + self._debug_log(y, "y", self._xz_dims, kwargs) + + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) + y = y.transpose(1, 2)[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 5e0ae786e..061921b3d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,17 +1,23 @@ +import logging import math import typing -from typing import Callable -import einops import torch -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_space import ( + CompositeTensorDim, + ConcatenatedTensorDim, + DefaultDimNames, + TensorDim, + TensorSpace, +) +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ -from fast_llm.utils import get_lr_scale +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -20,6 +26,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -27,38 +35,27 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization - # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor - - return init_ - - -def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + if tensor.numel() != d_state * d_inner: + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, ) + + return LambdaInitializer(init_, requires_global_initialization=True) + + +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) class MambaLayer(Mixer): @@ -72,115 +69,109 @@ def __init__( transformer_config: TransformerConfig, ): super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self.config: SSMConfig = config - - # Tensor dims: - td_inner = tensor_space[SSMDimNames.inner_dim] - td_inner_proj = tensor_space[SSMDimNames.inner_proj_mamba] # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space[SSMDimNames.dt_rank] - td_x_proj = tensor_space[SSMDimNames.x_proj_dim] - td_state = tensor_space[SSMDimNames.state_dim] - td_model = tensor_space[SSMDimNames.model_dim] - td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=init_kaiming_(td_model.size), + # Tensor dims: + hidden_dim = tensor_space[TransformerDimNames.hidden] + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) + + # TODO: Backward compatibility? + self.in_proj = Linear( + hidden_dim, + inner_projection_dim, + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), - init_method=init_kaiming_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + ( + inner_dim, + tensor_space[DefaultDimNames.scalar], + convolution_kernel_dim, + ), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=init_kaiming_(td_inner.size), + inner_dim, + x_projection_dim, + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=init_kaiming_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, dt_rank_dim), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, state_dim), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=init_kaiming_(td_model.size), - lr_scale=mamba_layer_lr_scale, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, + in_proj, self.conv1d_weight, - self.conv1d_bias, + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 75d06f268..c7becd948 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -98,10 +98,11 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space[TransformerDimNames.hidden] - # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale) # The mixer needs to be created here for backward-compatible weight ordering. setattr(self, self._mixer_module_name, self._create_mixer()) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e3f964aee..175cb07f1 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -23,7 +23,6 @@ class GPTHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False - trust_remote_code: typing.ClassVar[bool] = False @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -58,17 +57,14 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionDreamGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "dream" - trust_remote_code: typing.ClassVar[bool] = True class DiffusionLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "diffusion_llama" - trust_remote_code: typing.ClassVar[bool] = True @config_class() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index f97539c04..f632ab6c7 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,13 +4,18 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -45,42 +50,10 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -111,8 +84,7 @@ def _validate(self): self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -122,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -133,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -144,8 +114,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" @classmethod @@ -193,12 +162,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..b5e77e0f0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,8 @@ import pathlib import typing +from transformers import PretrainedConfig + from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -16,7 +18,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig @@ -29,12 +31,14 @@ HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.apriel_hybrid import configuration_ssm_hybrid_apriel, modeling_ssm_hybrid_apriel from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert -if typing.TYPE_CHECKING: - pass - class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel @@ -523,6 +527,11 @@ class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandle _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat architecture: typing.ClassVar[str] = "AprielSSMForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -635,6 +644,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -648,10 +658,21 @@ class AprielSSMHHybridHuggingfaceCheckpointHandler( format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat _default_block_type: str = SSMBlockType.mamba2_discrete.value architecture: typing.ClassVar[str] = "AprielSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel.__file__ + configuration_file = configuration_ssm_hybrid_apriel.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = modeling_ssm_hybrid_apriel.AprielSSMHybridConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel.AprielSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel.AprielSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -693,6 +714,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -707,28 +729,23 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() - # num_layers = self._model.config.base_model.transformer.num_layers - # # Embedding and output - # if self._model.config.base_model.tie_word_embeddings: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) - # else: - # converters.append( - # WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") - # ) - # converters.append( - # WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") - # ) - return converters + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type["PretrainedConfig"]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), diff --git a/setup.cfg b/setup.cfg index c086af7d0..843aa15ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,9 +48,14 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 + cartesia_pytorch>=0.0.2 + +GENERATION = + lm_eval>=0.4.9 + DEV = # Pre-commit git hook diff --git a/tests/conftest.py b/tests/conftest.py index 19bdfe5d9..86937326c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,27 +8,15 @@ import pytest import xdist.scheduler -from fast_llm.utils import get_and_reset_memory_usage_mib, set_global_variables +from fast_llm.utils import get_and_reset_memory_usage_mib from tests.utils.depends import DependencyManager +from tests.utils.global_variables import TEST_RESULTS_PATH, set_testing_global_variables # TODO: Is this early enough? -set_global_variables() # isort: skip - - -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): - # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. - assert worker_name.startswith("gw") - worker_id = int(worker_name[2:]) - gpus = [int(i) for i in gpus.split(",")] - num_gpus = len(gpus) - gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - +set_testing_global_variables() # isort: skip import torch # isort: skip - from tests.utils.save_load_configs import ( # isort: skip distributed_save_load_config, distributed_save_load_config_non_pp, @@ -44,7 +32,7 @@ ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, TEST_RESULTS_PATH, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip logger = logging.getLogger(__name__) diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..6614accce 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -23,7 +23,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.dataset import TEST_VOCAB_SIZE +from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 3e6c37632..312807aad 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -11,7 +11,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 4f36cdf89..6cc5d639a 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -7,7 +7,8 @@ get_test_data_and_compare_samples, ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index 1cc22250d..35d93d9d5 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -9,7 +9,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_CACHE, get_test_concatenated_memmap_dataset +from tests.utils.dataset import get_test_concatenated_memmap_dataset +from tests.utils.global_variables import DATASET_CACHE _DATASET_PREFIX_MIX_CONCATENATED_MEMMAP = DATASET_CACHE / "concatenated_memmap" diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 3f7d1a139..c149e1395 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -1,7 +1,8 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX def test_dataset_from_file(): diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 004b96289..551134fd2 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -7,7 +7,8 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index fcd7756db..1286bddd7 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -4,7 +4,8 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 32d76fa4c..a2996aa1c 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -13,7 +13,8 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index f8eedc5bc..1440614cb 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -7,7 +7,8 @@ validate_indexed_dataset_sampling, ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import DATASET_PREFIX, get_test_dataset +from tests.utils.dataset import get_test_dataset +from tests.utils.global_variables import DATASET_PREFIX GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23dc..031ec6f97 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, @@ -312,9 +317,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): if model_testing_config.name in ("diffusion_llama", "dream") else transformers.AutoModelForCausalLM ) - model_as_hf = auto_model.from_pretrained( - hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code - ).cuda() + model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index b9e2aa8c3..8011b5bbc 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,8 +3,9 @@ import pytest -from tests.utils.dataset import TOKENIZER_PATH, download_santacoder_tokenizer +from tests.utils.dataset import download_santacoder_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 30667cd17..5ff998bfa 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,8 +3,9 @@ import pytest from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import MODEL_DATASET_PREFIX, get_model_test_dataset +from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b770675d4..e4cce2935 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,27 +1,21 @@ import pathlib import random -import string import numpy as np import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from tests.utils.utils import SHARED_RESULT_PATH, TEST_RESULTS_PATH - -# TODO: Fixtures -TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" -TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" -DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 - -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" -MODEL_TEST_VOCAB_SIZE = 384 +from tests.utils.global_variables import ( + DATASET_PREFIX, + MODEL_DATASET_PREFIX, + MODEL_TEST_VOCAB_SIZE, + TEST_CHARACTERS, + TEST_DATASET_TOKENS, + TEST_VOCAB_SIZE, + TOKENIZER_FILE, + TOKENIZER_PATH, +) def download_santacoder_tokenizer(): diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py new file mode 100644 index 000000000..836b6b79d --- /dev/null +++ b/tests/utils/global_variables.py @@ -0,0 +1,49 @@ +""" +This files holds global variables and settings that need to be defined before importing any third-party package. +They are kept in a separate file to prevent circular imports. +""" + +import os +import pathlib +import string + +from fast_llm.utils import set_global_variables + +# Directory for all test data and results. +# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). +TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") + +WORKER_NAME = os.environ.get("PYTEST_XDIST_WORKER") +GPUS = os.environ.get("CUDA_VISIBLE_DEVICES") +SHARED_RESULT_PATH = TEST_RESULTS_PATH / (f"common_{WORKER_NAME}" if WORKER_NAME else "common") + + +def set_testing_global_variables(): + set_global_variables() # isort: skip + if WORKER_NAME: + if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): + # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. + assert WORKER_NAME.startswith("gw") + worker_id = int(WORKER_NAME[2:]) + gpus = [int(i) for i in gpus.split(",")] + num_gpus = len(gpus) + gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) + # TODO: This might help with some issues, but slows down testing significantly. + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + + +# TODO: Fixtures +TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" +TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +DATASET_CACHE = SHARED_RESULT_PATH / "dataset" +DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" +TEST_VOCAB_SIZE = 8192 +# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% +TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" +TEST_DATASET_TOKENS = 1000000 + +MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..e9bdeba97 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,9 +19,13 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE +from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -466,16 +470,14 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llamba", model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -483,57 +485,75 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", - "hybrid_discrete_mamba2", + # Tests hybrid Mamba 2. + "llama", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", - "hybrid_mamba2", + # Tests hybrid discrete Mamba 2. + "llama", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sdp", "ms"), ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 25d5221d8..88303a0f4 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,7 +1,6 @@ import json import logging import math -import os import pathlib import sys import time @@ -19,22 +18,12 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage from fast_llm.utils import get_and_reset_memory_usage_mib, header +from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -# Directory for all test data and results. -# Cannot be a fixture because it's used outside testing environment (ex. distributed scripts). -TEST_RESULTS_PATH = pathlib.Path("/tmp/fast_llm_tests") - -# Directory for data that is shared between independent tests and may not be parallel-safe, -# ex. generated dataset and downloaded files. -if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): - SHARED_RESULT_PATH = TEST_RESULTS_PATH / f"common_{worker_name}" -else: - SHARED_RESULT_PATH = TEST_RESULTS_PATH / "common" - @pytest.fixture(scope="session") def result_path():