-
Notifications
You must be signed in to change notification settings - Fork 301
Description
Describe the bug
On Xeon 6, when Gemma3 is executed in fp16, it does not generate any output tokens. This issue is also noticed on nvidia GPU when run with colab.
To Reproduce
We can use this colab to reproduce the issue https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/keras_inference.ipynb and enable mixed precision before loading the model
keras.mixed_precision.set_global_policy("mixed_float16")
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(
"gemma3_instruct_4b"
)
After gemma_lm.generate("what is keras in 3 bullet points?", max_length=64) is called, it does not generate any new output tokens.
Expected behavior
It should generate relevant output in response to the prompt.
Additional context
when changes similar to this HF PR - huggingface/transformers#36832 are applied in keras-hub gemma3 model, it does generate new output.
If not this, there is a need to fix fp16 generate, so relevant output is generated on HW that support fp16.
Would you like to help us fix it?
I ported changes from HF to keras-hub and create a PR for reference