Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
weight_sum_fp32: True # whether to use full fp32 precision for weight_sum during final unpermute in moe

# Multi-Token Prediction Configs
# The number of auxiliary prediction layers to use for MTP.
Expand Down
14 changes: 10 additions & 4 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,13 @@ def unpermute(
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# For Llama4, combine using weights of 1 for selected experts
reshaped_weights = jnp.ones_like(reshaped_weights)
if self.config.weight_sum_fp32:
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
reshaped_weights = reshaped_weights.astype(jnp.float32)
output = jnp.einsum(
"BKE,BK -> BE",
reshaped_intermediate.astype(jnp.float32),
reshaped_weights.astype(jnp.float32),
reshaped_intermediate,
reshaped_weights,
precision=matmul_precision,
)
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
Expand Down Expand Up @@ -1673,11 +1676,14 @@ def dense_matmul(
with jax.named_scope("w_sum"):
if is_llama4_decoder_layer:
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
if self.config.weight_sum_fp32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
weights = weights.astype(jnp.float32)
# cast to f32 for sum up in einsum op
output = jnp.einsum(
"BSEM,BSE -> BSM",
intermediate_layer.astype(jnp.float32),
weights.astype(jnp.float32), # pylint: disable=undefined-variable,possibly-used-before-assignment
intermediate_layer,
weights,
precision=matmul_precision,
).astype(self.dtype)
return output, None
Expand Down
Loading