🚨 Fix float16 overflow in Gemma4 vision pooler#46277
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
`Gemma4VisionPooler` scales hidden states by `sqrt(hidden_size)`. For the high-magnitude vision activations produced by the larger Gemma-4 checkpoints, this product exceeds the float16 maximum (65504) and saturates to `inf`, which then turns the downstream logits into `NaN`. Compute the pooler scaling and the following standardize in float32 and cast back to the working dtype once the magnitude is small again. This is a no-op in float32, keeps float16 finite, and computes the (cancellation-prone) standardize more accurately in low precision. The persistent `std_bias`/`std_scale` buffers are left untouched, so checkpoint round-trips are unchanged. Add a regression test that reproduces the float16 scaling overflow on a small config and checks that the output stays finite and matches the float32 result.
87c8065 to
7251d0b
Compare
|
cc @Cyrilvallez @zucchini-nlp — small float16 numerical fix for the Gemma4 vision pooler (modular). Would appreciate a review when you have a moment. Thanks! |
| # Scale in float32 and return float32: the sqrt(hidden_size) scaling can push the | ||
| # activations past the float16 range (max 65504), so the magnitude is kept in float32 | ||
| # until the caller standardizes it. | ||
| hidden_states = hidden_states.float() * self.root_hidden_size | ||
| return hidden_states, padding_positions |
There was a problem hiding this comment.
hm, could we instead cast back the hidden states to original_dtype before returning?
There was a problem hiding this comment.
Casting back to the input dtype inside the pooler would re-introduce the overflow: the * sqrt(hidden_size) product already exceeds the float16 range (> 65504) on the larger checkpoints, and it's the following standardize — which runs in Gemma4VisionModel.forward, after the pooler returns — that brings the magnitude back down. So a float16 cast here is exactly where it saturates to inf.
To keep the pooler returning the input dtype as you suggest, I think the cleanest option is to drop the scaling from the pooler entirely and fold the factor into the standardize, so it stays in the input dtype and the large intermediate is never formed.
1. Gemma4VisionPooler.forward — remove the scaling, return the pooled features in the input dtype:
if hidden_states.shape[1] != output_length:
hidden_states, padding_positions = self._avg_pool_by_positions(
hidden_states, pixel_position_ids, output_length
)
return hidden_states, padding_positions2. Gemma4VisionModel.forward — fold root_hidden_size into the standardize (replacing the current (hidden_states - self.std_bias) * self.std_scale):
# Strip padding tokens. pooler_mask is True = valid, False = padding.
hidden_states = hidden_states[pooler_mask]
if self.config.standardize:
# Fold the pooler's sqrt(hidden_size) scaling into the standardize. This is
# mathematically equal to `(hidden_states * root_hidden_size - std_bias) * std_scale`,
# but the largest intermediate is `hidden_states - std_bias / root_hidden_size`, so it
# never overflows float16 and needs no float32 upcast.
root_hidden_size = self.pooler.root_hidden_size
hidden_states = (hidden_states - self.std_bias / root_hidden_size) * self.std_scale * root_hidden_size
else:
hidden_states = hidden_states * self.pooler.root_hidden_size
return BaseModelOutputWithPast(last_hidden_state=hidden_states)(Edited in modular_gemma4.py and regenerated.) std_bias / std_scale are left untouched, so checkpoints are unaffected. Does that look better to you?
zucchini-nlp
left a comment
There was a problem hiding this comment.
Ah right, indeed! Lets add 🚨 in pr title since it is breaking if we change return dtype, even for an internal module
|
Done, added the 🚨. Though if the fold I proposed above works for you, the pooler keeps returning the input dtype (the scaling just moves into the standardize), so it wouldn't be a breaking change and we could drop the marker. Happy to go either way, just let me know which you prefer. |
What does this PR do?
Gemma4VisionModelscales the pooled vision features bysqrt(hidden_size)before the standardize step. For the high-magnitude vision activations produced by the larger Gemma-4 checkpoints, this product exceeds the float16 maximum (65504) and saturates toinf, which then propagates asNaNthrough the language-model logits, so the model is unusable in float16.The large value is only a transient — the immediately following
(hidden_states - std_bias) * std_scalestandardize brings the magnitude back down — but in float16 it overflows before the standardize can recover it.Fix
Gemma4VisionPoolerapplies itssqrt(hidden_size)scaling in float32 and returns float32;Gemma4VisionModel.forwardthen standardizes in float32 and casts back to the working dtype once the magnitude is small again:std_biassubtraction is a cancellation of large values (std_biasis on the order of 1e4–1e5), so the result is not bit-identical to before, but the difference is within low-precision noise.The persistent
std_bias/std_scalebuffers are left untouched, so existing checkpoints round-trip unchanged.Reproduction
google/gemma-4-31b-itin float16 returnsNaNlogits for any image input. Evaluating MMBench (dev EN, 200 samples):mainmainOn
main, every float16 sample overflows and yieldsNaNlogits (the 42% is the degenerate all-NaNfallback, not real accuracy). With the fix, float16 matches bfloat16, and bfloat16 is unchanged.Tests
Added
Gemma4VisionPoolerFloat16Test, which reproduces the float16 scaling overflow on a small config (forcing the high-activation regime the large checkpoints reach) and asserts the output stays finite and matches the float32 result. It fails onmainand passes with this PR.Before submitting
modeling_gemma4.pywas regenerated frommodular_gemma4.py.