Skip to content

Commit 7a7cbdb

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 84a67b3 commit 7a7cbdb

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

transformer_engine/common/fused_attn/utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ struct FADescriptor_v1 {
122122
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
123123
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
124124
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
125-
o_tensor_type, do_tensor_type, dqkv_tensor_type,
126-
generate_max_sum_exp) <
125+
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
127126
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
128127
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
129128
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,

transformer_engine/pytorch/cpp_extensions/fused_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def fused_attn_fwd(
323323
)
324324

325325
if return_max_score:
326-
qkv_format = qkv_layout.replace("3","").replace("2","").split("_")[0]
326+
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
327327
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
328328
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
329329
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]

0 commit comments

Comments
 (0)