diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 37cbf42fdd..925adf9a19 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -868,6 +868,12 @@ def _flash_attn_bwd( current_stream, ) + # NB num_threads application for 3 kernels + # There are pre, main, post processing kernels, currenlty num_threads is only actually + # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do + # before cache key gen + num_threads = 384 + # Backward kernel: compute dk, dv, dq_accum. score_mod_hash = utils.hash_callable(score_mod) if score_mod else False score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False @@ -936,7 +942,6 @@ def _flash_attn_bwd( seqused_q is None, seqused_k is None, ) - num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)