Skip to content

Commit caebcfe

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent eadf8d6 commit caebcfe

File tree

2 files changed

+67
-31
lines changed

2 files changed

+67
-31
lines changed

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def test_cp_with_fused_attention(
228228
if cp_comm_type == "all_gather":
229229
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
230230
if cp_comm_type == "a2a+p2p":
231-
pytest.skip("CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format yet!")
231+
pytest.skip(
232+
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
233+
" yet!"
234+
)
232235
if dtype == "fp8" and cp_comm_type == "all_gather":
233236
pytest.skip(
234237
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size
259259
return x
260260

261261

262-
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim = 0):
262+
def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0):
263263
"""
264264
Reorder sequence chunks for A2A communication that happens after attention
265265
compute.
@@ -298,22 +298,19 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim
298298
]
299299
"""
300300
total_slices_of_any_sequence = 2 * cp_size
301-
slice_sizes = (
302-
cu_seqlens[1:] - cu_seqlens[:-1]
303-
) // total_slices_of_any_sequence
301+
slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence
304302

305303
indices = [
306304
(
307305
# 1st segment
308306
torch.arange(
309-
seq_start + (cp_rank * slice_size),
310-
seq_start + ((cp_rank + 1) * slice_size)
307+
seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size)
311308
),
312309
# 2nd segment
313310
torch.arange(
314311
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
315-
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size)
316-
)
312+
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
313+
),
317314
)
318315
for cp_rank in range(cp_size)
319316
for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1])
@@ -325,9 +322,7 @@ def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim
325322
return x.index_select(seq_dim, indices)
326323

327324

328-
def reorder_seq_chunks_after_a2a_before_attn_thd(
329-
x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim = 0
330-
):
325+
def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0):
331326
"""
332327
Reorder sequence chunks for A2A communication that happens before attention
333328
compute.
@@ -374,17 +369,23 @@ def reorder_seq_chunks_after_a2a_before_attn_thd(
374369
"""
375370

376371
max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size
377-
cu_seqlens_on_any_cp_rank = cu_seqlens//cp_size
372+
cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size
378373

