12
12
from transformer_engine .pytorch .attention .dot_product_attention .context_parallel import (
13
13
get_cu_seqlens_on_cp_rank ,
14
14
)
15
+ from transformer_engine .pytorch .attention .dot_product_attention .utils import combine_and_quantize
15
16
import transformer_engine_torch as tex
16
17
from test_attention_with_cp import model_configs_flash_attn , model_configs_fused_attn
17
18
from transformer_engine .pytorch .fp8 import fp8_autocast
18
- from transformer_engine .pytorch .tensor .float8_tensor import Float8Tensor , Float8Quantizer
19
- from transformer_engine .common .recipe import DelayedScaling
19
+ from transformer_engine .pytorch .tensor .float8_tensor import (
20
+ Float8Tensor ,
21
+ Float8Quantizer ,
22
+ Float8CurrentScalingQuantizer ,
23
+ )
24
+ from transformer_engine .common .recipe import DelayedScaling , Float8CurrentScaling
20
25
from utils import ModelConfig , compare_and_assert
21
26
22
-
23
27
dtypes = {"fp16" : torch .float16 , "bf16" : torch .bfloat16 , "fp8" : torch .bfloat16 }
24
28
25
29
@@ -151,7 +155,7 @@ def get_tols(config, dtype):
151
155
elif dtype == "fp8" :
152
156
atol = 5e-1
153
157
rtol = 5e-1
154
- rmse_tol = 0.1
158
+ rmse_tol = 0.15
155
159
else :
156
160
assert False , f"{ dtype = } is not supported!"
157
161
@@ -164,14 +168,23 @@ def run_dpa_with_cp(
164
168
qkv_format = "bshd" ,
165
169
kernel_backend = "FlashAttention" ,
166
170
cp_comm_type = "p2p" ,
167
- fp8_mha = False ,
171
+ fp8_bwd = "True" ,
172
+ fp8_dpa = "False" ,
173
+ fp8_mha = "False" ,
174
+ scaling_mode = "delayed" ,
175
+ f16_O = "False" ,
168
176
log_level = logging .WARNING ,
169
177
):
170
178
"""Test DotProductAttention module with context parallelism"""
171
179
logging .root .setLevel (log_level )
172
180
173
181
# set up environment variables and config
174
- fp8_mha = fp8_mha == "True"
182
+ fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
183
+ os .environ ["NVTE_FP8_DPA_BWD" ] = "1" if fp8_bwd else "0"
184
+ fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
185
+ fp8_mha = fp8_mha == "True" and dtype == "fp8"
186
+ f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
187
+ os .environ ["NVTE_DPA_FP8CS_O_in_F16" ] = "1" if f16_O else "0"
175
188
os .environ ["NVTE_FLASH_ATTN" ] = "0"
176
189
os .environ ["NVTE_FUSED_ATTN" ] = "0"
177
190
if kernel_backend == "FlashAttention" :
@@ -219,8 +232,12 @@ def run_dpa_with_cp(
219
232
sub_group = dist .new_group (sub_ranks , backend = "nccl" )
220
233
if rank in sub_ranks :
221
234
cp_comm_sub_groups .append (sub_group )
235
+
222
236
if dtype == "fp8" :
223
- fp8_recipe = DelayedScaling (fp8_dpa = True , fp8_mha = fp8_mha )
237
+ if scaling_mode == "delayed" :
238
+ fp8_recipe = DelayedScaling (fp8_dpa = fp8_dpa , fp8_mha = fp8_mha )
239
+ if scaling_mode == "current" :
240
+ fp8_recipe = Float8CurrentScaling (fp8_dpa = fp8_dpa , fp8_mha = fp8_mha )
224
241
225
242
# instantiate attention module
226
243
core_attn = DotProductAttention (
@@ -247,19 +264,38 @@ def run_dpa_with_cp(
247
264
cu_seqlens_q_padded ,
248
265
cu_seqlens_kv_padded ,
249
266
) = generate_input_shapes (qkv_format , config , world_size , kernel_backend )
250
- q = torch .randn (q_input_shape , dtype = dtypes [dtype ]).cuda ()
251
- k = torch .randn (k_input_shape , dtype = dtypes [dtype ]).cuda ()
252
- v = torch .randn (v_input_shape , dtype = dtypes [dtype ]).cuda ()
253
- for x in [q , k , v ]:
254
- x .requires_grad = True
255
-
256
- dout = torch .randn (attn_output_shape , dtype = dtypes [dtype ]).cuda ()
257
- if fp8_mha :
267
+ q_orig = torch .clamp (torch .randn (q_input_shape , dtype = dtypes [dtype ]), min = - 1 , max = 1 ).cuda ()
268
+ k_orig = torch .clamp (torch .randn (k_input_shape , dtype = dtypes [dtype ]), min = - 1 , max = 1 ).cuda ()
269
+ v_orig = torch .clamp (torch .randn (v_input_shape , dtype = dtypes [dtype ]), min = - 1 , max = 1 ).cuda ()
270
+ dout_orig = torch .clamp (
271
+ torch .randn (attn_output_shape , dtype = dtypes [dtype ]), min = - 1 , max = 1
272
+ ).cuda ()
273
+ if scaling_mode == "delayed" :
274
+ qkv_quantizer = Float8Quantizer (
275
+ fp8_dtype = tex .DType .kFloat8E4M3 ,
276
+ scale = torch .tensor ([1 ], dtype = torch .float32 ).cuda (),
277
+ amax = torch .tensor ([0 ], dtype = torch .float32 ).cuda (),
278
+ )
258
279
dout_quantizer = Float8Quantizer (
259
280
fp8_dtype = tex .DType .kFloat8E5M2 ,
260
281
scale = torch .tensor ([1 ], dtype = torch .float32 ).cuda (),
261
282
amax = torch .tensor ([0 ], dtype = torch .float32 ).cuda (),
262
283
)
284
+ if scaling_mode == "current" :
285
+ qkv_quantizer = Float8CurrentScalingQuantizer (
286
+ fp8_dtype = tex .DType .kFloat8E4M3 ,
287
+ device = "cuda" ,
288
+ )
289
+ dout_quantizer = Float8CurrentScalingQuantizer (
290
+ fp8_dtype = tex .DType .kFloat8E5M2 ,
291
+ device = "cuda" ,
292
+ )
293
+ qkv_layout = "_" .join ([qkv_format ] * 3 )
294
+ q , k , v , dout = [x .clone ().detach () for x in [q_orig , k_orig , v_orig , dout_orig ]]
295
+ if fp8_mha :
296
+ q , k , v = combine_and_quantize (qkv_layout , q , k , v , qkv_quantizer )
297
+ for x in [q , k , v ]:
298
+ x .requires_grad = True
263
299
264
300
if config .attn_bias_type not in ["no_bias" , "alibi" ]:
265
301
attn_bias_shape = (1 , 1 , config .max_seqlen_q , config .max_seqlen_kv )
@@ -274,6 +310,7 @@ def run_dpa_with_cp(
274
310
else :
275
311
fp8_context = nullcontext ()
276
312
with fp8_context :
313
+ # q, k, v, out in FP8; dout in F16
277
314
out = core_attn (
278
315
q ,
279
316
k ,
@@ -284,8 +321,9 @@ def run_dpa_with_cp(
284
321
cu_seqlens_kv = cu_seqlens_kv ,
285
322
cu_seqlens_q_padded = cu_seqlens_q_padded ,
286
323
cu_seqlens_kv_padded = cu_seqlens_kv_padded ,
324
+ fp8_output = fp8_mha ,
287
325
)
288
- if fp8_mha :
326
+ if fp8_bwd and fp8_mha :
289
327
dout_fp8 = dout_quantizer (dout )
290
328
out .backward (dout_fp8 )
291
329
else :
@@ -298,24 +336,10 @@ def run_dpa_with_cp(
298
336
############ run with CP ############
299
337
logging .info (f"[Rank { rank } ] Run with context parallelism" )
300
338
301
- # set up environment
302
- core_attn .set_context_parallel_group (
303
- cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group ,
304
- cp_comm_ranks ,
305
- torch .cuda .Stream (),
306
- cp_comm_type ,
307
- )
308
- if config .softmax_type != "vanilla" :
309
- core_attn .softmax_offset .grad .zero_ ()
310
- if dtype == "fp8" :
311
- core_attn .reset_fp8_meta_tensors ()
312
- fp8_context = fp8_autocast (enabled = True , fp8_recipe = fp8_recipe , fp8_group = cp_comm_group )
313
- else :
314
- fp8_context = nullcontext ()
315
-
316
339
# set up inputs
317
340
q_ , k_ , v_ , dout_ , * rest = [
318
- x .clone ().detach () for x in [q , k , v , dout ] + ([] if bias is None else [bias ])
341
+ x .clone ().detach ()
342
+ for x in [q_orig , k_orig , v_orig , dout_orig ] + ([] if bias is None else [bias ])
319
343
]
320
344
bias_ = rest [0 ] if len (rest ) else None
321
345
if qkv_format == "bshd" or qkv_format == "sbhd" :
@@ -343,16 +367,42 @@ def run_dpa_with_cp(
343
367
)
344
368
q_ , dout_ = [x .index_select (0 , seq_idx_q ) for x in [q_ , dout_ ]]
345
369
k_ , v_ = [x .index_select (0 , seq_idx_kv ) for x in [k_ , v_ ]]
370
+ else :
371
+ assert False , f"{ qkv_format } is an unsupported qkv_format!"
372
+ q_ , k_ , v_ , dout_ = [x .contiguous () for x in [q_ , k_ , v_ , dout_ ]]
373
+ if scaling_mode == "delayed" :
374
+ qkv_quantizer .scale .fill_ (1.0 )
375
+ qkv_quantizer .amax .fill_ (0.0 )
376
+ dout_quantizer .scale .fill_ (1.0 )
377
+ dout_quantizer .amax .fill_ (0.0 )
378
+ if fp8_mha :
379
+ q_ , k_ , v_ = combine_and_quantize (qkv_layout , q_ , k_ , v_ , qkv_quantizer )
346
380
q_ , k_ , v_ = [x .requires_grad_ () for x in [q_ , k_ , v_ ]]
347
381
if bias_ is not None :
348
382
bias_ = bias_ .view (
349
383
* bias_ .shape [:- 2 ], 2 * world_size , bias_ .shape [- 2 ] // (2 * world_size ), bias_ .shape [- 1 ]
350
384
)
351
385
bias_ = bias_ .index_select (2 , seq_idx )
352
386
bias_ = bias_ .view (* bias_ .shape [:2 ], - 1 , bias_ .shape [- 1 ])
387
+ # set up environment
388
+ core_attn .set_context_parallel_group (
389
+ cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group ,
390
+ cp_comm_ranks ,
391
+ torch .cuda .Stream (),
392
+ cp_comm_type ,
393
+ )
394
+ if config .softmax_type != "vanilla" :
395
+ core_attn .softmax_offset .grad .zero_ ()
396
+ if dtype == "fp8" :
397
+ core_attn .fp8_initialized = False
398
+ core_attn .fp8_meta_tensors_initialized = False
399
+ fp8_context = fp8_autocast (enabled = True , fp8_recipe = fp8_recipe , fp8_group = cp_comm_group )
400
+ else :
401
+ fp8_context = nullcontext ()
353
402
354
403
# run attention
355
404
with fp8_context :
405
+ # q, k, v, out in FP8; dout in F16
356
406
out_ = core_attn (
357
407
q_ ,
358
408
k_ ,
@@ -363,27 +413,30 @@ def run_dpa_with_cp(
363
413
cu_seqlens_kv = cu_seqlens_kv ,
364
414
cu_seqlens_q_padded = cu_seqlens_q_padded ,
365
415
cu_seqlens_kv_padded = cu_seqlens_kv_padded ,
416
+ fp8_output = fp8_mha ,
366
417
)
367
- if fp8_mha :
418
+ if fp8_bwd and fp8_mha :
368
419
dout_fp8_ = dout_quantizer (dout_ )
369
420
out_ .backward (dout_fp8_ )
370
421
else :
371
422
out_ .backward (dout_ )
372
- if fp8_mha :
373
- assert isinstance (out , Float8Tensor )
374
- assert isinstance (out_ , Float8Tensor )
375
- out = out .dequantize ()
376
- out_ = out_ .dequantize ()
377
-
378
- # get outputs
379
423
dq_ , dk_ , dv_ = q_ .grad , k_ .grad , v_ .grad
380
424
d_softmax_offset_ = None
381
425
if config .softmax_type != "vanilla" :
382
426
d_softmax_offset_ = core_attn .softmax_offset .grad .clone ()
383
- for x in [out_ , dq_ , dk_ , dv_ , d_softmax_offset_ ]:
384
- if x is not None :
385
- assert torch .all (~ torch .isnan (x ))
386
- assert torch .all (~ torch .isinf (x ))
427
+
428
+ # get outputs
429
+ tensors = [out , dq , dk , dv , out_ , dq_ , dk_ , dv_ ]
430
+ if fp8_mha :
431
+ tensors_to_deq = [out , out_ ] if not fp8_bwd else tensors
432
+ for i , tensor in enumerate (tensors_to_deq ):
433
+ tensors_to_deq [i ] = tensor .dequantize ()
434
+ if not fp8_bwd :
435
+ tensors [0 ], tensors [4 ] = tensors_to_deq
436
+ for tensor in tensors :
437
+ assert torch .all (~ torch .isnan (tensor ))
438
+ assert torch .all (~ torch .isinf (tensor ))
439
+ out , dq , dk , dv , out_ , dq_ , dk_ , dv_ = tensors
387
440
388
441
############ compare results between CP and no-CP ############
389
442
if qkv_format == "bshd" or qkv_format == "sbhd" :
@@ -394,17 +447,17 @@ def run_dpa_with_cp(
394
447
x .shape [seq_dim ] // (2 * world_size ),
395
448
* x .shape [(seq_dim + 1 ) :],
396
449
)
397
- for x in [q . grad , k . grad , v . grad , out ]
450
+ for x in [dq , dk , dv , out ]
398
451
]
399
452
dq , dk , dv , out = [x .index_select (seq_dim , seq_idx ) for x in [dq , dk , dv , out ]]
400
453
dq_ , dk_ , dv_ , out_ = [
401
454
x .view (* x .shape [:seq_dim ], 2 , x .shape [seq_dim ] // 2 , * x .shape [(seq_dim + 1 ) :])
402
- for x in [q_ . grad , k_ . grad , v_ . grad , out_ ]
455
+ for x in [dq_ , dk_ , dv_ , out_ ]
403
456
]
404
457
elif qkv_format == "thd" :
405
- dq , out = [x .index_select (0 , seq_idx_q ).contiguous () for x in [q . grad , out ]]
406
- dk , dv = [x .index_select (0 , seq_idx_kv ).contiguous () for x in [k . grad , v . grad ]]
407
- dq_ , dk_ , dv_ , out_ = [q_ . grad , k_ . grad , v_ . grad , out_ ]
458
+ dq , out = [x .index_select (0 , seq_idx_q ).contiguous () for x in [dq , out ]]
459
+ dk , dv = [x .index_select (0 , seq_idx_kv ).contiguous () for x in [dk , dv ]]
460
+ dq_ , dk_ , dv_ , out_ = [dq_ , dk_ , dv_ , out_ ]
408
461
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
409
462
cu_seqlens_q = get_cu_seqlens_on_cp_rank (
410
463
cu_seqlens_q , cu_seqlens_q_padded , world_size , rank , True , True
0 commit comments