File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -579,7 +579,7 @@ def unpermute(
579579 if self .config .decoder_block == ctypes .DecoderBlockType .LLAMA4 :
580580 # For Llama4, combine using weights of 1 for selected experts
581581 reshaped_weights = jnp .ones_like (reshaped_weights )
582- if self .config .weight_sum_fp32 :
582+ if self .config .float32_weight_sum :
583583 reshaped_intermediate = reshaped_intermediate .astype (jnp .float32 )
584584 reshaped_weights = reshaped_weights .astype (jnp .float32 )
585585 output = jnp .einsum (
@@ -1676,7 +1676,7 @@ def dense_matmul(
16761676 with jax .named_scope ("w_sum" ):
16771677 if is_llama4_decoder_layer :
16781678 weights = self .reshape_and_update_weights (jnp .ones_like (top_k_weights ), top_k_indices )
1679- if self .config .weight_sum_fp32 :
1679+ if self .config .float32_weight_sum :
16801680 intermediate_layer = intermediate_layer .astype (jnp .float32 )
16811681 weights = weights .astype (jnp .float32 )
16821682 # cast to f32 for sum up in einsum op
You can’t perform that action at this time.
0 commit comments