Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
weight_sum_fp32: True # whether to use full fp32 precision for weight_sum during final unpermute in moe

# Multi-Token Prediction Configs
# The number of auxiliary prediction layers to use for MTP.
Expand Down
7 changes: 5 additions & 2 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,13 @@ def unpermute(
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# For Llama4, combine using weights of 1 for selected experts
reshaped_weights = jnp.ones_like(reshaped_weights)
if self.config.weight_sum_fp32:
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
reshaped_weights = reshaped_weights.astype(jnp.float32)
output = jnp.einsum(
"BKE,BK -> BE",
reshaped_intermediate.astype(jnp.float32),
reshaped_weights.astype(jnp.float32),
reshaped_intermediate,
reshaped_weights,
precision=matmul_precision,
)
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
Expand Down
Loading