Skip to content

Commit b09e255

Browse files
Migrate DotProductAttention to NNX
1 parent a15fc00 commit b09e255

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def __init__(
466466
self.chunk_attn_window_size = chunk_attn_window_size
467467
self.use_ragged_attention = use_ragged_attention
468468
self.ragged_block_size = ragged_block_size
469+
self.rngs = rngs
469470

470471
def maybe_create_nnx(einsum, *args):
471472
if isinstance(einsum, nn.Module):
@@ -1266,11 +1267,16 @@ def cudnn_flash_attention(
12661267
if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
12671268
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
12681269
attn_mask = None
1270+
dummy_attn_mask = None
12691271
else:
12701272
# generate attn_mask
12711273
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
1274+
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)
12721275
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
12731276

1277+
if attn_mask is not None:
1278+
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
1279+
12741280
dpa_layer = DotProductAttention(
12751281
head_dim=head_dim,
12761282
num_attention_heads=self.num_query_heads,
@@ -1288,6 +1294,17 @@ def cudnn_flash_attention(
12881294
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
12891295
context_parallel_axis="context",
12901296
)
1297+
1298+
dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=self.rngs)
1299+
dummy_query_prefill = jnp.zeros(
1300+
(1, self.max_target_length, self.num_query_heads, self.config.head_dim), dtype=self.dtype
1301+
)
1302+
dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype)
1303+
dummy_value_prefill = jnp.zeros(
1304+
(1, self.max_target_length, self.num_kv_heads, self.config.head_dim), dtype=self.dtype
1305+
)
1306+
1307+
dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, mask=dummy_attn_mask)
12911308
return dpa_layer(query, key, value, mask=attn_mask)
12921309

12931310
def cudnn_jax_flash_attention(

0 commit comments

Comments
 (0)