Skip to content

Commit 84a67b3

Browse files
committed
WIP: add thd support
Signed-off-by: Charlene Yang <[email protected]>
1 parent 326a54c commit 84a67b3

File tree

7 files changed

+31
-24
lines changed

7 files changed

+31
-24
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,18 +324,31 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
324324
}
325325

326326
std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
327+
if (is_ragged_q && cudnn_runtime_version >= 90600) {
328+
offset_stats =
329+
mha_graph->tensor(fe::graph::Tensor_attributes()
330+
.set_name("offset_stats")
331+
.set_dim({b + 1, 1, 1, 1})
332+
.set_stride({1, 1, 1, 1})
333+
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
334+
}
327335
if (return_max_score) {
328336
Max = mha_graph->tensor(fe::graph::Tensor_attributes()
329337
.set_name("Max")
330338
.set_dim({b, h, s_q, 1})
331-
.set_stride({h * s_q, s_q, 1, 1})
332339
.set_data_type(fe::DataType_t::FLOAT));
333-
sdpa_options.set_logit_max(Max);
334340
Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes()
335341
.set_name("Sum_Exp")
336342
.set_dim({b, h, s_q, 1})
337-
.set_stride({h * s_q, s_q, 1, 1})
338343
.set_data_type(fe::DataType_t::FLOAT));
344+
if (is_ragged_q && cudnn_runtime_version >= 90600) {
345+
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
346+
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
347+
} else {
348+
Max->set_stride({h * s_q, s_q, 1, 1});
349+
Sum_Exp->set_stride({h * s_q, s_q, 1, 1});
350+
}
351+
sdpa_options.set_logit_max(Max);
339352
sdpa_options.set_score_sum_exp(Sum_Exp);
340353
}
341354

@@ -357,12 +370,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
357370
if (!return_max_score) {
358371
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
359372
if (is_ragged_q && cudnn_runtime_version >= 90600) {
360-
offset_stats =
361-
mha_graph->tensor(fe::graph::Tensor_attributes()
362-
.set_name("offset_stats")
363-
.set_dim({b + 1, 1, 1, 1})
364-
.set_stride({1, 1, 1, 1})
365-
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
366373
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
367374
} else {
368375
Stats->set_stride({h * s_q, s_q, 1, 1});

transformer_engine/common/fused_attn/utils.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,22 @@ struct FADescriptor_v1 {
115115
cudnn_frontend::DataType_t o_tensor_type;
116116
cudnn_frontend::DataType_t do_tensor_type;
117117
cudnn_frontend::DataType_t dqkv_tensor_type;
118-
bool generate_stats;
119118
bool generate_max_sum_exp;
120119

121120
bool operator<(const FADescriptor_v1 &rhs) const {
122121
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
123122
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
124123
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
125124
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
126-
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_stats,
125+
o_tensor_type, do_tensor_type, dqkv_tensor_type,
127126
generate_max_sum_exp) <
128127
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
129128
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
130129
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
131130
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
132131
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
133132
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
134-
rhs.dqkv_tensor_type, rhs.generate_stats, rhs.generate_max_sum_exp);
133+
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
135134
}
136135
};
137136

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ def forward(
13501350
return out_ret
13511351

13521352
@staticmethod
1353-
def backward(ctx, d_out, *args):
1353+
def backward(ctx, d_out, *_args):
13541354
# pylint: disable=missing-function-docstring
13551355

13561356
# d_out is expected to be in FP8 if is_output_fp8=True,

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,7 +1845,7 @@ def forward(
18451845
return out_ret
18461846

18471847
@staticmethod
1848-
def backward(ctx, dout, *args):
1848+
def backward(ctx, dout, *_args):
18491849
# pylint: disable=missing-function-docstring
18501850

18511851
# add NVTX range
@@ -2864,7 +2864,7 @@ def forward(
28642864
return out
28652865

28662866
@staticmethod
2867-
def backward(ctx, dout, *args):
2867+
def backward(ctx, dout, *_args):
28682868
# pylint: disable=missing-function-docstring
28692869
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
28702870
cp_size = get_distributed_world_size(ctx.cp_group)
@@ -3425,7 +3425,7 @@ def forward(
34253425
return out_ret
34263426

34273427
@staticmethod
3428-
def backward(ctx, dout, *args):
3428+
def backward(ctx, dout, *_args):
34293429
# pylint: disable=missing-function-docstring
34303430
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
34313431
cp_size = get_distributed_world_size(ctx.cp_group)

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,9 +486,6 @@ def get_attention_backend(
486486
if use_flash_attention:
487487
use_flash_attention = False
488488
logger.debug("Disabling FlashAttention for max_score")
489-
if use_fused_attention and qkv_format == "thd":
490-
logger.debug("Disabling FusedAttention for max_score and qkv_format = thd")
491-
use_fused_attention = False
492489
if fp8 and fp8_meta["recipe"].fp8_dpa:
493490
use_flash_attention = False
494491
use_fused_attention = False

transformer_engine/pytorch/cpp_extensions/fused_attn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,14 @@ def fused_attn_fwd(
323323
)
324324

325325
if return_max_score:
326-
# output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
326+
qkv_format = qkv_layout.replace("3","").replace("2","").split("_")[0]
327+
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
328+
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
329+
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
327330
stats = output_tensors[1] + torch.log(output_tensors[2])
328-
# Max [b, h, sq, 1] -> max_score [h]
329-
max_score = torch.amax(output_tensors[1], dim=(0, 2, 3)).to(dtype=output_tensors[0].dtype)
331+
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
332+
# Max -> max_score [h]
333+
max_score = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
330334
aux_ctx_tensors = [stats]
331335
aux_ctx_tensors.extend(output_tensors[3:])
332336
return output_tensors[0], aux_ctx_tensors, max_score

transformer_engine/pytorch/csrc/extensions/attention.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ std::vector<py::object> fused_attn_fwd(
253253
// f16_max512 : S [b, h, sq, skv]
254254
// f16_arbitrary:
255255
// return_max_score=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
256-
// return_max_score=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv]
256+
// return_max_score=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
257257
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
258258
size_t i = 0;
259259
at::Tensor output_tensor;
@@ -262,7 +262,7 @@ std::vector<py::object> fused_attn_fwd(
262262
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
263263
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
264264
set_tensor_param(i++, output_tensor);
265-
// fp8 has an additional softmax stats tensor, ZInv
265+
// fp8 has an additional softmax stats tensor, ZInv; return_max_score=true has an additional Sum_Exp tensor
266266
if (return_max_score || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
267267
output_tensor =
268268
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),

0 commit comments

Comments
 (0)