@@ -295,7 +295,10 @@ def abstract(
295
295
elif backend == NVTE_Fused_Attn_Backend .NVTE_F16_arbitrary_seqlen :
296
296
# cuDNN 9.6 reduces the required softmax shape
297
297
if get_cudnn_version () >= (9 , 6 , 0 ):
298
- softmax_shape = (* batch_shape , attn_heads , q_max_seqlen , 1 )
298
+ if config .qkv_layout .is_thd ():
299
+ softmax_shape = (* batch_shape , q_max_seqlen , attn_heads , 1 )
300
+ else :
301
+ softmax_shape = (* batch_shape , attn_heads , q_max_seqlen , 1 )
299
302
else :
300
303
softmax_shape = (
301
304
* batch_shape ,
@@ -607,28 +610,49 @@ def batcher(batched_args, batch_dims, *, config):
607
610
def infer_sharding_from_operands (config , mesh , arg_infos , result_infos ):
608
611
del result_infos
609
612
q_spec = get_padded_spec (arg_infos [0 ])
613
+
614
+ # when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
615
+ # otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
616
+ is_packed_softmax = get_cudnn_version () >= (9 , 6 , 0 ) and config .qkv_layout .is_thd ()
617
+
610
618
if config .qkv_layout .is_qkvpacked ():
611
619
# q_spec = (...batch, q_seqlen, 3, head, hidden)
612
620
out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec [:- 3 ], * q_spec [- 2 :]))
613
- softmax_aux_sharding = NamedSharding (
614
- mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 2 ], q_spec [- 4 ], None )
615
- )
621
+ if not is_packed_softmax :
622
+ softmax_aux_sharding = NamedSharding (
623
+ mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 2 ], q_spec [- 4 ], None )
624
+ )
625
+ else :
626
+ softmax_aux_sharding = NamedSharding (
627
+ mesh , PartitionSpec (* q_spec [:- 4 ], q_spec [- 4 ], q_spec [- 2 ], None )
628
+ )
616
629
elif config .qkv_layout .is_kvpacked ():
617
630
# q_spec = (...batch, q_seqlen, head, hidden)
618
631
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
619
632
out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec ))
620
- softmax_aux_sharding = NamedSharding (
621
- mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
622
- )
633
+ if not is_packed_softmax :
634
+ softmax_aux_sharding = NamedSharding (
635
+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
636
+ )
637
+ else :
638
+ softmax_aux_sharding = NamedSharding (
639
+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 3 ], q_spec [- 2 ], None )
640
+ )
623
641
elif config .qkv_layout .is_separate ():
624
642
# q_spec = (...batch, q_seqlen, head, hidden)
625
643
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
626
644
out_sharding = NamedSharding (mesh , PartitionSpec (* q_spec ))
627
- softmax_aux_sharding = NamedSharding (
628
- mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
629
- )
645
+ if not is_packed_softmax :
646
+ softmax_aux_sharding = NamedSharding (
647
+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 2 ], q_spec [- 3 ], None )
648
+ )
649
+ else :
650
+ softmax_aux_sharding = NamedSharding (
651
+ mesh , PartitionSpec (* q_spec [:- 3 ], q_spec [- 3 ], q_spec [- 2 ], None )
652
+ )
630
653
else :
631
654
raise ValueError (f"Unsupported { config .qkv_layout = } " )
655
+
632
656
rng_state_sharding = NamedSharding (mesh , PartitionSpec (get_all_mesh_axes (), None ))
633
657
return (out_sharding , softmax_aux_sharding , rng_state_sharding )
634
658
@@ -2236,7 +2260,6 @@ def scan_kv_block(idx, carry):
2236
2260
subblock_config ,
2237
2261
)
2238
2262
2239
- # TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
2240
2263
softmax_aux_per_step = softmax_aux_per_step .reshape ((batch , q_max_seqlen , head , 1 ))
2241
2264
2242
2265
def skip_correction (_output , _softmax_aux , output_per_step , softmax_aux_per_step ):
@@ -2272,8 +2295,6 @@ def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
2272
2295
carry = scan_kv_block (i , carry )
2273
2296
(_ , _ , _ , output , softmax_aux ) = carry
2274
2297
2275
- softmax_aux = softmax_aux .reshape ((batch , head , q_max_seqlen , 1 ))
2276
-
2277
2298
return output .astype (q .dtype ), softmax_aux , rng_state
2278
2299
2279
2300
return mesh , fwd_impl , out_shardings , arg_shardings
0 commit comments