@@ -466,6 +466,7 @@ def __init__(
466466 self .chunk_attn_window_size = chunk_attn_window_size
467467 self .use_ragged_attention = use_ragged_attention
468468 self .ragged_block_size = ragged_block_size
469+ self .rngs = rngs
469470
470471 def maybe_create_nnx (einsum , * args ):
471472 if isinstance (einsum , nn .Module ):
@@ -1266,11 +1267,16 @@ def cudnn_flash_attention(
12661267 if self .attention_type == AttentionType .LOCAL_SLIDING or using_context_parallelism :
12671268 mask_type = "causal" # SWA and Context Parallelism only work with causal masking
12681269 attn_mask = None
1270+ dummy_attn_mask = None
12691271 else :
12701272 # generate attn_mask
12711273 mask_type = "padding_causal" # only padding_causal mask type can take a created mask
1274+ dummy_attn_mask = jnp .zeros ((1 , 1 , 1 , self .max_target_length , self .max_target_length ), dtype = jnp .uint8 )
12721275 attn_mask = self .generate_attention_mask (query , key , decoder_segment_ids , model_mode )
12731276
1277+ if attn_mask is not None :
1278+ attn_mask = jnp .where ((attn_mask >= DEFAULT_MASK_VALUE * 0.5 ), 0 , 1 ).astype (jnp .uint8 )
1279+
12741280 dpa_layer = DotProductAttention (
12751281 head_dim = head_dim ,
12761282 num_attention_heads = self .num_query_heads ,
@@ -1288,6 +1294,17 @@ def cudnn_flash_attention(
12881294 context_parallel_causal_load_balanced = self .config .context_parallel_load_balance ,
12891295 context_parallel_axis = "context" ,
12901296 )
1297+
1298+ dpa_layer = nnx_wrappers .ToNNX (dpa_layer , rngs = self .rngs )
1299+ dummy_query_prefill = jnp .zeros (
1300+ (1 , self .max_target_length , self .num_query_heads , self .config .head_dim ), dtype = self .dtype
1301+ )
1302+ dummy_key_prefill = jnp .zeros ((1 , self .max_target_length , self .num_kv_heads , self .config .head_dim ), dtype = self .dtype )
1303+ dummy_value_prefill = jnp .zeros (
1304+ (1 , self .max_target_length , self .num_kv_heads , self .config .head_dim ), dtype = self .dtype
1305+ )
1306+
1307+ dpa_layer .lazy_init (dummy_query_prefill , dummy_key_prefill , dummy_value_prefill , mask = dummy_attn_mask )
12911308 return dpa_layer (query , key , value , mask = attn_mask )
12921309
12931310 def cudnn_jax_flash_attention (
0 commit comments