Skip to content

Commit bee4649

Browse files
[JAX] Fix softmax aux shapes for packed/THD format (#1575)
* Fix softmax shape for THD format. Signed-off-by: Michael Goldfarb <[email protected]>
1 parent 4f33ece commit bee4649

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

transformer_engine/jax/cpp_extensions/attention.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def abstract(
295295
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
296296
# cuDNN 9.6 reduces the required softmax shape
297297
if get_cudnn_version() >= (9, 6, 0):
298-
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
298+
if config.qkv_layout.is_thd():
299+
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
300+
else:
301+
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
299302
else:
300303
softmax_shape = (
301304
*batch_shape,
@@ -607,28 +610,49 @@ def batcher(batched_args, batch_dims, *, config):
607610
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
608611
del result_infos
609612
q_spec = get_padded_spec(arg_infos[0])
613+
614+
# when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
615+
# otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
616+
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
617+
610618
if config.qkv_layout.is_qkvpacked():
611619
# q_spec = (...batch, q_seqlen, 3, head, hidden)
612620
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
613-
softmax_aux_sharding = NamedSharding(
614-
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
615-
)
621+
if not is_packed_softmax:
622+
softmax_aux_sharding = NamedSharding(
623+
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
624+
)
625+
else:
626+
softmax_aux_sharding = NamedSharding(
627+
mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
628+
)
616629
elif config.qkv_layout.is_kvpacked():
617630
# q_spec = (...batch, q_seqlen, head, hidden)
618631
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
619632
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
620-
softmax_aux_sharding = NamedSharding(
621-
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
622-
)
633+
if not is_packed_softmax:
634+
softmax_aux_sharding = NamedSharding(
635+
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
636+
)
637+
else:
638+
softmax_aux_sharding = NamedSharding(
639+
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
640+
)
623641
elif config.qkv_layout.is_separate():
624642
# q_spec = (...batch, q_seqlen, head, hidden)
625643
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
626644
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
627-
softmax_aux_sharding = NamedSharding(
628-
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
629-
)
645+
if not is_packed_softmax:
646+
softmax_aux_sharding = NamedSharding(
647+
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
648+
)
649+
else:
650+
softmax_aux_sharding = NamedSharding(
651+
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
652+
)
630653
else:
631654
raise ValueError(f"Unsupported {config.qkv_layout=}")
655+
632656
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
633657
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
634658

@@ -2236,7 +2260,6 @@ def scan_kv_block(idx, carry):
22362260
subblock_config,
22372261
)
22382262

2239-
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
22402263
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
22412264

22422265
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
@@ -2272,8 +2295,6 @@ def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
22722295
carry = scan_kv_block(i, carry)
22732296
(_, _, _, output, softmax_aux) = carry
22742297

2275-
softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))
2276-
22772298
return output.astype(q.dtype), softmax_aux, rng_state
22782299

22792300
return mesh, fwd_impl, out_shardings, arg_shardings

0 commit comments

Comments
 (0)