From 31fa060b2099ac491fb3d4718f7d111f2c6f64e8 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Fri, 24 Apr 2026 08:32:44 -0700 Subject: [PATCH 1/3] mamba: avoid redundant HBM reloads in causal_conv1d_update shift loop The original shift loop re-reads conv_state[1..WIDTH-1] from HBM on every sequence step, even though those same values are already in registers as x_val_0/1/2 from the earlier load. When state_len == WIDTH (the common Mamba configuration where the conv state depth equals the kernel width), skip the re-reads and store from the existing registers. The HAS_INT_STATE snapshot path benefits from the same reuse. state_len > WIDTH falls through to the original loop. Numerically bit-exact on conv_state; measured ~1.5% decode throughput improvement on nano-v3 at BS=1, OSL=256 (p50 245.21 -> 248.79 tok/s). Co-Authored-By: Claude Opus 4.7 (1M context) --- megatron/core/ssm/ops/causal_conv1d_triton.py | 88 +++++++++++++++---- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index f57f5d94cea..f34f98b633d 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -139,33 +139,91 @@ def causal_conv1d_update_kernel( conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask ).to(tl.float32) - # Shift the linear state buffer left by 1 - i = 0 - while i < state_len - 1: - val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) - tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) - i += 1 - # Process the single token for the current sequence step x_val = tl.load(x_ptrs, mask=mask) + # Shift the linear state buffer left by 1. When state_len == WIDTH (the + # common case: conv_state dim == conv kernel width) the shifted values + # are already resident in the x_val_* registers from the loads above, so + # we can write them back without a second HBM read per position. For + # state_len > WIDTH the leading positions are untouched by compute; fall + # through to the explicit load+store shift. + if state_len == WIDTH: + out_dtype = conv_state_ptrs.dtype.element_ty + if WIDTH >= 2: + tl.store( + conv_state_ptrs + 0 * conv_state_l_stride, + x_val_0.to(out_dtype), + mask=mask, + ) + if WIDTH >= 3: + tl.store( + conv_state_ptrs + 1 * conv_state_l_stride, + x_val_1.to(out_dtype), + mask=mask, + ) + if WIDTH >= 4: + tl.store( + conv_state_ptrs + 2 * conv_state_l_stride, + x_val_2.to(out_dtype), + mask=mask, + ) + else: + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + # Store the new token at the end of the linear state buffer tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) - # Write out to the intermediate state buffer if requested + # Write out to the intermediate state buffer if requested. Reuse the + # register values from the shift above (and x_val for the new tail + # position) when state_len == WIDTH, instead of re-reading from HBM. if HAS_INT_STATE: - i = 0 - while i < state_len: - val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) - int_ptr = ( + if state_len == WIDTH: + int_base = ( int_state_ptr + state_batch_coord * int_state_b_stride + s * int_state_s_stride + channel_offsets * int_state_c_stride - + i * int_state_l_stride ) - tl.store(int_ptr, val, mask=mask) - i += 1 + out_dtype = int_base.dtype.element_ty + if WIDTH >= 2: + tl.store( + int_base + 0 * int_state_l_stride, + x_val_0.to(out_dtype), + mask=mask, + ) + if WIDTH >= 3: + tl.store( + int_base + 1 * int_state_l_stride, + x_val_1.to(out_dtype), + mask=mask, + ) + if WIDTH >= 4: + tl.store( + int_base + 2 * int_state_l_stride, + x_val_2.to(out_dtype), + mask=mask, + ) + tl.store( + int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask + ) + else: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 # Advance registers for calculation x_val_f32 = x_val.to(tl.float32) From e5c3fdc01b22f722a818884ee69a8f9445d9e37c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Apr 2026 09:37:02 -0700 Subject: [PATCH 2/3] lint --- megatron/core/ssm/ops/causal_conv1d_triton.py | 34 ++++--------------- 1 file changed, 7 insertions(+), 27 deletions(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index f34f98b633d..4ff7ddcf933 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -152,21 +152,15 @@ def causal_conv1d_update_kernel( out_dtype = conv_state_ptrs.dtype.element_ty if WIDTH >= 2: tl.store( - conv_state_ptrs + 0 * conv_state_l_stride, - x_val_0.to(out_dtype), - mask=mask, + conv_state_ptrs + 0 * conv_state_l_stride, x_val_0.to(out_dtype), mask=mask ) if WIDTH >= 3: tl.store( - conv_state_ptrs + 1 * conv_state_l_stride, - x_val_1.to(out_dtype), - mask=mask, + conv_state_ptrs + 1 * conv_state_l_stride, x_val_1.to(out_dtype), mask=mask ) if WIDTH >= 4: tl.store( - conv_state_ptrs + 2 * conv_state_l_stride, - x_val_2.to(out_dtype), - mask=mask, + conv_state_ptrs + 2 * conv_state_l_stride, x_val_2.to(out_dtype), mask=mask ) else: i = 0 @@ -191,26 +185,12 @@ def causal_conv1d_update_kernel( ) out_dtype = int_base.dtype.element_ty if WIDTH >= 2: - tl.store( - int_base + 0 * int_state_l_stride, - x_val_0.to(out_dtype), - mask=mask, - ) + tl.store(int_base + 0 * int_state_l_stride, x_val_0.to(out_dtype), mask=mask) if WIDTH >= 3: - tl.store( - int_base + 1 * int_state_l_stride, - x_val_1.to(out_dtype), - mask=mask, - ) + tl.store(int_base + 1 * int_state_l_stride, x_val_1.to(out_dtype), mask=mask) if WIDTH >= 4: - tl.store( - int_base + 2 * int_state_l_stride, - x_val_2.to(out_dtype), - mask=mask, - ) - tl.store( - int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask - ) + tl.store(int_base + 2 * int_state_l_stride, x_val_2.to(out_dtype), mask=mask) + tl.store(int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask) else: i = 0 while i < state_len: From 2082ea7fb90e6630bea05c870c7ff09f0dcb69c5 Mon Sep 17 00:00:00 2001 From: William Dykas Date: Tue, 28 Apr 2026 06:58:01 -0700 Subject: [PATCH 3/3] ADD TEST --- megatron/core/ssm/ops/causal_conv1d_triton.py | 34 +++----------- .../ssm/test_causal_conv1d_triton.py | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index f34f98b633d..4ff7ddcf933 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -152,21 +152,15 @@ def causal_conv1d_update_kernel( out_dtype = conv_state_ptrs.dtype.element_ty if WIDTH >= 2: tl.store( - conv_state_ptrs + 0 * conv_state_l_stride, - x_val_0.to(out_dtype), - mask=mask, + conv_state_ptrs + 0 * conv_state_l_stride, x_val_0.to(out_dtype), mask=mask ) if WIDTH >= 3: tl.store( - conv_state_ptrs + 1 * conv_state_l_stride, - x_val_1.to(out_dtype), - mask=mask, + conv_state_ptrs + 1 * conv_state_l_stride, x_val_1.to(out_dtype), mask=mask ) if WIDTH >= 4: tl.store( - conv_state_ptrs + 2 * conv_state_l_stride, - x_val_2.to(out_dtype), - mask=mask, + conv_state_ptrs + 2 * conv_state_l_stride, x_val_2.to(out_dtype), mask=mask ) else: i = 0 @@ -191,26 +185,12 @@ def causal_conv1d_update_kernel( ) out_dtype = int_base.dtype.element_ty if WIDTH >= 2: - tl.store( - int_base + 0 * int_state_l_stride, - x_val_0.to(out_dtype), - mask=mask, - ) + tl.store(int_base + 0 * int_state_l_stride, x_val_0.to(out_dtype), mask=mask) if WIDTH >= 3: - tl.store( - int_base + 1 * int_state_l_stride, - x_val_1.to(out_dtype), - mask=mask, - ) + tl.store(int_base + 1 * int_state_l_stride, x_val_1.to(out_dtype), mask=mask) if WIDTH >= 4: - tl.store( - int_base + 2 * int_state_l_stride, - x_val_2.to(out_dtype), - mask=mask, - ) - tl.store( - int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask - ) + tl.store(int_base + 2 * int_state_l_stride, x_val_2.to(out_dtype), mask=mask) + tl.store(int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask) else: i = 0 while i < state_len: diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py index 3015f5ed989..624cd8c048b 100644 --- a/tests/unit_tests/ssm/test_causal_conv1d_triton.py +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -221,6 +221,50 @@ def test_intermediate_state(self, width): conv_state_ref[:, :, -1] = x[:, s, :] torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_state_len_eq_width_fast_path(self, width): + """Cover the ``state_len == WIDTH`` fast path (the common Mamba + configuration where d_conv == width). + + The other tests use ``state_len = 8`` so they always fall through to + the explicit shift loop. Here ``state_len = width`` exercises the + register-resident shift and the matching ``HAS_INT_STATE`` branch. + """ + torch.manual_seed(42) + B, seq_len, D = 2, 4, 64 + state_len = width + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_initial = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + conv_state_triton = conv_state_initial.clone() + conv_state_ref = conv_state_initial.clone() + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + # Output and final state match the reference. + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + # Per-step intermediate states match a manual replay. + replay_state = conv_state_initial.clone() + for s in range(seq_len): + replay_state[:, :, :-1] = replay_state[:, :, 1:].clone() + replay_state[:, :, -1] = x[:, s, :] + torch.testing.assert_close(int_states[:, s, :, :], replay_state, atol=1e-5, rtol=1e-5) + def test_intermediate_state_with_indices(self): """Test intermediate states work correctly with conv_state_indices mapping.""" torch.manual_seed(42)