@@ -259,7 +259,7 @@ def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size
259
259
return x
260
260
261
261
262
- def reorder_seq_chunks_before_a2a_after_attn_thd (x , cu_seqlens , cp_size , seq_dim = 0 ):
262
+ def reorder_seq_chunks_before_a2a_after_attn_thd (x , cu_seqlens , cp_size , seq_dim = 0 ):
263
263
"""
264
264
Reorder sequence chunks for A2A communication that happens after attention
265
265
compute.
@@ -298,22 +298,19 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim
298
298
]
299
299
"""
300
300
total_slices_of_any_sequence = 2 * cp_size
301
- slice_sizes = (
302
- cu_seqlens [1 :] - cu_seqlens [:- 1 ]
303
- ) // total_slices_of_any_sequence
301
+ slice_sizes = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]) // total_slices_of_any_sequence
304
302
305
303
indices = [
306
304
(
307
305
# 1st segment
308
306
torch .arange (
309
- seq_start + (cp_rank * slice_size ),
310
- seq_start + ((cp_rank + 1 ) * slice_size )
307
+ seq_start + (cp_rank * slice_size ), seq_start + ((cp_rank + 1 ) * slice_size )
311
308
),
312
309
# 2nd segment
313
310
torch .arange (
314
311
seq_start + ((total_slices_of_any_sequence - cp_rank - 1 ) * slice_size ),
315
- seq_start + ((total_slices_of_any_sequence - cp_rank ) * slice_size )
316
- )
312
+ seq_start + ((total_slices_of_any_sequence - cp_rank ) * slice_size ),
313
+ ),
317
314
)
318
315
for cp_rank in range (cp_size )
319
316
for slice_size , seq_start in zip (slice_sizes , cu_seqlens [:- 1 ])
@@ -325,9 +322,7 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim
325
322
return x .index_select (seq_dim , indices )
326
323
327
324
328
- def reorder_seq_chunks_after_a2a_before_attn_thd (
329
- x , cu_seqlens , seq_chunk_ids , cp_size , seq_dim = 0
330
- ):
325
+ def reorder_seq_chunks_after_a2a_before_attn_thd (x , cu_seqlens , seq_chunk_ids , cp_size , seq_dim = 0 ):
331
326
"""
332
327
Reorder sequence chunks for A2A communication that happens before attention
333
328
compute.
@@ -374,17 +369,23 @@ def reorder_seq_chunks_after_a2a_before_attn_thd(
374
369
"""
375
370
376
371
max_cum_seqlen_per_cp_rank = cu_seqlens [- 1 ] // cp_size
377
- cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size
372
+ cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size
378
373
379
374
# Go through all the sequence segments (the sizes should be the same from all the ranks)
380
375
indices = [
381
376
torch .arange (
382
377
# Calculate 'left' boundary
383
- start + max_cum_seqlen_per_cp_rank * (chunk_id // 2 ) if loc < cp_size
384
- else (start + end ) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2 ),
378
+ (
379
+ start + max_cum_seqlen_per_cp_rank * (chunk_id // 2 )
380
+ if loc < cp_size
381
+ else (start + end ) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2 )
382
+ ),
385
383
# Calculate 'right' boundary
386
- (start + end ) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2 ) if loc < cp_size
387
- else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2 )
384
+ (
385
+ (start + end ) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2 )
386
+ if loc < cp_size
387
+ else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2 )
388
+ ),
388
389
)
389
390
for start , end in zip (cu_seqlens_on_any_cp_rank [:- 1 ], cu_seqlens_on_any_cp_rank [1 :])
390
391
for loc , chunk_id in enumerate (seq_chunk_ids )
@@ -428,8 +429,10 @@ def flash_attn_a2a_communicate(
428
429
)
429
430
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
430
431
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
431
- a2a_outputs [i - 2 ] = x .view (* x .shape [:seq_dim ], - 1 , * x .shape [(seq_dim + 2 ) :])
432
- else : # qkv_format == "thd"
432
+ a2a_outputs [i - 2 ] = x .view (
433
+ * x .shape [:seq_dim ], - 1 , * x .shape [(seq_dim + 2 ) :]
434
+ )
435
+ else : # qkv_format == "thd"
433
436
# [cp, t, np//cp, hn] -> [cp*t, np//cp, hn]
434
437
x = x .view (- 1 , * x .shape [2 :])
435
438
# reorder the sequence chunks
@@ -464,7 +467,7 @@ def flash_attn_a2a_communicate(
464
467
a2a_inputs [i ] = reorder_seq_chunks_for_a2a_after_attn (
465
468
x , chunk_ids_for_a2a , seq_dim , cp_size
466
469
)
467
- else : # qkv_format == "thd"
470
+ else : # qkv_format == "thd"
468
471
# reorder the sequence chunks
469
472
x = reorder_seq_chunks_before_a2a_after_attn_thd (x , cu_seqlens , cp_size )
470
473
# [cp*t, np//cp, hn] -> [cp, t, np//cp, hn]
@@ -3244,7 +3247,9 @@ def forward(
3244
3247
3245
3248
causal = "causal" in attn_mask_type
3246
3249
padding = "padding" in attn_mask_type
3247
- assert not padding or qkv_format == "thd" , f"{ attn_mask_type } mask type is not supported for BSHD and SBHD!"
3250
+ assert (
3251
+ not padding or qkv_format == "thd"
3252
+ ), f"{ attn_mask_type } mask type is not supported for BSHD and SBHD!"
3248
3253
assert attn_bias_type == "no_bias" , f"{ attn_bias_type } bias type is not supported!"
3249
3254
assert q .shape [- 1 ] % 8 == 0 , "Hidden size per attention head should be multiple of 8!"
3250
3255
assert (
@@ -3300,7 +3305,7 @@ def forward(
3300
3305
if qkv_format in ["bshd" , "sbhd" ]:
3301
3306
batch_dim = qkv_format .index ("b" )
3302
3307
seq_dim = qkv_format .index ("s" )
3303
- else : # qkv_format == "thd"
3308
+ else : # qkv_format == "thd"
3304
3309
batch_dim = seq_dim = qkv_format .index ("t" )
3305
3310
3306
3311
assert (
@@ -3347,8 +3352,15 @@ def forward(
3347
3352
3348
3353
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn (cp_size , q .device )
3349
3354
q , k , v = flash_attn_a2a_communicate (
3350
- [q , k , v ], chunk_ids_for_a2a , seq_dim , cp_size , cp_group , cp_stream ,
3351
- before_attn = True , qkv_format = qkv_format , cu_seqlens = cu_seqlens_q_padded
3355
+ [q , k , v ],
3356
+ chunk_ids_for_a2a ,
3357
+ seq_dim ,
3358
+ cp_size ,
3359
+ cp_group ,
3360
+ cp_stream ,
3361
+ before_attn = True ,
3362
+ qkv_format = qkv_format ,
3363
+ cu_seqlens = cu_seqlens_q_padded ,
3352
3364
)
3353
3365
if softmax_type != "vanilla" :
3354
3366
softmax_offset = flash_attn_a2a_communicate_softmax_offset (
@@ -3437,8 +3449,15 @@ def forward(
3437
3449
3438
3450
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn (cp_size , out .device )
3439
3451
out = flash_attn_a2a_communicate (
3440
- out , chunk_ids_for_a2a , seq_dim , cp_size , cp_group , cp_stream ,
3441
- before_attn = False , qkv_format = qkv_format , cu_seqlens = cu_seqlens_q_padded
3452
+ out ,
3453
+ chunk_ids_for_a2a ,
3454
+ seq_dim ,
3455
+ cp_size ,
3456
+ cp_group ,
3457
+ cp_stream ,
3458
+ before_attn = False ,
3459
+ qkv_format = qkv_format ,
3460
+ cu_seqlens = cu_seqlens_q_padded ,
3442
3461
)
3443
3462
3444
3463
if use_fused_attention :
@@ -3556,7 +3575,7 @@ def backward(ctx, dout):
3556
3575
3557
3576
if qkv_format in ["bshd" , "sbhd" ]:
3558
3577
seq_dim = qkv_format .index ("s" )
3559
- else : # qkv_format == "thd"
3578
+ else : # qkv_format == "thd"
3560
3579
seq_dim = qkv_format .index ("t" )
3561
3580
3562
3581
bwd_nominal_dtype = ctx .fwd_nominal_dtype
@@ -3593,8 +3612,15 @@ def backward(ctx, dout):
3593
3612
3594
3613
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn (cp_size , out .device )
3595
3614
out , dout = flash_attn_a2a_communicate (
3596
- [out , dout ], chunk_ids_for_a2a , seq_dim , cp_size , ctx .cp_group , ctx .cp_stream ,
3597
- before_attn = True , qkv_format = qkv_format , cu_seqlens = cu_seqlens_q_padded
3615
+ [out , dout ],
3616
+ chunk_ids_for_a2a ,
3617
+ seq_dim ,
3618
+ cp_size ,
3619
+ ctx .cp_group ,
3620
+ ctx .cp_stream ,
3621
+ before_attn = True ,
3622
+ qkv_format = qkv_format ,
3623
+ cu_seqlens = cu_seqlens_q_padded ,
3598
3624
)
3599
3625
3600
3626
flash_attn_bwd = None
@@ -3704,8 +3730,15 @@ def backward(ctx, dout):
3704
3730
3705
3731
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn (cp_size , dq .device )
3706
3732
dq , dk , dv = flash_attn_a2a_communicate (
3707
- [dq , dk , dv ], chunk_ids_for_a2a , seq_dim , cp_size , ctx .cp_group , ctx .cp_stream ,
3708
- before_attn = False , qkv_format = qkv_format , cu_seqlens = cu_seqlens_q_padded
3733
+ [dq , dk , dv ],
3734
+ chunk_ids_for_a2a ,
3735
+ seq_dim ,
3736
+ cp_size ,
3737
+ ctx .cp_group ,
3738
+ ctx .cp_stream ,
3739
+ before_attn = False ,
3740
+ qkv_format = qkv_format ,
3741
+ cu_seqlens = cu_seqlens_q_padded ,
3709
3742
)
3710
3743
3711
3744
if qkv_format == "bshd" :
0 commit comments