Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions megatron/core/ssm/ops/causal_conv1d_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,33 +139,71 @@ def causal_conv1d_update_kernel(
conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask
).to(tl.float32)

# Shift the linear state buffer left by 1
i = 0
while i < state_len - 1:
val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask)
tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask)
i += 1

# Process the single token for the current sequence step
x_val = tl.load(x_ptrs, mask=mask)

# Shift the linear state buffer left by 1. When state_len == WIDTH (the
# common case: conv_state dim == conv kernel width) the shifted values
# are already resident in the x_val_* registers from the loads above, so
# we can write them back without a second HBM read per position. For
# state_len > WIDTH the leading positions are untouched by compute; fall
# through to the explicit load+store shift.
if state_len == WIDTH:
Comment thread
wdykas marked this conversation as resolved.
out_dtype = conv_state_ptrs.dtype.element_ty
if WIDTH >= 2:
tl.store(
conv_state_ptrs + 0 * conv_state_l_stride, x_val_0.to(out_dtype), mask=mask
)
if WIDTH >= 3:
tl.store(
conv_state_ptrs + 1 * conv_state_l_stride, x_val_1.to(out_dtype), mask=mask
)
if WIDTH >= 4:
tl.store(
conv_state_ptrs + 2 * conv_state_l_stride, x_val_2.to(out_dtype), mask=mask
)
else:
i = 0
while i < state_len - 1:
val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask)
tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask)
i += 1

# Store the new token at the end of the linear state buffer
tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask)

# Write out to the intermediate state buffer if requested
# Write out to the intermediate state buffer if requested. Reuse the
# register values from the shift above (and x_val for the new tail
# position) when state_len == WIDTH, instead of re-reading from HBM.
if HAS_INT_STATE:
i = 0
while i < state_len:
val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask)
int_ptr = (
if state_len == WIDTH:
int_base = (
int_state_ptr
+ state_batch_coord * int_state_b_stride
+ s * int_state_s_stride
+ channel_offsets * int_state_c_stride
+ i * int_state_l_stride
)
tl.store(int_ptr, val, mask=mask)
i += 1
out_dtype = int_base.dtype.element_ty
if WIDTH >= 2:
tl.store(int_base + 0 * int_state_l_stride, x_val_0.to(out_dtype), mask=mask)
if WIDTH >= 3:
tl.store(int_base + 1 * int_state_l_stride, x_val_1.to(out_dtype), mask=mask)
if WIDTH >= 4:
tl.store(int_base + 2 * int_state_l_stride, x_val_2.to(out_dtype), mask=mask)
tl.store(int_base + (state_len - 1) * int_state_l_stride, x_val, mask=mask)
else:
i = 0
while i < state_len:
val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask)
int_ptr = (
int_state_ptr
+ state_batch_coord * int_state_b_stride
+ s * int_state_s_stride
+ channel_offsets * int_state_c_stride
+ i * int_state_l_stride
)
tl.store(int_ptr, val, mask=mask)
i += 1

# Advance registers for calculation
x_val_f32 = x_val.to(tl.float32)
Expand Down
44 changes: 44 additions & 0 deletions tests/unit_tests/ssm/test_causal_conv1d_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,50 @@ def test_intermediate_state(self, width):
conv_state_ref[:, :, -1] = x[:, s, :]
torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5)

@pytest.mark.parametrize("width", [2, 3, 4])
def test_state_len_eq_width_fast_path(self, width):
"""Cover the ``state_len == WIDTH`` fast path (the common Mamba
configuration where d_conv == width).

The other tests use ``state_len = 8`` so they always fall through to
the explicit shift loop. Here ``state_len = width`` exercises the
register-resident shift and the matching ``HAS_INT_STATE`` branch.
"""
torch.manual_seed(42)
B, seq_len, D = 2, 4, 64
state_len = width
x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32)
conv_state_initial = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32)
weight = torch.randn(D, width, device="cuda", dtype=torch.float32)
int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32)

conv_state_triton = conv_state_initial.clone()
conv_state_ref = conv_state_initial.clone()

result = causal_conv1d_update(
x,
conv_state_triton,
weight,
bias=None,
silu_activation=False,
conv_state_indices=None,
intermediate_conv_states=int_states,
)
expected = causal_conv1d_update_ref(
x, conv_state_ref, weight, bias=None, silu_activation=False
)

# Output and final state match the reference.
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5)

# Per-step intermediate states match a manual replay.
replay_state = conv_state_initial.clone()
for s in range(seq_len):
replay_state[:, :, :-1] = replay_state[:, :, 1:].clone()
replay_state[:, :, -1] = x[:, s, :]
torch.testing.assert_close(int_states[:, s, :, :], replay_state, atol=1e-5, rtol=1e-5)

def test_intermediate_state_with_indices(self):
"""Test intermediate states work correctly with conv_state_indices mapping."""
torch.manual_seed(42)
Expand Down
Loading