We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent be4470c commit 6f0349dCopy full SHA for 6f0349d
src/MaxText/layers/moe.py
@@ -1673,7 +1673,7 @@ def dense_matmul(
1673
if self.config.activations_in_float32:
1674
intermediate_layer = intermediate_layer.astype(jnp.float32)
1675
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
1676
- with jax.named_scope("w_sum"):
+ with jax.named_scope("weight_sum"):
1677
if is_llama4_decoder_layer:
1678
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
1679
if self.config.float32_weight_sum:
0 commit comments