Skip to content

🚨 Fix float16 overflow in Gemma4 vision pooler#46277

Open
Bluear7878 wants to merge 2 commits into
huggingface:mainfrom
Bluear7878:fix-gemma4-vision-pooler-fp16-overflow
Open

🚨 Fix float16 overflow in Gemma4 vision pooler#46277
Bluear7878 wants to merge 2 commits into
huggingface:mainfrom
Bluear7878:fix-gemma4-vision-pooler-fp16-overflow

Conversation

@Bluear7878
Copy link
Copy Markdown

@Bluear7878 Bluear7878 commented May 29, 2026

What does this PR do?

Gemma4VisionModel scales the pooled vision features by sqrt(hidden_size) before the standardize step. For the high-magnitude vision activations produced by the larger Gemma-4 checkpoints, this product exceeds the float16 maximum (65504) and saturates to inf, which then propagates as NaN through the language-model logits, so the model is unusable in float16.

The large value is only a transient — the immediately following (hidden_states - std_bias) * std_scale standardize brings the magnitude back down — but in float16 it overflows before the standardize can recover it.

Fix

Gemma4VisionPooler applies its sqrt(hidden_size) scaling in float32 and returns float32; Gemma4VisionModel.forward then standardizes in float32 and casts back to the working dtype once the magnitude is small again:

  • float32: unchanged (the upcast is a no-op).
  • float16: no longer overflows.
  • bfloat16: computed slightly more accurately. The std_bias subtraction is a cancellation of large values (std_bias is on the order of 1e4–1e5), so the result is not bit-identical to before, but the difference is within low-precision noise.

The persistent std_bias / std_scale buffers are left untouched, so existing checkpoints round-trip unchanged.

Reproduction

google/gemma-4-31b-it in float16 returns NaN logits for any image input. Evaluating MMBench (dev EN, 200 samples):

dtype branch accuracy overflow NaN logits
float16 main 42.0% 200/200 200/200
float16 this PR 91.0% 0/200 0/200
bfloat16 main 91.5% 0/200 0/200
bfloat16 this PR 91.5% 0/200 0/200

On main, every float16 sample overflows and yields NaN logits (the 42% is the degenerate all-NaN fallback, not real accuracy). With the fix, float16 matches bfloat16, and bfloat16 is unchanged.

Note: standardize=False in float16 is inherently unsupported (the pooled output is intentionally large); all released Gemma-4 checkpoints use standardize=True.

Tests

Added Gemma4VisionPoolerFloat16Test, which reproduces the float16 scaling overflow on a small config (forcing the high-activation regime the large checkpoints reach) and asserts the output stays finite and matches the float32 result. It fails on main and passes with this PR.

Before submitting

  • This PR fixes a bug (float16 inference of large Gemma-4 vision models).
  • Did you write any new necessary tests? Yes.
  • modeling_gemma4.py was regenerated from modular_gemma4.py.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

`Gemma4VisionPooler` scales hidden states by `sqrt(hidden_size)`. For the
high-magnitude vision activations produced by the larger Gemma-4 checkpoints,
this product exceeds the float16 maximum (65504) and saturates to `inf`, which
then turns the downstream logits into `NaN`.

Compute the pooler scaling and the following standardize in float32 and cast
back to the working dtype once the magnitude is small again. This is a no-op in
float32, keeps float16 finite, and computes the (cancellation-prone) standardize
more accurately in low precision. The persistent `std_bias`/`std_scale` buffers
are left untouched, so checkpoint round-trips are unchanged.

Add a regression test that reproduces the float16 scaling overflow on a small
config and checks that the output stays finite and matches the float32 result.
@Bluear7878 Bluear7878 force-pushed the fix-gemma4-vision-pooler-fp16-overflow branch from 87c8065 to 7251d0b Compare May 29, 2026 04:46
@Bluear7878
Copy link
Copy Markdown
Author

cc @Cyrilvallez @zucchini-nlp — small float16 numerical fix for the Gemma4 vision pooler (modular). Would appreciate a review when you have a moment. Thanks!

Comment on lines +652 to 656
# Scale in float32 and return float32: the sqrt(hidden_size) scaling can push the
# activations past the float16 range (max 65504), so the magnitude is kept in float32
# until the caller standardizes it.
hidden_states = hidden_states.float() * self.root_hidden_size
return hidden_states, padding_positions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, could we instead cast back the hidden states to original_dtype before returning?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting back to the input dtype inside the pooler would re-introduce the overflow: the * sqrt(hidden_size) product already exceeds the float16 range (> 65504) on the larger checkpoints, and it's the following standardize — which runs in Gemma4VisionModel.forward, after the pooler returns — that brings the magnitude back down. So a float16 cast here is exactly where it saturates to inf.

To keep the pooler returning the input dtype as you suggest, I think the cleanest option is to drop the scaling from the pooler entirely and fold the factor into the standardize, so it stays in the input dtype and the large intermediate is never formed.

1. Gemma4VisionPooler.forward — remove the scaling, return the pooled features in the input dtype:

        if hidden_states.shape[1] != output_length:
            hidden_states, padding_positions = self._avg_pool_by_positions(
                hidden_states, pixel_position_ids, output_length
            )

        return hidden_states, padding_positions

2. Gemma4VisionModel.forward — fold root_hidden_size into the standardize (replacing the current (hidden_states - self.std_bias) * self.std_scale):

        # Strip padding tokens. pooler_mask is True = valid, False = padding.
        hidden_states = hidden_states[pooler_mask]

        if self.config.standardize:
            # Fold the pooler's sqrt(hidden_size) scaling into the standardize. This is
            # mathematically equal to `(hidden_states * root_hidden_size - std_bias) * std_scale`,
            # but the largest intermediate is `hidden_states - std_bias / root_hidden_size`, so it
            # never overflows float16 and needs no float32 upcast.
            root_hidden_size = self.pooler.root_hidden_size
            hidden_states = (hidden_states - self.std_bias / root_hidden_size) * self.std_scale * root_hidden_size
        else:
            hidden_states = hidden_states * self.pooler.root_hidden_size

        return BaseModelOutputWithPast(last_hidden_state=hidden_states)

(Edited in modular_gemma4.py and regenerated.) std_bias / std_scale are left untouched, so checkpoints are unaffected. Does that look better to you?

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, indeed! Lets add 🚨 in pr title since it is breaking if we change return dtype, even for an internal module

@Bluear7878 Bluear7878 changed the title Fix float16 overflow in Gemma4 vision pooler 🚨 Fix float16 overflow in Gemma4 vision pooler May 29, 2026
@Bluear7878
Copy link
Copy Markdown
Author

Done, added the 🚨.

Though if the fold I proposed above works for you, the pooler keeps returning the input dtype (the scaling just moves into the standardize), so it wouldn't be a breaking change and we could drop the marker. Happy to go either way, just let me know which you prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants