Skip to content

Commit c4c185d

Browse files
cyanguwapre-commit-ci[bot]
authored andcommitted
[PyTorch] Add max_logit support for MuonClip (#2195)
* add max_score for fused/unfused F16 non-CP Signed-off-by: Charlene Yang <[email protected]> * calculate max per head instead of max over all heads Signed-off-by: Charlene Yang <[email protected]> * fix fused attn max_score shape Signed-off-by: Charlene Yang <[email protected]> * revert FE to github Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.15.0-rc Signed-off-by: Charlene Yang <[email protected]> * fix merge Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reduce ew kernels; fix causal masks; add more tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix to tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for flash-attn Signed-off-by: Charlene Yang <[email protected]> * WIP: add CP support for p2p/a2a/all_gather Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor improvements of implementation/tests Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: add thd support Signed-off-by: Charlene Yang <[email protected]> * add thd to UnfusedDPA Signed-off-by: Charlene Yang <[email protected]> * fix lint Signed-off-by: Charlene Yang <[email protected]> * more fixes for lint Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to FE 1.15 Signed-off-by: Charlene Yang <[email protected]> * remove unneeded changes Signed-off-by: Charlene Yang <[email protected]> * disable unfused for thd + pad_between_seqs Signed-off-by: Charlene Yang <[email protected]> * minor fixes Signed-off-by: Charlene Yang <[email protected]> * disable thd for unfused until bug is fixed Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * fix all gather Signed-off-by: Charlene Yang <[email protected]> * rename max_score to max_logit Signed-off-by: Charlene Yang <[email protected]> * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * fix all_gather Signed-off-by: Charlene Yang <[email protected]> * disable fused attn + thd Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8b9849a commit c4c185d

File tree

19 files changed

+748
-305
lines changed

19 files changed

+748
-305
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 108 files

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
248248
attn_mask_type=config.attn_mask_type,
249249
window_size=config.window_size,
250250
softmax_type=config.softmax_type,
251+
return_max_logit=config.return_max_logit,
251252
).cuda()
252253
if config.softmax_type != "vanilla":
253254
core_attn.softmax_offset.requires_grad = True
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
308309
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
309310
else:
310311
fp8_context = nullcontext()
312+
max_logit = None
311313
with fp8_context:
312314
# q, k, v, out in FP8; dout in F16
313315
out = core_attn(
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
322324
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
323325
fp8_output=fp8_mha,
324326
)
327+
if config.return_max_logit:
328+
out, max_logit = out
325329
if fp8_bwd and fp8_mha:
326330
dout_fp8 = dout_quantizer(dout)
327331
out.backward(dout_fp8)
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
400404
fp8_context = nullcontext()
401405

402406
# run attention
407+
max_logit_ = None
403408
with fp8_context:
404409
# q, k, v, out in FP8; dout in F16
405410
out_ = core_attn(
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
414419
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
415420
fp8_output=fp8_mha,
416421
)
422+
if config.return_max_logit:
423+
out_, max_logit_ = out_
417424
if fp8_bwd and fp8_mha:
418425
dout_fp8_ = dout_quantizer(dout_)
419426
out_.backward(dout_fp8_)
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
495502
)
496503

