Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 108 files
15 changes: 11 additions & 4 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
).cuda()
if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True
Expand Down Expand Up @@ -309,6 +310,7 @@ def run_dpa_with_cp(
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
max_score = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn(
Expand All @@ -323,6 +325,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_score:
out, max_score = out
if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
Expand Down Expand Up @@ -401,6 +405,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext()

# run attention
max_score_ = None
with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn(
Expand All @@ -415,6 +420,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
)
if config.return_max_score:
out_, max_score_ = out_
if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
Expand Down Expand Up @@ -496,15 +503,15 @@ def run_dpa_with_cp(
)

atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_score_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_score]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_score"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i]:
if "softmax_offset" not in names[i] and "max_score" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
Expand Down
71 changes: 60 additions & 11 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_dot_product_attention(

# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
unfused_attn_fwd, unfused_max_score, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
Expand All @@ -180,7 +180,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -192,7 +192,7 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -203,7 +203,7 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
fused_attn_fwd_1, fused_max_score_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -216,7 +216,7 @@ def test_dot_product_attention(

# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, flash_max_score, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
Expand All @@ -232,16 +232,22 @@ def test_dot_product_attention(
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(flash_max_score, unfused_max_score, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(fused_max_score, unfused_max_score, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(fused_max_score, flash_max_score, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backends) == 2:
Expand All @@ -260,6 +266,40 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_max_score = {
# test: ModelConfig(b, sq, hq, dqk)
"max_score_1_0": ModelConfig(8, 128, 16, 64),
"max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"max_score_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_score_3_0": ModelConfig(
8, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"
),
"max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
"max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
"max_score_4_1": ModelConfig(
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
),
"max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
"max_score_5_1": ModelConfig(
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
),
"max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_score])
@pytest.mark.parametrize("model", model_configs_max_score.keys())
def test_dpa_max_score(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_score = True
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down Expand Up @@ -1065,6 +1105,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
Expand Down Expand Up @@ -1102,16 +1143,20 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
max_score = None
if config.return_max_score:
out, max_score = out
if is_training:
out.backward(d_out)

d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_score, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
Expand Down Expand Up @@ -1140,14 +1185,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
return (
out_orig,
max_score,
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
)
else:
return out_orig, (None, None, None, d_softmax_offset)
return out_orig, max_score, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return out, max_score, (None, None, None, d_softmax_offset)


model_configs_te_layer = {
Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):

model_configs_fused_attn = {
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_score=True), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_score=True), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
Expand Down
3 changes: 3 additions & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_score=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
Expand Down Expand Up @@ -233,6 +234,7 @@ def __init__(
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_score = return_max_score
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
Expand Down Expand Up @@ -318,6 +320,7 @@ def test():
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
)
(
use_flash_attention,
Expand Down
Loading