Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1bd7fd6
fix
pggPL Oct 1, 2025
5cbfb7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
a1e0c51
removed packed versions
pggPL Oct 2, 2025
dcb23b3
fix
pggPL Oct 2, 2025
90e2d1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2025
58c443a
fix
pggPL Oct 6, 2025
426ac6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2025
06cac3f
fix
pggPL Oct 6, 2025
df623e5
fix
pggPL Oct 6, 2025
5efb88c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2025
b5ffcda
fix
pggPL Oct 7, 2025
bd53e36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2025
438953e
jax
pggPL Oct 14, 2025
970cf80
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
216b46b
fix
pggPL Oct 14, 2025
f04be9d
fix
pggPL Oct 14, 2025
e0c7f6b
fix
pggPL Oct 14, 2025
176e2e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
cab9659
fix
pggPL Oct 14, 2025
2cd1727
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
99d67c1
Merge remote-tracking branch 'upstream/main' into jax_attention
pggPL Oct 15, 2025
4c79c34
fix
pggPL Oct 16, 2025
1845710
fix
pggPL Oct 17, 2025
7ca2377
Merge branch 'main' into jax_attention
pggPL Oct 17, 2025
6a45be7
fixes
pggPL Oct 17, 2025
b05e8ef
fix
pggPL Oct 17, 2025
3abec0b
fix:
pggPL Oct 17, 2025
a5730df
fixes
pggPL Oct 21, 2025
952d558
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
c5332a4
fixes
pggPL Oct 21, 2025
1a7d61b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
32ca266
fix
pggPL Oct 21, 2025
e9a3841
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_fused_attn_kernel_available,
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
Expand Down Expand Up @@ -66,6 +67,7 @@ def impl_test_self_attn(
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
Expand All @@ -80,6 +82,7 @@ def impl_test_self_attn(
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
Expand Down Expand Up @@ -109,6 +112,7 @@ def impl_test_self_attn(
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
Expand Down Expand Up @@ -142,6 +146,14 @@ def impl_test_self_attn(
],
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn(
self,
device_count,
Expand All @@ -153,6 +165,7 @@ def test_self_attn(
bias_shape,
attn_mask_type,
dtype,
softmax_type,
):
self.impl_test_self_attn(
device_count,
Expand All @@ -164,6 +177,7 @@ def test_self_attn(
bias_shape,
attn_mask_type,
dtype,
softmax_type,
use_shardy=False,
)

Expand All @@ -175,8 +189,23 @@ def test_self_attn(
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_self_attn_shardy(
self, device_count, mesh_shape, mesh_axes, mesh_resource, attn_bias_type, bias_shape
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type,
bias_shape,
softmax_type,
):
data_shape = (32, 512, 12, 64)
self.impl_test_self_attn(
Expand All @@ -189,6 +218,7 @@ def test_self_attn_shardy(
bias_shape,
AttnMaskType.PADDING_MASK,
jnp.bfloat16,
softmax_type,
use_shardy=True,
)

Expand All @@ -213,8 +243,24 @@ def generate_collectives_count_ref(self):
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_mask_type,
dtype,
softmax_type,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
Expand All @@ -230,6 +276,7 @@ def test_cross_attn(
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
num_head,
num_head,
Expand All @@ -252,6 +299,7 @@ def test_cross_attn(
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
Expand Down Expand Up @@ -322,6 +370,8 @@ def impl_test_context_parallel_attn(
bias_shape = None
dropout_prob = 0.0
is_training = True
# Context parallel does not support softmax_offset
softmax_type = AttnSoftmaxType.VANILLA_SOFTMAX
dp_size, cp_size, tp_size = mesh_shape

batch, seqlen, num_head, hidden = data_shape
Expand All @@ -343,6 +393,7 @@ def impl_test_context_parallel_attn(
hidden,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
Expand All @@ -366,6 +417,7 @@ def check_has_backend_for_mask(mask_type):
qkv_layout,
attn_bias_type,
mask_type,
softmax_type,
dropout_prob,
num_head,
num_kv_heads,
Expand Down
36 changes: 20 additions & 16 deletions tests/jax/test_distributed_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.softmax import SoftmaxFusion, softmax

DTYPES = [jnp.float16, jnp.bfloat16]

Expand All @@ -29,12 +29,12 @@ def generate_collectives_count_ref(self):
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
self, shape, mesh_resource, softmax_fusion, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape

x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
if softmax_fusion == SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
Expand All @@ -56,8 +56,8 @@ def generate_inputs(
return (x, mask), (x_pspec, mask_pspec)

@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def target_func(x, mask, scale_factor=1.0, softmax_fusion=SoftmaxFusion.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_fusion=softmax_fusion))

@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
Expand All @@ -80,24 +80,24 @@ def impl_test_softmax(
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
if broadcast_batch_mask and softmax_fusion != SoftmaxFusion.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")

jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
self.target_func, scale_factor=scale_factor, softmax_fusion=softmax_fusion
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)

(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
data_shape, mesh_resource, softmax_fusion, dtype, bad_sharding, broadcast_batch_mask
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
Expand Down Expand Up @@ -139,8 +139,12 @@ def impl_test_softmax(
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
"softmax_fusion",
[
SoftmaxFusion.SCALED,
SoftmaxFusion.SCALED_MASKED,
SoftmaxFusion.SCALED_UPPER_TRIANG_MASKED,
],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -153,7 +157,7 @@ def test_softmax(
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion,
scale_factor,
dtype,
bad_sharding,
Expand All @@ -165,7 +169,7 @@ def test_softmax(
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion,
scale_factor,
dtype,
bad_sharding,
Expand All @@ -174,7 +178,7 @@ def test_softmax(
)

@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("softmax_fusion", [SoftmaxFusion.SCALED, SoftmaxFusion.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd(
Expand All @@ -183,7 +187,7 @@ def test_softmax_gspmd(
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
softmax_fusion,
bad_sharding,
broadcast_batch_mask,
):
Expand All @@ -193,7 +197,7 @@ def test_softmax_gspmd(
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
softmax_fusion=softmax_fusion,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
Expand Down
Loading