@@ -295,7 +295,10 @@ def abstract(
295295 elif backend == NVTE_Fused_Attn_Backend .NVTE_F16_arbitrary_seqlen :
296296 # cuDNN 9.6 reduces the required softmax shape
297297 if get_cudnn_version () >= (9 , 6 , 0 ):
298- softmax_shape = (* batch_shape , attn_heads , q_max_seqlen , 1 )
298+ if config .qkv_layout .is_thd ():
299+ softmax_shape = (* batch_shape , q_max_seqlen , attn_heads , 1 )
300+ else :
301+ softmax_shape = (* batch_shape , attn_heads , q_max_seqlen , 1 )
299302 else :
300303 softmax_shape = (
301304 * batch_shape ,
@@ -607,28 +610,49 @@ def batcher(batched_args, batch_dims, *, config):
607610 def infer_sharding_from_operands (config , mesh , arg_infos , result_infos ):
608611 del result_infos
609612 q_spec = get_padded_spec (arg_infos [0 ])
613+
614+ # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
615+ # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
616+ is_packed_softmax = get_cudnn_version () >= (9 , 6 , 0 ) and config .qkv_layout .is_thd ()
617+
610618 if config .qkv_layout .is_qkvpacked ():
611619 # q_spec = (...batch, q_seqlen, 3, head, hidden)
612620 out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec [:- 3 ], * q_spec [- 2 :]))
613- softmax_aux_sharding = NamedSharding (
614- mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 2 ], q_spec [- 4 ], None )
615- )
621+ if not is_packed_softmax :
622+ softmax_aux_sharding = NamedSharding (
623+ mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 2 ], q_spec [- 4 ], None )
624+ )
625+ else :
626+ softmax_aux_sharding = NamedSharding (
627+ mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 4 ], q_spec [- 2 ], None )
628+ )
616629 elif config .qkv_layout .is_kvpacked ():
617630 # q_spec = (...batch, q_seqlen, head, hidden)
618631 # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
619632 out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec ))
620- softmax_aux_sharding = NamedSharding (
621- mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
622- )
633+ if not is_packed_softmax :
634+ softmax_aux_sharding = NamedSharding (
635+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
636+ )
637+ else :
638+ softmax_aux_sharding = NamedSharding (
639+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 3 ], q_spec [- 2 ], None )
640+ )
623641 elif config .qkv_layout .is_separate ():
624642 # q_spec = (...batch, q_seqlen, head, hidden)
625643 # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
626644 out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec ))
627- softmax_aux_sharding = NamedSharding (
628- mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
629- )
645+ if not is_packed_softmax :
646+ softmax_aux_sharding = NamedSharding (
647+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
648+ )
649+ else :
650+ softmax_aux_sharding = NamedSharding (
651+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 3 ], q_spec [- 2 ], None )
652+ )
630653 else :
631654 raise ValueError (f"Unsupported { config .qkv_layout = } " )
655+
632656 rng_state_sharding = NamedSharding (mesh , PartitionSpec (get_all_mesh_axes (), None ))
633657 return (out_sharding , softmax_aux_sharding , rng_state_sharding )
634658
@@ -2236,7 +2260,6 @@ def scan_kv_block(idx, carry):
22362260 subblock_config ,
22372261 )
22382262
2239- # TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
22402263 softmax_aux_per_step = softmax_aux_per_step .reshape ((batch , q_max_seqlen , head , 1 ))
22412264
22422265 def skip_correction (_output , _softmax_aux , output_per_step , softmax_aux_per_step ):
@@ -2272,8 +2295,6 @@ def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
22722295 carry = scan_kv_block (i , carry )
22732296 (_ , _ , _ , output , softmax_aux ) = carry
22742297
2275- softmax_aux = softmax_aux .reshape ((batch , head , q_max_seqlen , 1 ))
2276-
22772298 return output .astype (q .dtype ), softmax_aux , rng_state
22782299
22792300 return mesh , fwd_impl , out_shardings , arg_shardings
0 commit comments