Skip to content

Commit 3e34552

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c4170ad commit 3e34552

File tree

6 files changed

+63
-22
lines changed

6 files changed

+63
-22
lines changed

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,22 @@ def generate_input_shapes(
9090
cu_seqlens_kv_padded = None
9191
elif qkv_format == "thd":
9292
# seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
93-
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32).to(torch.int32) * config.max_seqlen_q
93+
seqlens_q = (
94+
torch.ones([config.batch_size], dtype=torch.int32).to(torch.int32) * config.max_seqlen_q
95+
)
9496
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
9597
cu_seqlens_q_padded = torch.cat(
9698
[
9799
torch.zeros([1], dtype=torch.int32),
98100
seqlens_q_padded.cumsum(0, dtype=torch.int32),
99-
#torch.tensor([q_input_shape[0]], dtype=torch.int32),
101+
# torch.tensor([q_input_shape[0]], dtype=torch.int32),
100102
]
101103
).cuda()
102104
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
103-
print(f"dev {torch.cuda.current_device()} cu_seqlens_q: {cu_seqlens_q}, cu_seqlens_q_padded: {cu_seqlens_q_padded}")
105+
print(
106+
f"dev {torch.cuda.current_device()} cu_seqlens_q: {cu_seqlens_q}, cu_seqlens_q_padded:"
107+
f" {cu_seqlens_q_padded}"
108+
)
104109
q_input_shape = (
105110
cu_seqlens_q_padded[-1],
106111
config.num_heads,
@@ -266,9 +271,9 @@ def run_dpa_with_cp(
266271
cu_seqlens_q_padded,
267272
cu_seqlens_kv_padded,
268273
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
269-
q_orig = torch.clamp(2*torch.ones(q_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
270-
k_orig = torch.clamp(2*torch.ones(k_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
271-
v_orig = torch.clamp(2*torch.ones(v_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
274+
q_orig = torch.clamp(2 * torch.ones(q_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
275+
k_orig = torch.clamp(2 * torch.ones(k_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
276+
v_orig = torch.clamp(2 * torch.ones(v_input_shape, dtype=dtypes[dtype]), min=-1, max=2).cuda()
272277
dout_orig = torch.clamp(
273278
torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1
274279
).cuda()
@@ -448,7 +453,9 @@ def run_dpa_with_cp(
448453
if not fp8_bwd:
449454
tensors[0], tensors[4] = tensors_to_deq
450455
for i, tensor in enumerate(tensors):
451-
print(f"dev {torch.cuda.current_device()} tensor {i} is nan: {torch.isnan(tensor).nonzero()}")
456+
print(
457+
f"dev {torch.cuda.current_device()} tensor {i} is nan: {torch.isnan(tensor).nonzero()}"
458+
)
452459
# print(f"dev {torch.cuda.current_device()} tensor {i} is inf: {torch.isinf(tensor).non_zero()}")
453460
assert torch.all(~torch.isnan(tensor))
454461
assert torch.all(~torch.isinf(tensor))

tests/pytorch/attention/test_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def test_dot_product_attention(
240240
if config.return_max_score:
241241
torch.testing.assert_close(fused_max_score, unfused_max_score, **tols)
242242
for i, _ in enumerate(unfused_attn_bwd):
243-
print(f'iiiiii {i}')
243+
print(f"iiiiii {i}")
244244
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
245245
if fused_attn_supported and flash_attn_supported:
246246
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
@@ -282,7 +282,9 @@ def test_dpa_max_score(dtype, model_configs, model):
282282
"""Test DotProductAttention module with checkpointing"""
283283
config = model_configs[model]
284284
config.return_max_score = True
285-
test_dot_product_attention(dtype, model_configs, model, False, True, "thd_thd_thd", False, False)
285+
test_dot_product_attention(
286+
dtype, model_configs, model, False, True, "thd_thd_thd", False, False
287+
)
286288

287289

288290
model_configs_softmax = {

tests/pytorch/attention/test_attention_with_cp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
183183
qkv_formats = ["bshd", "sbhd", "thd"]
184184
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
185185
if test_essential:
186-
configs = ["cp_1_0", "cp_1_1"] #, "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
186+
configs = ["cp_1_0", "cp_1_1"] # , "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
187187
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
188188
dtypes = ["bf16", "fp8"]
189189
qkv_formats = ["sbhd", "thd"]

transformer_engine/common/fused_attn/utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ struct FADescriptor_v1 {
122122
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
123123
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
124124
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
125-
o_tensor_type, do_tensor_type, dqkv_tensor_type,
126-
generate_max_sum_exp) <
125+
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
127126
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
128127
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
129128
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,10 @@ def forward(
16171617
softmax_lse_per_step[i - 1],
16181618
)
16191619
if return_max_score:
1620-
print(f"dev={torch.cuda.current_device()} i={i}, max_score_per_step={max_score_per_step[i - 1]}")
1620+
print(
1621+
f"dev={torch.cuda.current_device()} i={i},"
1622+
f" max_score_per_step={max_score_per_step[i - 1]}"
1623+
)
16211624
if i == 1:
16221625
max_score = torch.clone(max_score_per_step[0])
16231626
else:

transformer_engine/pytorch/cpp_extensions/fused_attn.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,12 @@ def fused_attn_fwd(
324324
)
325325

326326
if return_max_score:
327-
qkv_format = qkv_layout.replace("3","").replace("2","").split("_")[0]
328-
print(f"dev {torch.cuda.current_device()} qkv_format: {qkv_format}, cu_seqlens_q: {cu_seqlens_q}, cu_seqlens_kv: {cu_seqlens_kv}, cu_seqlens_q_padded: {cu_seqlens_q_padded}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded}")
327+
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
328+
print(
329+
f"dev {torch.cuda.current_device()} qkv_format: {qkv_format}, cu_seqlens_q:"
330+
f" {cu_seqlens_q}, cu_seqlens_kv: {cu_seqlens_kv}, cu_seqlens_q_padded:"
331+
f" {cu_seqlens_q_padded}, cu_seqlens_kv_padded: {cu_seqlens_kv_padded}"
332+
)
329333
# print(f"dev {torch.cuda.current_device()} q: {q.shape}, k: {k.shape}, v: {v.shape}")
330334
# print(f"dev {torch.cuda.current_device()} output_tensors[0] is nan: {torch.isnan(output_tensors[0]).sum()}, output_tensors[0]: {output_tensors[0].shape}, output_tensors[0].min(): {output_tensors[0].min()}, output_tensors[0].max(): {output_tensors[0].max()}")
331335
# print(f"dev {torch.cuda.current_device()} output_tensors[1] is nan: {torch.isnan(output_tensors[1]).sum()}, output_tensors[1]: {output_tensors[1].shape}, output_tensors[1].min(): {output_tensors[1].min()}, output_tensors[1].max(): {output_tensors[1].max()}")
@@ -335,7 +339,10 @@ def fused_attn_fwd(
335339
stats = output_tensors[1] + torch.log(output_tensors[2])
336340
zero_indices_1 = (output_tensors[1] == 0).nonzero()
337341
zero_indices_2 = (output_tensors[2] == 0).nonzero()
338-
print(f"dev {torch.cuda.current_device()} zero_indices_1: {zero_indices_1}, zero_indices_2: {zero_indices_2}")
342+
print(
343+
f"dev {torch.cuda.current_device()} zero_indices_1: {zero_indices_1},"
344+
f" zero_indices_2: {zero_indices_2}"
345+
)
339346
if torch.cuda.current_device() == 0 and not os.path.exists("output_tensors1.pt"):
340347
torch.save(output_tensors[1], "output_tensors1.pt")
341348
torch.save(output_tensors[2], "output_tensors2.pt")
@@ -344,16 +351,39 @@ def fused_attn_fwd(
344351
# Max [tq, h, 1] -> max_score [h]
345352
max_score = torch.amax(output_tensors[1], dim=(0, 2)).to(dtype=output_tensors[0].dtype)
346353
print(f"dev {torch.cuda.current_device()} max_score: {max_score}")
347-
print(f"dev {torch.cuda.current_device()} output_tensors[0] is nan: {torch.isnan(output_tensors[0]).sum()}, output_tensors[0]: {output_tensors[0].shape}, output_tensors[0].min(): {output_tensors[0].min()}, output_tensors[0].max(): {output_tensors[0].max()}")
348-
print(f"dev {torch.cuda.current_device()} output_tensors[1] is nan: {torch.isnan(output_tensors[1]).sum()}, output_tensors[1]: {output_tensors[1].shape}, output_tensors[1].min(): {output_tensors[1].min()}, output_tensors[1].max(): {output_tensors[1].max()}")
349-
print(f"dev {torch.cuda.current_device()} output_tensors[2] is nan: {torch.isnan(output_tensors[2]).sum()}, output_tensors[2]: {output_tensors[2].shape}, output_tensors[2].min(): {output_tensors[2].min()}, output_tensors[2].max(): {output_tensors[2].max()}")
350-
print(f"dev {torch.cuda.current_device()} stats is nan: {torch.isnan(stats).sum()}, stats: {stats.shape}, stats.min(): {stats.min()}, stats.max(): {stats.max()}")
351-
print(f"dev {torch.cuda.current_device()} max_score is nan: {torch.isnan(max_score).sum()}, max_score: {max_score.shape} ")
354+
print(
355+
f"dev {torch.cuda.current_device()} output_tensors[0] is nan:"
356+
f" {torch.isnan(output_tensors[0]).sum()}, output_tensors[0]:"
357+
f" {output_tensors[0].shape}, output_tensors[0].min(): {output_tensors[0].min()},"
358+
f" output_tensors[0].max(): {output_tensors[0].max()}"
359+
)
360+
print(
361+
f"dev {torch.cuda.current_device()} output_tensors[1] is nan:"
362+
f" {torch.isnan(output_tensors[1]).sum()}, output_tensors[1]:"
363+
f" {output_tensors[1].shape}, output_tensors[1].min(): {output_tensors[1].min()},"
364+
f" output_tensors[1].max(): {output_tensors[1].max()}"
365+
)
366+
print(
367+
f"dev {torch.cuda.current_device()} output_tensors[2] is nan:"
368+
f" {torch.isnan(output_tensors[2]).sum()}, output_tensors[2]:"
369+
f" {output_tensors[2].shape}, output_tensors[2].min(): {output_tensors[2].min()},"
370+
f" output_tensors[2].max(): {output_tensors[2].max()}"
371+
)
372+
print(
373+
f"dev {torch.cuda.current_device()} stats is nan: {torch.isnan(stats).sum()},"
374+
f" stats: {stats.shape}, stats.min(): {stats.min()}, stats.max(): {stats.max()}"
375+
)
376+
print(
377+
f"dev {torch.cuda.current_device()} max_score is nan:"
378+
f" {torch.isnan(max_score).sum()}, max_score: {max_score.shape} "
379+
)
352380
else:
353381
# output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
354382
stats = output_tensors[1] + torch.log(output_tensors[2])
355383
# Max [b, h, sq, 1] -> max_score [h]
356-
max_score = torch.amax(output_tensors[1], dim=(0, 2, 3)).to(dtype=output_tensors[0].dtype)
384+
max_score = torch.amax(output_tensors[1], dim=(0, 2, 3)).to(
385+
dtype=output_tensors[0].dtype
386+
)
357387
aux_ctx_tensors = [stats]
358388
aux_ctx_tensors.extend(output_tensors[3:])
359389
return output_tensors[0], aux_ctx_tensors, max_score

0 commit comments

Comments
 (0)