Skip to content

Commit db88b59

Browse files
Fix TPU tests - for splash attention (#21891)
* fix TPU tests * code reformat * address review comments * code reformat * fix test * code reformat * fix splash attention bug# * fix tpu tests
1 parent 9f7bd40 commit db88b59

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

keras/src/backend/jax/nn.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,25 +1340,32 @@ def dot_product_attention(
13401340
if custom_mask is None and is_causal:
13411341
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
13421342

1343-
try:
1344-
output = wrap_flash_attention(
1345-
query_tpu_layout,
1346-
key_tpu_layout,
1347-
value_tpu_layout,
1348-
decoder_segment_ids=decoder_segment_ids,
1349-
custom_mask=custom_mask,
1350-
attn_logits_soft_cap=attn_logits_soft_cap,
1351-
head_shards=head_shards,
1352-
q_seq_shards=q_seq_shards,
1353-
)
1354-
# Transpose output back to Keras layout
1355-
return jnp.transpose(output, axes=(0, 2, 1, 3))
1356-
except Exception:
1357-
logging.exception(
1358-
"Failed to apply Splash kernel for flash attention. "
1359-
"Falling back to JAX native dot_product_attention."
1360-
)
1343+
# Splash attention kernel requires concrete mask values for hashing.
1344+
# If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
1345+
if isinstance(mask, jax.core.Tracer) or isinstance(
1346+
custom_mask, jax.core.Tracer
1347+
):
13611348
flash_attention = False
1349+
else:
1350+
try:
1351+
output = wrap_flash_attention(
1352+
query_tpu_layout,
1353+
key_tpu_layout,
1354+
value_tpu_layout,
1355+
decoder_segment_ids=decoder_segment_ids,
1356+
custom_mask=custom_mask,
1357+
attn_logits_soft_cap=attn_logits_soft_cap,
1358+
head_shards=head_shards,
1359+
q_seq_shards=q_seq_shards,
1360+
)
1361+
# Transpose output back to Keras layout
1362+
return jnp.transpose(output, axes=(0, 2, 1, 3))
1363+
except Exception:
1364+
logging.exception(
1365+
"Failed to apply Splash kernel for flash attention. "
1366+
"Falling back to JAX native dot_product_attention."
1367+
)
1368+
flash_attention = False
13621369

13631370
# JAX native dot_product_attention for GPU or fallback for TPU
13641371
if hasattr(jax.nn, "dot_product_attention"):
@@ -1404,6 +1411,11 @@ def dot_product_attention(
14041411

14051412
def _reshape_to_grouped(t):
14061413
if t is not None:
1414+
while t.ndim < 4:
1415+
if t.ndim == 3 and t.shape[1] == N:
1416+
t = jnp.expand_dims(t, axis=2)
1417+
else:
1418+
t = jnp.expand_dims(t, axis=1)
14071419
tB, tN, tT, tS = t.shape
14081420
if tN == 1:
14091421
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))

keras/src/ops/nn_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,40 @@ def test_polar(self):
13241324

13251325

13261326
class NNOpsCorrectnessTest(testing.TestCase):
1327+
@pytest.mark.skipif(backend.backend() != "jax", reason="JAX only")
1328+
def test_dot_product_attention_inside_scan(self):
1329+
import jax
1330+
1331+
try:
1332+
if jax.devices()[0].platform != "tpu":
1333+
self.skipTest("TPU-specific test")
1334+
except:
1335+
self.skipTest("TPU-specific test")
1336+
1337+
import jax.numpy as jnp
1338+
1339+
def attention_scan_body(carry, x):
1340+
query, key, value = x
1341+
# dot_product_attention expects 4D inputs (B, H, S, D)
1342+
query = jnp.expand_dims(query, axis=0)
1343+
key = jnp.expand_dims(key, axis=0)
1344+
value = jnp.expand_dims(value, axis=0)
1345+
1346+
# Use a mask to trigger the issue
1347+
mask = jnp.ones((1, 4, 8), dtype="bool")
1348+
out = knn.dot_product_attention(query, key, value, mask=mask)
1349+
1350+
out = jnp.squeeze(out, axis=0)
1351+
return carry, out
1352+
1353+
query = jnp.ones((2, 1, 4, 8))
1354+
key = jnp.ones((2, 1, 4, 8))
1355+
value = jnp.ones((2, 1, 4, 8))
1356+
1357+
# Scan over the first dimension
1358+
_, out = jax.lax.scan(attention_scan_body, None, (query, key, value))
1359+
self.assertEqual(out.shape, (2, 1, 4, 8))
1360+
13271361
def test_relu(self):
13281362
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
13291363
self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])

0 commit comments

Comments
 (0)