Skip to content

Commit 87cb26c

Browse files
[PyTorch] Add max_logit support for MuonClip (#2195)
* add max_score for fused/unfused F16 non-CP Signed-off-by: Charlene Yang <[email protected]> * calculate max per head instead of max over all heads Signed-off-by: Charlene Yang <[email protected]> * fix fused attn max_score shape Signed-off-by: Charlene Yang <[email protected]> * revert FE to github Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.15.0-rc Signed-off-by: Charlene Yang <[email protected]> * fix merge Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reduce ew kernels; fix causal masks; add more tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix to tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for flash-attn Signed-off-by: Charlene Yang <[email protected]> * WIP: add CP support for p2p/a2a/all_gather Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor improvements of implementation/tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: add thd support Signed-off-by: Charlene Yang <[email protected]> * add thd to UnfusedDPA Signed-off-by: Charlene Yang <[email protected]> * fix lint Signed-off-by: Charlene Yang <[email protected]> * more fixes for lint Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to FE 1.15 Signed-off-by: Charlene Yang <[email protected]> * remove unneeded changes Signed-off-by: Charlene Yang <[email protected]> * disable unfused for thd + pad_between_seqs Signed-off-by: Charlene Yang <[email protected]> * minor fixes Signed-off-by: Charlene Yang <[email protected]> * disable thd for unfused until bug is fixed Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * fix all gather Signed-off-by: Charlene Yang <[email protected]> * rename max_score to max_logit Signed-off-by: Charlene Yang <[email protected]> * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * disable fused attn + thd Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 060811c commit 87cb26c

File tree

19 files changed

+748
-305
lines changed

19 files changed

+748
-305
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 108 files

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
248248
attn_mask_type=config.attn_mask_type,
249249
window_size=config.window_size,
250250
softmax_type=config.softmax_type,
251+
return_max_logit=config.return_max_logit,
251252
).cuda()
252253
if config.softmax_type != "vanilla":
253254
core_attn.softmax_offset.requires_grad = True
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
308309
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
309310
else:
310311
fp8_context = nullcontext()
312+
max_logit = None
311313
with fp8_context:
312314
# q, k, v, out in FP8; dout in F16
313315
out = core_attn(
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
322324
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
323325
fp8_output=fp8_mha,
324326
)
327+
if config.return_max_logit:
328+
out, max_logit = out
325329
if fp8_bwd and fp8_mha:
326330
dout_fp8 = dout_quantizer(dout)
327331
out.backward(dout_fp8)
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
400404
fp8_context = nullcontext()
401405

402406
# run attention
407+
max_logit_ = None
403408
with fp8_context:
404409
# q, k, v, out in FP8; dout in F16
405410
out_ = core_attn(
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
414419
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
415420
fp8_output=fp8_mha,
416421
)
422+
if config.return_max_logit:
423+
out_, max_logit_ = out_
417424
if fp8_bwd and fp8_mha:
418425
dout_fp8_ = dout_quantizer(dout_)
419426
out_.backward(dout_fp8_)
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
495502
)
496503

