Skip to content

Commit eadf8d6

Browse files
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into enable_thd_cp_swa
2 parents 9601605 + e30c36a commit eadf8d6

File tree

86 files changed

+4757
-2712
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+4757
-2712
lines changed

3rdparty/cudnn-frontend

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 102 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
1313
get_cu_seqlens_on_cp_rank,
1414
)
15+
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
1516
import transformer_engine_torch as tex
1617
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
1718
from transformer_engine.pytorch.fp8 import fp8_autocast
18-
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
19-
from transformer_engine.common.recipe import DelayedScaling
19+
from transformer_engine.pytorch.tensor.float8_tensor import (
20+
Float8Tensor,
21+
Float8Quantizer,
22+
Float8CurrentScalingQuantizer,
23+
)
24+
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
2025
from utils import ModelConfig, compare_and_assert
2126

22-
2327
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
2428

2529

@@ -151,7 +155,7 @@ def get_tols(config, dtype):
151155
elif dtype == "fp8":
152156
atol = 5e-1
153157
rtol = 5e-1
154-
rmse_tol = 0.1
158+
rmse_tol = 0.15
155159
else:
156160
assert False, f"{dtype=} is not supported!"
157161

@@ -164,14 +168,23 @@ def run_dpa_with_cp(
164168
qkv_format="bshd",
165169
kernel_backend="FlashAttention",
166170
cp_comm_type="p2p",
167-
fp8_mha=False,
171+
fp8_bwd="True",
172+
fp8_dpa="False",
173+
fp8_mha="False",
174+
scaling_mode="delayed",
175+
f16_O="False",
168176
log_level=logging.WARNING,
169177
):
170178
"""Test DotProductAttention module with context parallelism"""
171179
logging.root.setLevel(log_level)
172180

