Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions src/transformers/models/nemotron_h/modeling_nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,34 +292,43 @@ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,

if getattr(config, "use_mamba_kernels", True):
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)
)

if not is_fast_path_available:
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)
else:
causal_conv1d_update = None
causal_conv1d_fn = None
selective_state_update = None
mamba_chunk_scan_combined = None
mamba_split_conv1d_scan_combined = None
is_fast_path_available = False

if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
Expand Down
57 changes: 33 additions & 24 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,34 +567,43 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,

if getattr(config, "use_mamba_kernels", True):
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)
)

if not is_fast_path_available:
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)
else:
causal_conv1d_update = None
causal_conv1d_fn = None
selective_state_update = None
mamba_chunk_scan_combined = None
mamba_split_conv1d_scan_combined = None
is_fast_path_available = False

if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
Expand Down
57 changes: 33 additions & 24 deletions src/transformers/models/zamba2/modular_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,34 +320,43 @@ def __init__(self, config: Zamba2Config, layer_idx: int | None = None):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)

global causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)

global is_fast_path_available
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,

if getattr(config, "use_mamba_kernels", True):
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)

mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
mamba_chunk_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
)
mamba_split_conv1d_scan_combined = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
)
)

if not is_fast_path_available:
is_fast_path_available = all(
(
selective_state_update,
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
causal_conv1d_fn,
causal_conv1d_update,
)
)
else:
causal_conv1d_update = None
causal_conv1d_fn = None
selective_state_update = None
mamba_chunk_scan_combined = None
mamba_split_conv1d_scan_combined = None
is_fast_path_available = False

if getattr(config, "use_mamba_kernels", True) and not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
Expand Down
Loading