Skip to content

Commit 575b01e

Browse files
fix tpu tests
1 parent f9bcdfc commit 575b01e

File tree

1 file changed

+5
-0
lines changed
  • keras/src/backend/jax

1 file changed

+5
-0
lines changed

keras/src/backend/jax/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,11 @@ def dot_product_attention(
14111411

14121412
def _reshape_to_grouped(t):
14131413
if t is not None:
1414+
while t.ndim < 4:
1415+
if t.ndim == 3 and t.shape[1] == N:
1416+
t = jnp.expand_dims(t, axis=2)
1417+
else:
1418+
t = jnp.expand_dims(t, axis=1)
14141419
tB, tN, tT, tS = t.shape
14151420
if tN == 1:
14161421
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))

0 commit comments

Comments
 (0)