379374
# Go through all the sequence segments (the sizes should be the same from all the ranks)
380375
indices = [
381376
torch.arange(
382377
# Calculate 'left' boundary
383-
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2) if loc < cp_size
384-
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2),
378+
(
379+
start + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
380+
if loc < cp_size
381+
else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
382+
),
385383
# Calculate 'right' boundary
386-
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) if loc < cp_size
387-
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
384+
(
385+
(start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
386+
if loc < cp_size
387+
else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2)
388+
),
388389
)
389390
for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:])
390391
for loc, chunk_id in enumerate(seq_chunk_ids)
@@ -428,8 +429,10 @@ def flash_attn_a2a_communicate(
428429
)
429430
# [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn]
430431
# or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn]
431-
a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :])
432-
else: # qkv_format == "thd"
432+
a2a_outputs[i - 2] = x.view(
433+
*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]
434+
)
435+
else: # qkv_format == "thd"
433436
# [cp, t, np//cp, hn] -> [cp*t, np//cp, hn]
434437
x = x.view(-1, *x.shape[2:])
435438
# reorder the sequence chunks
@@ -464,7 +467,7 @@ def flash_attn_a2a_communicate(
464467
a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn(
465468
x, chunk_ids_for_a2a, seq_dim, cp_size
466469
)
467-
else: # qkv_format == "thd"
470+
else: # qkv_format == "thd"
468471
# reorder the sequence chunks
469472
x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size)
470473
# [cp*t, np//cp, hn] -> [cp, t, np//cp, hn]
@@ -3244,7 +3247,9 @@ def forward(
32443247

32453248
causal = "causal" in attn_mask_type
32463249
padding = "padding" in attn_mask_type
3247-
assert not padding or qkv_format == "thd", f"{attn_mask_type} mask type is not supported for BSHD and SBHD!"
3250+
assert (
3251+
not padding or qkv_format == "thd"
3252+
), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!"
32483253
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
32493254
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
32503255
assert (
@@ -3300,7 +3305,7 @@ def forward(
33003305
if qkv_format in ["bshd", "sbhd"]:
33013306
batch_dim = qkv_format.index("b")
33023307
seq_dim = qkv_format.index("s")
3303-
else: # qkv_format == "thd"
3308+
else: # qkv_format == "thd"
33043309
batch_dim = seq_dim = qkv_format.index("t")
33053310

33063311
assert (
@@ -3347,8 +3352,15 @@ def forward(
33473352

33483353
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device)
33493354
q, k, v = flash_attn_a2a_communicate(
3350-
[q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream,
3351-
before_attn=True, qkv_format=qkv_format, cu_seqlens=cu_seqlens_q_padded
3355+
[q, k, v],
3356+
chunk_ids_for_a2a,
3357+
seq_dim,
3358+
cp_size,
3359+
cp_group,
3360+
cp_stream,
3361+
before_attn=True,
3362+
qkv_format=qkv_format,
3363+
cu_seqlens=cu_seqlens_q_padded,
33523364
)
33533365
if softmax_type != "vanilla":
33543366
softmax_offset = flash_attn_a2a_communicate_softmax_offset(
@@ -3437,8 +3449,15 @@ def forward(
34373449

34383450
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
34393451
out = flash_attn_a2a_communicate(
3440-
out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream,
3441-
before_attn=False, qkv_format=qkv_format, cu_seqlens=cu_seqlens_q_padded
3452+
out,
3453+
chunk_ids_for_a2a,
3454+
seq_dim,
3455+
cp_size,
3456+
cp_group,
3457+
cp_stream,
3458+
before_attn=False,
3459+
qkv_format=qkv_format,
3460+
cu_seqlens=cu_seqlens_q_padded,
34423461
)
34433462

34443463
if use_fused_attention:
@@ -3556,7 +3575,7 @@ def backward(ctx, dout):
35563575

35573576
if qkv_format in ["bshd", "sbhd"]:
35583577
seq_dim = qkv_format.index("s")
3559-
else: # qkv_format == "thd"
3578+
else: # qkv_format == "thd"
35603579
seq_dim = qkv_format.index("t")
35613580

35623581
bwd_nominal_dtype = ctx.fwd_nominal_dtype
@@ -3593,8 +3612,15 @@ def backward(ctx, dout):
35933612

35943613
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device)
35953614
out, dout = flash_attn_a2a_communicate(
3596-
[out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream,
3597-
before_attn=True, qkv_format=qkv_format, cu_seqlens=cu_seqlens_q_padded
3615+
[out, dout],
3616+
chunk_ids_for_a2a,
3617+
seq_dim,
3618+
cp_size,
3619+
ctx.cp_group,
3620+
ctx.cp_stream,
3621+
before_attn=True,
3622+
qkv_format=qkv_format,
3623+
cu_seqlens=cu_seqlens_q_padded,
35983624
)
35993625

36003626
flash_attn_bwd = None
@@ -3704,8 +3730,15 @@ def backward(ctx, dout):
37043730

37053731
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device)
37063732
dq, dk, dv = flash_attn_a2a_communicate(
3707-
[dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream,
3708-
before_attn=False, qkv_format=qkv_format, cu_seqlens=cu_seqlens_q_padded
3733+
[dq, dk, dv],
3734+
chunk_ids_for_a2a,
3735+
seq_dim,
3736+
cp_size,
3737+
ctx.cp_group,
3738+
ctx.cp_stream,
3739+
before_attn=False,
3740+
qkv_format=qkv_format,
3741+
cu_seqlens=cu_seqlens_q_padded,
37093742
)
37103743

37113744
if qkv_format == "bshd":

0 commit comments

Comments
 (0)