497504
atol, rtol, rmse_tol = get_tols(config, dtype)
498-
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
499-
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
500-
names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
505+
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
506+
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
507+
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
501508
names_cp = [x + "_cp" for x in names]
502509
names_no_cp = [x + "_no_cp" for x in names]
503510
is_fp8 = dtype == "fp8"
504511
for i, t in enumerate(tensors_no_cp):
505512
if t is not None:
506-
if "softmax_offset" not in names[i]:
513+
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
507514
if qkv_format == "bshd":
508515
compare_and_assert(
509516
t[:, 0],

tests/pytorch/attention/test_attention.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def test_dot_product_attention(
130130
if config.window_size == (-1, -1) and swa:
131131
config.window_size = [2, 2]
132132
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
133+
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
134+
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
135+
config.attn_mask_type = (
136+
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
137+
)
133138

134139
# Get backends
135140
is_training = True
@@ -171,7 +176,7 @@ def test_dot_product_attention(
171176

172177
# UnfusedDotProductAttention backend
173178
if unfused_attn_supported:
174-
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
179+
unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention(
175180
dtype,
176181
config,
177182
"UnfusedDotProductAttention",
@@ -185,7 +190,7 @@ def test_dot_product_attention(
185190
# FusedAttention backend
186191
if fused_attn_supported:
187192
if len(fused_attn_backends) == 1:
188-
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
193+
fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention(
189194
dtype,
190195
config,
191196
"FusedAttention",
@@ -197,7 +202,7 @@ def test_dot_product_attention(
197202
)
198203
if len(fused_attn_backends) == 2:
199204
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
200-
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
205+
fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention(
201206
dtype,
202207
config,
203208
"FusedAttention",
@@ -208,7 +213,7 @@ def test_dot_product_attention(
208213
is_training,
209214
)
210215
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
211-
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
216+
fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention(
212217
dtype,
213218
config,
214219
"FusedAttention",
@@ -221,7 +226,7 @@ def test_dot_product_attention(
221226

222227
# FlashAttention backend
223228
if flash_attn_supported:
224-
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
229+
flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention(
225230
dtype,
226231
config,
227232
"FlashAttention",
@@ -242,6 +247,8 @@ def test_dot_product_attention(
242247
if unfused_attn_supported and fused_attn_supported:
243248
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
244249
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
250+
if config.return_max_logit:
251+
torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols)
245252
for i, _ in enumerate(unfused_attn_bwd):
246253
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
247254
if fused_attn_supported and flash_attn_supported:
@@ -265,6 +272,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
265272
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
266273

267274

275+
model_configs_max_logit = {
276+
# test: ModelConfig(b, sq, hq, dqk)
277+
"max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
278+
"max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
279+
"max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
280+
"max_logit_4": ModelConfig(
281+
8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias"
282+
),
283+
"max_logit_5": ModelConfig(
284+
8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0)
285+
),
286+
"max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
287+
}
288+
289+
290+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
291+
@pytest.mark.parametrize("dtype", param_types)
292+
@pytest.mark.parametrize("model_configs", [model_configs_max_logit])
293+
@pytest.mark.parametrize("model", model_configs_max_logit.keys())
294+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"])
295+
def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
296+
"""Test DotProductAttention module with checkpointing"""
297+
config = model_configs[model]
298+
config.return_max_logit = True
299+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
300+
301+
268302
model_configs_softmax = {
269303
# test: ModelConfig(b, sq, hq, dqk)
270304
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
@@ -961,6 +995,8 @@ def _run_dot_product_attention(
961995
layout = layout.replace("d", "dqk")
962996
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
963997
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
998+
# tensor: with padding tokens
999+
# tensor_orig: without padding tokens
9641000
tensor_orig = tensor
9651001
if qkv_format == "thd" and pad_between_seqs:
9661002
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
@@ -1070,6 +1106,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
10701106
layer_number=1,
10711107
attention_type=config.attn_type,
10721108
softmax_type=config.softmax_type,
1109+
return_max_logit=config.return_max_logit,
10731110
).to(dtype=dtype, device="cuda")
10741111
if not is_training:
10751112
block = block.eval()
@@ -1107,16 +1144,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11071144
alibi_slopes=alibi_slopes,
11081145
fast_zero_fill=True,
11091146
)
1147+
max_logit = None
1148+
if config.return_max_logit:
1149+
out, max_logit = out
11101150
if is_training:
11111151
out.backward(d_out)
1152+
11121153
d_softmax_offset = None
11131154
if is_training and config.softmax_type != "vanilla":
11141155
d_softmax_offset = block.softmax_offset.grad
1156+
11151157
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
11161158
if is_training:
1117-
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1159+
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
11181160
else:
1119-
return out, (None, None, None, d_softmax_offset)
1161+
return out, max_logit, (None, None, None, d_softmax_offset)
11201162
if backend == "FusedAttention":
11211163
if qkv_format == "thd" and pad_between_seqs:
11221164
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
@@ -1145,14 +1187,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11451187
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
11461188
)
11471189
if is_training:
1148-
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
1190+
return (
1191+
out_orig,
1192+
max_logit,
1193+
(q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset),
1194+
)
11491195
else:
1150-
return out_orig, (None, None, None, d_softmax_offset)
1196+
return out_orig, max_logit, (None, None, None, d_softmax_offset)
11511197
else:
11521198
if is_training:
1153-
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
1199+
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
11541200
else:
1155-
return out, (None, None, None, d_softmax_offset)
1201+
return out, max_logit, (None, None, None, d_softmax_offset)
11561202

11571203

11581204
model_configs_te_layer = {

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
137137

138138
model_configs_fused_attn = {
139139
# test: ModelConfig(b, sq, hq, dqk)
140-
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
141-
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
140+
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA
141+
"cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA
142142
"cp_1_2": ModelConfig(
143143
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
144144
), # MHA
@@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
183183
qkv_formats = ["bshd", "sbhd", "thd"]
184184
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
185185
if test_essential:
186-
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
186+
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
187187
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
188188
dtypes = ["bf16", "fp8"]
189189
qkv_formats = ["sbhd", "thd"]

tests/pytorch/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __init__(
205205
window_size: Tuple[int, int] = (-1, -1),
206206
context_parallel: bool = False,
207207
cp_comm_type: str = "p2p",
208+
return_max_logit=False,
208209
total_requests: int = None,
209210
max_ctx_len: int = None,
210211
num_layers: int = 1,
@@ -233,6 +234,7 @@ def __init__(
233234
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
234235
self.context_parallel = context_parallel
235236
self.cp_comm_type = cp_comm_type
237+
self.return_max_logit = return_max_logit
236238
self.total_requests = total_requests
237239
self.max_ctx_len = max_ctx_len
238240
self.num_layers = num_layers
@@ -318,6 +320,7 @@ def test():
318320
is_training=is_training,
319321
inference_params=inference_params,
320322
softmax_type=config.softmax_type,
323+
return_max_logit=config.return_max_logit,
321324
)
322325
(
323326
use_flash_attention,

0 commit comments

Comments
 (0)