173181
# set up environment variables and config
174-
fp8_mha = fp8_mha == "True"
182+
fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
183+
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
184+
fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
185+
fp8_mha = fp8_mha == "True" and dtype == "fp8"
186+
f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
187+
os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
175188
os.environ["NVTE_FLASH_ATTN"] = "0"
176189
os.environ["NVTE_FUSED_ATTN"] = "0"
177190
if kernel_backend == "FlashAttention":
@@ -219,8 +232,12 @@ def run_dpa_with_cp(
219232
sub_group = dist.new_group(sub_ranks, backend="nccl")
220233
if rank in sub_ranks:
221234
cp_comm_sub_groups.append(sub_group)
235+
222236
if dtype == "fp8":
223-
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
237+
if scaling_mode == "delayed":
238+
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
239+
if scaling_mode == "current":
240+
fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
224241

225242
# instantiate attention module
226243
core_attn = DotProductAttention(
@@ -247,19 +264,38 @@ def run_dpa_with_cp(
247264
cu_seqlens_q_padded,
248265
cu_seqlens_kv_padded,
249266
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
250-
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
251-
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
252-
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
253-
for x in [q, k, v]:
254-
x.requires_grad = True
255-
256-
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
257-
if fp8_mha:
267+
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
268+
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
269+
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
270+
dout_orig = torch.clamp(
271+
torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1
272+
).cuda()
273+
if scaling_mode == "delayed":
274+
qkv_quantizer = Float8Quantizer(
275+
fp8_dtype=tex.DType.kFloat8E4M3,
276+
scale=torch.tensor([1], dtype=torch.float32).cuda(),
277+
amax=torch.tensor([0], dtype=torch.float32).cuda(),
278+
)
258279
dout_quantizer = Float8Quantizer(
259280
fp8_dtype=tex.DType.kFloat8E5M2,
260281
scale=torch.tensor([1], dtype=torch.float32).cuda(),
261282
amax=torch.tensor([0], dtype=torch.float32).cuda(),
262283
)
284+
if scaling_mode == "current":
285+
qkv_quantizer = Float8CurrentScalingQuantizer(
286+
fp8_dtype=tex.DType.kFloat8E4M3,
287+
device="cuda",
288+
)
289+
dout_quantizer = Float8CurrentScalingQuantizer(
290+
fp8_dtype=tex.DType.kFloat8E5M2,
291+
device="cuda",
292+
)
293+
qkv_layout = "_".join([qkv_format] * 3)
294+
q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
295+
if fp8_mha:
296+
q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
297+
for x in [q, k, v]:
298+
x.requires_grad = True
263299

264300
if config.attn_bias_type not in ["no_bias", "alibi"]:
265301
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
@@ -274,6 +310,7 @@ def run_dpa_with_cp(
274310
else:
275311
fp8_context = nullcontext()
276312
with fp8_context:
313+
# q, k, v, out in FP8; dout in F16
277314
out = core_attn(
278315
q,
279316
k,
@@ -284,8 +321,9 @@ def run_dpa_with_cp(
284321
cu_seqlens_kv=cu_seqlens_kv,
285322
cu_seqlens_q_padded=cu_seqlens_q_padded,
286323
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
324+
fp8_output=fp8_mha,
287325
)
288-
if fp8_mha:
326+
if fp8_bwd and fp8_mha:
289327
dout_fp8 = dout_quantizer(dout)
290328
out.backward(dout_fp8)
291329
else:
@@ -298,24 +336,10 @@ def run_dpa_with_cp(
298336
############ run with CP ############
299337
logging.info(f"[Rank {rank}] Run with context parallelism")
300338

301-
# set up environment
302-
core_attn.set_context_parallel_group(
303-
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
304-
cp_comm_ranks,
305-
torch.cuda.Stream(),
306-
cp_comm_type,
307-
)
308-
if config.softmax_type != "vanilla":
309-
core_attn.softmax_offset.grad.zero_()
310-
if dtype == "fp8":
311-
core_attn.reset_fp8_meta_tensors()
312-
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
313-
else:
314-
fp8_context = nullcontext()
315-
316339
# set up inputs
317340
q_, k_, v_, dout_, *rest = [
318-
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
341+
x.clone().detach()
342+
for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias])
319343
]
320344
bias_ = rest[0] if len(rest) else None
321345
if qkv_format == "bshd" or qkv_format == "sbhd":
@@ -343,16 +367,42 @@ def run_dpa_with_cp(
343367
)
344368
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
345369
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
370+
else:
371+
assert False, f"{qkv_format} is an unsupported qkv_format!"
372+
q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]]
373+
if scaling_mode == "delayed":
374+
qkv_quantizer.scale.fill_(1.0)
375+
qkv_quantizer.amax.fill_(0.0)
376+
dout_quantizer.scale.fill_(1.0)
377+
dout_quantizer.amax.fill_(0.0)
378+
if fp8_mha:
379+
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
346380
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
347381
if bias_ is not None:
348382
bias_ = bias_.view(
349383
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
350384
)
351385
bias_ = bias_.index_select(2, seq_idx)
352386
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
387+
# set up environment
388+
core_attn.set_context_parallel_group(
389+
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
390+
cp_comm_ranks,
391+
torch.cuda.Stream(),
392+
cp_comm_type,
393+
)
394+
if config.softmax_type != "vanilla":
395+
core_attn.softmax_offset.grad.zero_()
396+
if dtype == "fp8":
397+
core_attn.fp8_initialized = False
398+
core_attn.fp8_meta_tensors_initialized = False
399+
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
400+
else:
401+
fp8_context = nullcontext()
353402

354403
# run attention
355404
with fp8_context:
405+
# q, k, v, out in FP8; dout in F16
356406
out_ = core_attn(
357407
q_,
358408
k_,
@@ -363,27 +413,30 @@ def run_dpa_with_cp(
363413
cu_seqlens_kv=cu_seqlens_kv,
364414
cu_seqlens_q_padded=cu_seqlens_q_padded,
365415
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
416+
fp8_output=fp8_mha,
366417
)
367-
if fp8_mha:
418+
if fp8_bwd and fp8_mha:
368419
dout_fp8_ = dout_quantizer(dout_)
369420
out_.backward(dout_fp8_)
370421
else:
371422
out_.backward(dout_)
372-
if fp8_mha:
373-
assert isinstance(out, Float8Tensor)
374-
assert isinstance(out_, Float8Tensor)
375-
out = out.dequantize()
376-
out_ = out_.dequantize()
377-
378-
# get outputs
379423
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
380424
d_softmax_offset_ = None
381425
if config.softmax_type != "vanilla":
382426
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
383-
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
384-
if x is not None:
385-
assert torch.all(~torch.isnan(x))
386-
assert torch.all(~torch.isinf(x))
427+
428+
# get outputs
429+
tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
430+
if fp8_mha:
431+
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
432+
for i, tensor in enumerate(tensors_to_deq):
433+
tensors_to_deq[i] = tensor.dequantize()
434+
if not fp8_bwd:
435+
tensors[0], tensors[4] = tensors_to_deq
436+
for tensor in tensors:
437+
assert torch.all(~torch.isnan(tensor))
438+
assert torch.all(~torch.isinf(tensor))
439+
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
387440

388441
############ compare results between CP and no-CP ############
389442
if qkv_format == "bshd" or qkv_format == "sbhd":
@@ -394,17 +447,17 @@ def run_dpa_with_cp(
394447
x.shape[seq_dim] // (2 * world_size),
395448
*x.shape[(seq_dim + 1) :],
396449
)
397-
for x in [q.grad, k.grad, v.grad, out]
450+
for x in [dq, dk, dv, out]
398451
]
399452
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
400453
dq_, dk_, dv_, out_ = [
401454
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
402-
for x in [q_.grad, k_.grad, v_.grad, out_]
455+
for x in [dq_, dk_, dv_, out_]
403456
]
404457
elif qkv_format == "thd":
405-
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
406-
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
407-
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
458+
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
459+
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
460+
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
408461
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
409462
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
410463
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True

0 commit comments

Comments
 (0)