diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index f57f5d94cea..4ff7ddcf933 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -139,33 +139,71 @@ def causal_conv1d_update_kernel( conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask ).to(tl.float32) - # Shift the linear state buffer left by 1 - i = 0 - while i < state_len - 1: - val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) - tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) - i += 1 - # Process the single token for the current sequence step x_val = tl.load(x_ptrs, mask=mask) + # Shift the linear state buffer left by 1. When state_len == WIDTH (the + # common case: conv_state dim == conv kernel width) the shifted values + # are already resident in the x_val_* registers from the loads above, so + # we can write them back without a second HBM read per position. For + # state_len > WIDTH the leading positions are untouched by compute; fall + # through to the explicit load+store shift. + if state_len == WIDTH: + out_dtype = conv_state_ptrs.dtype.element_ty + if WIDTH >= 2: + tl.store( + conv_state_ptrs + 0 * conv_state_l_stride, x_val_0.to(out_dtype), mask=mask + ) + if WIDTH >= 3: + tl.store( + conv_state_ptrs + 1 * conv_state_l_stride, x_val_1.to(out_dtype), mask=mask + ) + if WIDTH >= 4: + tl.store( + conv_state_ptrs + 2 * conv_state_l_stride, x_val_2.to(out_dtype), mask=mask + ) + else: + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + # Store the new token at the end of the linear state buffer tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) - # Write out to the intermediate state buffer if requested + # Write out to the intermediate state buffer if requested. Reuse the + # register values from the shift above (and x_val for the new tail + # position) when state_len == WIDTH, instead of re-reading from HBM. if HAS_INT_STATE: - i = 0 - while i < state_len: - val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) - int_ptr = ( + if state_len == WIDTH: + int_base = ( int_state_ptr + state_batch_coord * int_state_b_stride + s * int_state_s_stride + channel_offsets * int_state_c_stride - + i * int_state_l_stride ) - tl.store(int_ptr, val, mask=mask) - i += 1 + out_dtype = int_base.dtype.element_ty + if WIDTH >= 2: + tl.store(int_base + 0 * int_state_l_stride, x_val_0.to(out_dtype), mask=mask) + if WIDTH >= 3: + tl.store(int_base + 1 * int_state_l_stride, x_val_1.to(out_dtype), mask=mask) + if WIDTH >= 4: + tl.store(int_base + 2 * int_state_l_stride, x_val_2.to(out_dtype), mask=mask) + tl.store(int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask) + else: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 # Advance registers for calculation x_val_f32 = x_val.to(tl.float32) diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py index 3015f5ed989..624cd8c048b 100644 --- a/tests/unit_tests/ssm/test_causal_conv1d_triton.py +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -221,6 +221,50 @@ def test_intermediate_state(self, width): conv_state_ref[:, :, -1] = x[:, s, :] torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_state_len_eq_width_fast_path(self, width): + """Cover the ``state_len == WIDTH`` fast path (the common Mamba + configuration where d_conv == width). + + The other tests use ``state_len = 8`` so they always fall through to + the explicit shift loop. Here ``state_len = width`` exercises the + register-resident shift and the matching ``HAS_INT_STATE`` branch. + """ + torch.manual_seed(42) + B, seq_len, D = 2, 4, 64 + state_len = width + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_initial = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + conv_state_triton = conv_state_initial.clone() + conv_state_ref = conv_state_initial.clone() + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + # Output and final state match the reference. + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + # Per-step intermediate states match a manual replay. + replay_state = conv_state_initial.clone() + for s in range(seq_len): + replay_state[:, :, :-1] = replay_state[:, :, 1:].clone() + replay_state[:, :, -1] = x[:, s, :] + torch.testing.assert_close(int_states[:, s, :, :], replay_state, atol=1e-5, rtol=1e-5) + def test_intermediate_state_with_indices(self): """Test intermediate states work correctly with conv_state_indices mapping.""" torch.manual_seed(42)