diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 9e264e5cfdcc..a32c349ce18f 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -179,34 +179,43 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) global causal_conv1d_update, causal_conv1d_fn - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) - causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" - ) - mamba_chunk_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" - ) - mamba_split_conv1d_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" - ) - global is_fast_path_available - is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, + + if config.use_mamba_kernels: + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" ) - ) - if not is_fast_path_available: + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + else: + causal_conv1d_update = None + causal_conv1d_fn = None + selective_state_update = None + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + is_fast_path_available = False + + if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 6d1af578e087..f64cd5968552 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -33,6 +33,8 @@ class Zamba2Config(PreTrainedConfig): Whether or not to use bias in the convolution layer of the mixer block. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. use_mem_eff_path (`bool`, *optional*, defaults to `False`): Whether or not to use the fused conv1d and scan in mamba2 layers. add_bias_linear (`bool`, *optional*, defaults to `False`): @@ -83,6 +85,7 @@ class Zamba2Config(PreTrainedConfig): time_step_floor: float = 1e-4 time_step_limit: list[float] | tuple[float, ...] | None = None n_mamba_heads: int = 8 + use_mamba_kernels: bool = True use_conv_bias: bool = True chunk_size: int = 256 use_mem_eff_path: bool = False diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 6e4ea7dcf2d8..54f20adebcdf 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -467,34 +467,43 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) global causal_conv1d_update, causal_conv1d_fn - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) - causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" - ) - mamba_chunk_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" - ) - mamba_split_conv1d_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" - ) - global is_fast_path_available - is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, + + if config.use_mamba_kernels: + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" ) - ) - if not is_fast_path_available: + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + else: + causal_conv1d_update = None + causal_conv1d_fn = None + selective_state_update = None + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + is_fast_path_available = False + + if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index d7716301ad4a..4864c4dbff78 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -255,34 +255,43 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) global causal_conv1d_update, causal_conv1d_fn - causal_conv1d = lazy_load_kernel("causal-conv1d") - causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) - causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) - global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined - mamba_ssm = lazy_load_kernel("mamba-ssm") - selective_state_update = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" - ) - mamba_chunk_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" - ) - mamba_split_conv1d_scan_combined = resolve_internal_import( - mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" - ) - global is_fast_path_available - is_fast_path_available = all( - ( - selective_state_update, - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - causal_conv1d_fn, - causal_conv1d_update, + + if config.use_mamba_kernels: + causal_conv1d = lazy_load_kernel("causal-conv1d") + causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None) + causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None) + + mamba_ssm = lazy_load_kernel("mamba-ssm") + selective_state_update = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update" + ) + mamba_chunk_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined" + ) + mamba_split_conv1d_scan_combined = resolve_internal_import( + mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined" ) - ) - if not is_fast_path_available: + is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) + ) + else: + causal_conv1d_update = None + causal_conv1d_fn = None + selective_state_update = None + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + is_fast_path_available = False + + if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available: logger.warning_once( "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"