Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,12 @@ def _flash_attn_bwd(
current_stream,
)

# NB num_threads application for 3 kernels
# There are pre, main, post processing kernels, currenlty num_threads is only actually
# used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
# before cache key gen
num_threads = 384

# Backward kernel: compute dk, dv, dq_accum.
score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False
Expand Down Expand Up @@ -936,7 +942,6 @@ def _flash_attn_bwd(
seqused_q is None,
seqused_k is None,
)
num_threads = 384
if compile_key not in _flash_attn_bwd.compile_cache:
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
Expand Down