497504
atol, rtol, rmse_tol = get_tols(config, dtype)
498-
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
499-
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
500-
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
505+
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
506+
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
507+
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
501508
names_cp = [x + "_cp" for x in names]
502509
names_no_cp = [x + "_no_cp" for x in names]
503510
is_fp8 = dtype == "fp8"
504511
for i, t in enumerate(tensors_no_cp):
505512
if t is not None:
506-
if "softmax_offset" not in names[i]:
513+
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
507514
if qkv_format == "bshd":
508515
compare_and_assert(
509516
t[:, 0],

tests/pytorch/attention/test_attention.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def test_dot_product_attention(
131131
if config.window_size == (-1, -1) and swa:
132132
config.window_size = [2, 2]
133133
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
134+
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
135+
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
136+
config.attn_mask_type = (
137+
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
138+
)
134139

135140
# Get backends
136141
is_training = True
@@ -172,7 +177,7 @@ def test_dot_product_attention(
172177

173178
# UnfusedDotProductAttention backend
174179
if unfused_attn_supported:
175-
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
180+
unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
176181
dtype,
177182
config,
178183
"UnfusedDotProductAttention",
@@ -186,7 +191,7 @@ def test_dot_product_attention(
186191
# FusedAttention backend
187192
if fused_attn_supported:
188193
if len(fused_attn_backends) == 1:
189-
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
194+
fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
190195
dtype,
191196
config,
192197
"FusedAttention",
@@ -198,7 +203,7 @@ def test_dot_product_attention(
198203
)
199204
if len(fused_attn_backends) == 2:
200205
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
201-
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
206+
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
202207
dtype,
203208
config,
204209
"FusedAttention",
@@ -209,7 +214,7 @@ def test_dot_product_attention(
209214
is_training,
210215
)
211216
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
212-
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
217+
fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
213218
dtype,
214219
config,
215220
"FusedAttention",
@@ -222,7 +227,7 @@ def test_dot_product_attention(
222227

223228
# FlashAttention backend
224229
if flash_attn_supported:
225-
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
230+
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
226231
dtype,
227232
config,
228233
"FlashAttention",
@@ -243,6 +248,8 @@ def test_dot_product_attention(
243248
if unfused_attn_supported and fused_attn_supported:
244249
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
245250
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
251+
if config.return_max_logit:
252+
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
246253
for i, _ in enumerate(unfused_attn_bwd):
247254
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
248255
if fused_attn_supported and flash_attn_supported:
@@ -266,6 +273,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
266273
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
267274

268275

276+
model_configs_max_logit = {
277+
# test: ModelConfig(b, sq, hq, dqk)
278+
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
279+
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
280+
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
281+
"max_logit_4": ModelConfig(
282+
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
283+
),
284+
"max_logit_5": ModelConfig(
285+
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
286+
),
287+
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
288+
}
289+
290+
291+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
292+
@pytest.mark.parametrize("dtype", param_types)
293+
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
294+
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
295+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
296+
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
297+
"""Test DotProductAttention module with checkpointing"""
298+
config = model_configs[model]
299+
config.return_max_logit = True
300+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
301+
302+
269303
model_configs_softmax = {
270304
# test: ModelConfig(b, sq, hq, dqk)
271305
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
@@ -962,6 +996,8 @@ def _run_dot_product_attention(
962996
layout = layout.replace("d", "dqk")
963997
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
964998
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
999+
# tensor: with padding tokens
1000+
# tensor_orig: without padding tokens
9651001
tensor_orig = tensor
9661002
if qkv_format == "thd" and pad_between_seqs:
9671003
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
@@ -1071,6 +1107,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
10711107
layer_number=1,
10721108
attention_type=config.attn_type,
10731109
softmax_type=config.softmax_type,
1110+
return_max_logit=config.return_max_logit,
10741111
).to(dtype=dtype, device="cuda")
10751112
if not is_training:
10761113
block = block.eval()
@@ -1108,16 +1145,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11081145
alibi_slopes=alibi_slopes,
11091146
fast_zero_fill=True,
11101147
)
1148+
max_logit = None
1149+
if config.return_max_logit:
1150+
out, max_logit = out
11111151
if is_training:
11121152
out.backward(d_out)
1153+
11131154
d_softmax_offset = None
11141155
if is_training and config.softmax_type != "vanilla":
11151156
d_softmax_offset = block.softmax_offset.grad
1157+
11161158
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
11171159
if is_training:
1118-
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1160+
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
11191161
else:
1120-
return out, (None, None, None, d_softmax_offset)
1162+
return out, max_logit, (None, None, None, d_softmax_offset)
11211163
if backend == "FusedAttention":
11221164
if qkv_format == "thd" and pad_between_seqs:
11231165
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
@@ -1146,14 +1188,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11461188
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
11471189
)
11481190
if is_training:
1149-
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
1191+
return (
1192+
out_orig,
1193+
max_logit,
1194+
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
1195+
)
11501196
else:
1151-
return out_orig, (None, None, None, d_softmax_offset)
1197+
return out_orig, max_logit, (None, None, None, d_softmax_offset)
11521198
else:
11531199
if is_training:
1154-
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1200+
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
11551201
else:
1156-
return out, (None, None, None, d_softmax_offset)
1202+
return out, max_logit, (None, None, None, d_softmax_offset)
11571203

11581204

11591205
model_configs_te_layer = {

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
137137

138138
model_configs_fused_attn = {
139139
# test: ModelConfig(b, sq, hq, dqk)
140-
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
141-
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
140+
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
141+
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
142142
"cp_1_2": ModelConfig(
143143
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
144144
), # MHA
@@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
183183
qkv_formats = ["bshd", "sbhd", "thd"]
184184
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
185185
if test_essential:
186-
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
186+
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
187187
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
188188
dtypes = ["bf16", "fp8"]
189189
qkv_formats = ["sbhd", "thd"]

tests/pytorch/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __init__(
205205
window_size: Tuple[int, int] = (-1, -1),
206206
context_parallel: bool = False,
207207
cp_comm_type: str = "p2p",
208+
return_max_logit=False,
208209
total_requests: int = None,
209210
max_ctx_len: int = None,
210211
num_layers: int = 1,
@@ -233,6 +234,7 @@ def __init__(
233234
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
234235
self.context_parallel = context_parallel
235236
self.cp_comm_type = cp_comm_type
237+
self.return_max_logit = return_max_logit
236238
self.total_requests = total_requests
237239
self.max_ctx_len = max_ctx_len
238240
self.num_layers = num_layers
@@ -318,6 +320,7 @@ def test():
318320
is_training=is_training,
319321
inference_params=inference_params,
320322
softmax_type=config.softmax_type,
323+
return_max_logit=config.return_max_logit,
321324
)
322325
(
323326
use_flash_attention,

0 commit comments

Comments
 (0)