Skip to content

Commit be4470c

Browse files
authored
Update moe.py
1 parent ad1d270 commit be4470c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/MaxText/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def unpermute(
579579
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
580580
# For Llama4, combine using weights of 1 for selected experts
581581
reshaped_weights = jnp.ones_like(reshaped_weights)
582-
if self.config.weight_sum_fp32:
582+
if self.config.float32_weight_sum:
583583
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
584584
reshaped_weights = reshaped_weights.astype(jnp.float32)
585585
output = jnp.einsum(
@@ -1676,7 +1676,7 @@ def dense_matmul(
16761676
with jax.named_scope("w_sum"):
16771677
if is_llama4_decoder_layer:
16781678
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
1679-
if self.config.weight_sum_fp32:
1679+
if self.config.float32_weight_sum:
16801680
intermediate_layer = intermediate_layer.astype(jnp.float32)
16811681
weights = weights.astype(jnp.float32)
16821682
# cast to f32 for sum up in einsum op

0 commit comments

Comments
 (0)