Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
66c6627
add max_score for fused/unfused F16 non-CP
cyanguwa Sep 16, 2025
8f9155f
calculate max per head instead of max over all heads
cyanguwa Sep 30, 2025
efaf827
fix fused attn max_score shape
cyanguwa Oct 8, 2025
290dfb9
revert FE to github
cyanguwa Oct 10, 2025
b1f300e
Merge branch 'main' into add_muon
cyanguwa Oct 10, 2025
c93518b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2025
63a7f79
update FE to 1.15.0-rc
cyanguwa Oct 10, 2025
78d5426
fix merge
cyanguwa Oct 10, 2025
437219b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2025
2adb1f2
reduce ew kernels; fix causal masks; add more tests
cyanguwa Oct 12, 2025
bc7d6b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2025
7946127
minor fix to tests
cyanguwa Oct 12, 2025
f984602
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2025
cb01843
remove logic for flash-attn
cyanguwa Oct 12, 2025
966b657
WIP: add CP support for p2p/a2a/all_gather
cyanguwa Oct 12, 2025
9d614a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2025
8e62fa6
minor improvements of implementation/tests
cyanguwa Oct 15, 2025
69b7ae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2025
6e40f17
Merge branch 'main' into add_muon
cyanguwa Oct 15, 2025
326a54c
Merge branch 'main' into add_muon
cyanguwa Oct 19, 2025
84a67b3
WIP: add thd support
cyanguwa Oct 20, 2025
1b64526
add thd to UnfusedDPA
cyanguwa Oct 21, 2025
c8b3bea
fix lint
cyanguwa Oct 21, 2025
6f1c9f0
more fixes for lint
cyanguwa Oct 21, 2025
494d048
Merge branch 'main' into add_muon
cyanguwa Oct 21, 2025
442d1fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
d6055ac
Merge branch 'main' into add_muon
cyanguwa Oct 22, 2025
5bc407a
update to FE 1.15
cyanguwa Oct 22, 2025
aee07fa
remove unneeded changes
cyanguwa Oct 22, 2025
5703783
disable unfused for thd + pad_between_seqs
cyanguwa Oct 22, 2025
e84a843
minor fixes
cyanguwa Oct 22, 2025
f92ed68
disable thd for unfused until bug is fixed
cyanguwa Oct 22, 2025
cde5f1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2025
795c4ac
Merge branch 'main' into add_muon
cyanguwa Oct 23, 2025
9b0af47
fix all_gather
cyanguwa Oct 23, 2025
a0ccabb
fix all gather
cyanguwa Oct 23, 2025
2872dd3
rename max_score to max_logit
cyanguwa Oct 23, 2025
200d98f
fix all_gather
cyanguwa Oct 23, 2025
4aadbc6
fix all_gather
cyanguwa Oct 23, 2025
03f1734
Merge branch 'main' into add_muon
cyanguwa Oct 24, 2025
eb580f3
disable fused attn + thd
cyanguwa Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 108 files
15 changes: 11 additions & 4 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
Expand Down Expand Up @@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else:
fp8_context = nullcontext()
max_score = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn(
Expand All @@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_score:
out, max_score = out
if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
Expand Down Expand Up @@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext()

# run attention
max_score_ = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn(
Expand All @@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_score:
out_, max_score_ = out_
if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
Expand Down Expand Up @@ -495,15 +502,15 @@ def run_dpa_with_cp(
)

atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_score_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_score]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_score"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if "softmax_offset" not in names[i] and "max_score" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
Expand Down
68 changes: 57 additions & 11 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
)

# Get backends
is_training = True
Expand Down Expand Up @@ -172,7 +177,7 @@ def test_dot_product_attention(

# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
unfused_attn_fwd, unfused_max_score, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
Expand All @@ -186,7 +191,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -198,7 +203,7 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -209,7 +214,7 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -222,7 +227,7 @@ def test_dot_product_attention(

# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
Expand All @@ -243,6 +248,8 @@ def test_dot_product_attention(
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(fused_max_score, unfused_max_score, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
Expand All @@ -266,6 +273,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_max_score = {
# test: ModelConfig(b, sq, hq, dqk)
"max_score_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_score_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_score_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"max_score_4": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_score_5": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_score_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_score])
@pytest.mark.parametrize("model", model_configs_max_score.keys())
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
def test_dpa_max_score(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_score = True
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down Expand Up @@ -962,6 +996,8 @@ def _run_dot_product_attention(
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
Expand Down Expand Up @@ -1071,6 +1107,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
Expand Down Expand Up @@ -1108,16 +1145,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
max_score = None
if config.return_max_score:
out, max_score = out
if is_training:
out.backward(d_out)

d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_score, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
Expand Down Expand Up @@ -1146,14 +1188,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
return (
out_orig,
max_score,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else:
return out_orig, (None, None, None, d_softmax_offset)
return out_orig, max_score, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_score, (None, None, None, d_softmax_offset)


model_configs_te_layer = {
Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):

model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_score=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_score=True), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
Expand Down
3 changes: 3 additions & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_score=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
Expand Down Expand Up @@ -233,6 +234,7 @@ def __init__(
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_score = return_max_score
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
Expand Down Expand Up @@ -318,6 +320,7 @@ def test():
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
)
(
use_flash_attention,
Expand Down
Loading
Loading