@@ -1340,25 +1340,32 @@ def dot_product_attention(
13401340 if custom_mask is None and is_causal :
13411341 custom_mask = jnp .tril (jnp .ones ((q_len , q_len ), dtype = jnp .bool_ ))
13421342
1343- try :
1344- output = wrap_flash_attention (
1345- query_tpu_layout ,
1346- key_tpu_layout ,
1347- value_tpu_layout ,
1348- decoder_segment_ids = decoder_segment_ids ,
1349- custom_mask = custom_mask ,
1350- attn_logits_soft_cap = attn_logits_soft_cap ,
1351- head_shards = head_shards ,
1352- q_seq_shards = q_seq_shards ,
1353- )
1354- # Transpose output back to Keras layout
1355- return jnp .transpose (output , axes = (0 , 2 , 1 , 3 ))
1356- except Exception :
1357- logging .exception (
1358- "Failed to apply Splash kernel for flash attention. "
1359- "Falling back to JAX native dot_product_attention."
1360- )
1343+ # Splash attention kernel requires concrete mask values for hashing.
1344+ # If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
1345+ if isinstance (mask , jax .core .Tracer ) or isinstance (
1346+ custom_mask , jax .core .Tracer
1347+ ):
13611348 flash_attention = False
1349+ else :
1350+ try :
1351+ output = wrap_flash_attention (
1352+ query_tpu_layout ,
1353+ key_tpu_layout ,
1354+ value_tpu_layout ,
1355+ decoder_segment_ids = decoder_segment_ids ,
1356+ custom_mask = custom_mask ,
1357+ attn_logits_soft_cap = attn_logits_soft_cap ,
1358+ head_shards = head_shards ,
1359+ q_seq_shards = q_seq_shards ,
1360+ )
1361+ # Transpose output back to Keras layout
1362+ return jnp .transpose (output , axes = (0 , 2 , 1 , 3 ))
1363+ except Exception :
1364+ logging .exception (
1365+ "Failed to apply Splash kernel for flash attention. "
1366+ "Falling back to JAX native dot_product_attention."
1367+ )
1368+ flash_attention = False
13621369
13631370 # JAX native dot_product_attention for GPU or fallback for TPU
13641371 if hasattr (jax .nn , "dot_product_attention" ):
@@ -1404,6 +1411,11 @@ def dot_product_attention(
14041411
14051412 def _reshape_to_grouped (t ):
14061413 if t is not None :
1414+ while t .ndim < 4 :
1415+ if t .ndim == 3 and t .shape [1 ] == N :
1416+ t = jnp .expand_dims (t , axis = 2 )
1417+ else :
1418+ t = jnp .expand_dims (t , axis = 1 )
14071419 tB , tN , tT , tS = t .shape
14081420 if tN == 1 :
14091421 t = jnp .broadcast_to (t [:, :, None , :, :], (tB , tN , G , tT , tS ))
0 commit comments