Skip to content

Commit 3b22b3d

Browse files
Merge pull request #2198 from CIeNET-International:feat/Migrate-DotProductAttention-to-NNX
PiperOrigin-RevId: 826110976
2 parents 278d2f9 + b09e255 commit 3b22b3d

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def __init__(
473473
self.chunk_attn_window_size = chunk_attn_window_size
474474
self.use_ragged_attention = use_ragged_attention
475475
self.ragged_block_size = ragged_block_size
476+
self.rngs = rngs
476477

477478
def maybe_create_nnx(einsum, *args):
478479
if isinstance(einsum, nn.Module):
@@ -1360,11 +1361,16 @@ def cudnn_flash_attention(
13601361
if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
13611362
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
13621363
attn_mask = None
1364+
dummy_attn_mask = None
13631365
else:
13641366
# generate attn_mask
13651367
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
1368+
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)
13661369
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
13671370

1371+
if attn_mask is not None:
1372+
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
1373+
13681374
dpa_layer = DotProductAttention(
13691375
head_dim=head_dim,
13701376
num_attention_heads=self.num_query_heads,
@@ -1382,6 +1388,17 @@ def cudnn_flash_attention(
13821388
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
13831389
context_parallel_axis="context",
13841390
)
1391+
1392+
dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=self.rngs)
1393+
dummy_query_prefill = jnp.zeros(
1394+
(1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype
1395+
)
1396+
dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype)
1397+
dummy_value_prefill = jnp.zeros(
1398+
(1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype
1399+
)
1400+
1401+
dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, mask=dummy_attn_mask)
13851402
return dpa_layer(query, key, value, mask=attn_mask)
13861403

13871404
def cudnn_jax_flash_attention(

0 commit comments

Comments
 (0)