Skip to content

Commit 5e8a9a9

Browse files
[JAX] Fix: Skip determinism tests for bprop for all sm >=100 (#2315)
* Fix: Skip determinism tests for bprop for all sm >=100 Signed-off-by: Kshitij Lakhani <[email protected]> * Add username to TODO Signed-off-by: Kshitij Lakhani <[email protected]> * Assert in fused attn bwd pass for sm100+ Signed-off-by: Kshitij Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f0295f9 commit 5e8a9a9

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

tests/jax/test_fused_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,14 @@ def _check_configs(self):
378378
pytest.skip(
379379
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
380380
)
381-
381+
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
382382
if (
383-
get_device_compute_capability(0) == 100
383+
get_device_compute_capability(0) >= 100
384384
and self.dropout_prob == 0.1
385385
and self.attn_bias_type is not AttnBiasType.NO_BIAS
386386
):
387387
pytest.skip(
388-
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
388+
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
389389
)
390390
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
391391
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,10 +2739,13 @@ def fused_attn_bwd(
27392739
assert bias is None
27402740
bias = jnp.zeros(0, dtype=qkv[0].dtype)
27412741

2742-
if 100 in get_all_device_compute_capability():
2742+
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
2743+
# sm100+
2744+
compute_capabilities = get_all_device_compute_capability()
2745+
if any(x >= 100 for x in compute_capabilities):
27432746
assert not (
27442747
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
2745-
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
2748+
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
27462749

27472750
fused_config = _FusedAttnConfig(
27482751
attn_bias_type=attn_bias_type,

0 commit comments

Comments
 (0)