@@ -473,6 +473,7 @@ def __init__(
473473 self .chunk_attn_window_size = chunk_attn_window_size
474474 self .use_ragged_attention = use_ragged_attention
475475 self .ragged_block_size = ragged_block_size
476+ self .rngs = rngs
476477
477478 def maybe_create_nnx (einsum , * args ):
478479 if isinstance (einsum , nn .Module ):
@@ -1360,11 +1361,16 @@ def cudnn_flash_attention(
13601361 if self .attention_type == AttentionType .LOCAL_SLIDING or using_context_parallelism :
13611362 mask_type = "causal" # SWA and Context Parallelism only work with causal masking
13621363 attn_mask = None
1364+ dummy_attn_mask = None
13631365 else :
13641366 # generate attn_mask
13651367 mask_type = "padding_causal" # only padding_causal mask type can take a created mask
1368+ dummy_attn_mask = jnp .zeros ((1 , 1 , 1 , self .max_target_length , self .max_target_length ), dtype = jnp .uint8 )
13661369 attn_mask = self .generate_attention_mask (query , key , decoder_segment_ids , model_mode )
13671370
1371+ if attn_mask is not None :
1372+ attn_mask = jnp .where ((attn_mask >= DEFAULT_MASK_VALUE * 0.5 ), 0 , 1 ).astype (jnp .uint8 )
1373+
13681374 dpa_layer = DotProductAttention (
13691375 head_dim = head_dim ,
13701376 num_attention_heads = self .num_query_heads ,
@@ -1382,6 +1388,17 @@ def cudnn_flash_attention(
13821388 context_parallel_causal_load_balanced = self .config .context_parallel_load_balance ,
13831389 context_parallel_axis = "context" ,
13841390 )
1391+
1392+ dpa_layer = nnx_wrappers .ToNNX (dpa_layer , rngs = self .rngs )
1393+ dummy_query_prefill = jnp .zeros (
1394+ (1 , self .max_target_length , self .num_query_heads , self .config .head_dim ), dtype = self .dtype
1395+ )
1396+ dummy_key_prefill = jnp .zeros ((1 , self .max_target_length , self .num_kv_heads , self .config .head_dim ), dtype = self .dtype )
1397+ dummy_value_prefill = jnp .zeros (
1398+ (1 , self .max_target_length , self .num_kv_heads , self .config .head_dim ), dtype = self .dtype
1399+ )
1400+
1401+ dpa_layer .lazy_init (dummy_query_prefill , dummy_key_prefill , dummy_value_prefill , mask = dummy_attn_mask )
13851402 return dpa_layer (query , key , value , mask = attn_mask )
13861403
13871404 def cudnn_jax_flash_attention (
0 commit comments