From f91b79d53dc5796808f2a5057fb50691a3ba9d33 Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 17 Oct 2025 18:07:55 +0000 Subject: [PATCH 1/3] add weight_sum_fp32 config --- src/MaxText/configs/base.yml | 1 + src/MaxText/layers/moe.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 72244d47e9..9d1929ba9c 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -149,6 +149,7 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax +weight_sum_fp32: True # whether to use full fp32 precision for weight_sum during final unpermute in moe # Multi-Token Prediction Configs # The number of auxiliary prediction layers to use for MTP. diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 04744c6960..973d38d420 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -579,10 +579,13 @@ def unpermute( if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # For Llama4, combine using weights of 1 for selected experts reshaped_weights = jnp.ones_like(reshaped_weights) + if config.weight_sum_fp32: + reshaped_intermediate = reshaped_intermediate.astype(jnp.float32) + reshaped_weights = reshaped_weights.astype(jnp.float32) output = jnp.einsum( "BKE,BK -> BE", - reshaped_intermediate.astype(jnp.float32), - reshaped_weights.astype(jnp.float32), + reshaped_intermediate, + reshaped_weights, precision=matmul_precision, ) return output.reshape(batch_size, sequence_length, -1).astype(self.dtype) From 0ed2597623010faff51a63b91adefb89b5d8b20a Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 17 Oct 2025 18:14:53 +0000 Subject: [PATCH 2/3] update --- src/MaxText/layers/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 973d38d420..7afe0379ee 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -579,8 +579,8 @@ def unpermute( if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4: # For Llama4, combine using weights of 1 for selected experts reshaped_weights = jnp.ones_like(reshaped_weights) - if config.weight_sum_fp32: - reshaped_intermediate = reshaped_intermediate.astype(jnp.float32) + if self.config.weight_sum_fp32: + reshaped_intermediate = reshaped_intermediate.astype(jnp.float32) reshaped_weights = reshaped_weights.astype(jnp.float32) output = jnp.einsum( "BKE,BK -> BE", From ce0d0ed90548594ad409b42dddde1dddd50668ff Mon Sep 17 00:00:00 2001 From: Qinwen Xu Date: Fri, 17 Oct 2025 18:36:23 +0000 Subject: [PATCH 3/3] update --- src/MaxText/layers/moe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 7afe0379ee..74af377ebd 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -1676,11 +1676,14 @@ def dense_matmul( with jax.named_scope("w_sum"): if is_llama4_decoder_layer: weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices) + if self.config.weight_sum_fp32: + intermediate_layer = intermediate_layer.astype(jnp.float32) + weights = weights.astype(jnp.float32) # cast to f32 for sum up in einsum op output = jnp.einsum( "BSEM,BSE -> BSM", - intermediate_layer.astype(jnp.float32), - weights.astype(jnp.float32), # pylint: disable=undefined-variable,possibly-used-before-assignment + intermediate_layer, + weights, precision=matmul_precision, ).astype(self.dtype) return output, None