Skip to content

Add T5Gemma to KerasHub #2339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

harshaljanjani
Copy link
Collaborator

@harshaljanjani harshaljanjani commented Jul 19, 2025

Description of the change

T5Gemma models integrate advanced features from Gemma 2, including GQA attention, RoPE embeddings, GeGLU activation, RMSNorm, and interleaved local/global attention, into the T5 transformer architecture. They deliver significant performance improvements over decoder-only models in tasks such as reasoning and summarization, striking an optimal balance between quality and inference efficiency.

Closes the issue #2321

Numerics Consistency (Absolute Tolerance @ 1e-4) and Example Output

from keras_hub.src.models.t5gemma.t5gemma_causal_lm import T5GemmaCausalLM

t5gemma_lm = T5GemmaCausalLM.from_preset("t5gemma_b_b_prefixlm_it")
output = t5gemma_lm.generate("What is the fastest land animal?")
print(output)
image

Reference

Colab Notebook

Numerics Validation and Usage Example

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@harshaljanjani harshaljanjani self-assigned this Jul 19, 2025
@github-actions github-actions bot added the Gemma Gemma model specific issues label Jul 19, 2025
@harshaljanjani harshaljanjani added the WIP Pull requests which are work in progress and not ready yet for review. label Jul 19, 2025
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @harshaljanjani, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the T5Gemma model into KerasHub, providing a comprehensive implementation of its encoder-decoder architecture, attention mechanisms, and supporting components. It enables causal language modeling with efficient text generation capabilities and includes dedicated preprocessing and tokenization utilities.

Highlights

  • New Model Integration: This pull request introduces the complete T5Gemma model architecture into KerasHub, enabling its use for various natural language processing tasks.
  • Advanced Attention Mechanisms: New T5GemmaSelfAttention and T5GemmaCrossAttention layers are added, featuring support for Grouped Query Attention (GQA) and Rotary Positional Embeddings (RoPE) for enhanced performance and positional encoding.
  • Encoder-Decoder Backbone: The T5GemmaBackbone is implemented, providing the core encoder-decoder structure. It supports both full attention and sliding window attention within its layers.
  • Causal Language Modeling: An end-to-end T5GemmaCausalLM is included, designed for efficient text generation through optimized call_with_cache and generate_step methods for autoregressive inference.
  • Dedicated Preprocessing and Tokenization: Custom T5GemmaCausalLMPreprocessor and T5GemmaTokenizer classes are added to handle input data preparation, including tokenization with SentencePiece and management of special tokens.
  • Core Layer Components: Fundamental building blocks like T5GemmaMLP (Multi-Layer Perceptron) and a specific t5gemma_kernel_initializer are introduced to support the T5Gemma architecture.
  • Comprehensive Testing: New unit tests are provided for the T5GemmaBackbone and T5GemmaCausalLM to ensure the correctness of the implementation and proper model saving functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code introduces the T5Gemma model to KerasHub. The implementation is comprehensive and well-structured. The review focuses on a performance optimization for the generation process and a point of code consistency. Addressing these will enhance the model's efficiency and maintainability.

@harshaljanjani harshaljanjani marked this pull request as ready for review July 23, 2025 16:13
@harshaljanjani harshaljanjani removed the WIP Pull requests which are work in progress and not ready yet for review. label Jul 23, 2025
@laxmareddyp
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the T5Gemma model to KerasHub, including its backbone, causal language model, preprocessor, tokenizer, and a checkpoint conversion script. The implementation is comprehensive and well-structured, with good test coverage. There's a critical correctness issue concerning shape inference in the attention mechanism, which affects the encoder and decoder layers. The compute_output_shape method in T5GemmaAttention does not correctly report the shape of the cache it returns, which can lead to problems with model serialization and analysis. I've provided suggestions to fix this and the corresponding call sites. Additionally, I've found a couple of medium-severity issues in the documentation examples, where a required argument was missing, which would cause the examples to fail.

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Jul 24, 2025

I believe the aforementioned cache shapes in the reviews can be inferred automatically, but the compute_output_shape methods can definitely be made more precise by including them. I’ll cover these after verifying!
I did, however, miss out on documenting head_dim⁣; I'll do that in the backbone and the underlying layers as well, thanks!

@laxmareddyp laxmareddyp added the kokoro:force-run Runs Tests on GPU label Jul 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 24, 2025
@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Jul 24, 2025

The errors in the JAX and TensorFlow GPU tests are not related to the T5Gemma implementation; also, the recent commit shows the exact same errors in CI. Please do check, thanks!
Also, I was going through the code and noticed I didn't propagate the dtypes. I'll do that quickly and re-run the normal + GPU tests.

@harshaljanjani harshaljanjani added the kokoro:force-run Runs Tests on GPU label Jul 24, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 24, 2025
@laxmareddyp
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an excellent and comprehensive pull request that adds the T5Gemma model to KerasHub. The code is well-structured and includes all the necessary components: model implementation, preprocessor, tokenizer, tests, and a conversion script.

I have a few suggestions to improve the code quality and maintainability. I've identified a potential issue with how head_dim is calculated in a couple of places, which could lead to bugs if not addressed. Additionally, there's an opportunity to refactor some duplicated code in the attention layer.

Once these points are addressed, this PR will be in great shape. Thank you for the great work!

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Jul 24, 2025

I just ran into the same head_dim bug as described while double-checking the presets. I'll fix that and incorporate the refactor suggestion I've noted from above. Also, hopefully we can have human reviews now @divyashreepathihalli @mattdangerw!

@laxmareddyp
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the T5Gemma model to KerasHub, including its backbone, causal language model, preprocessor, and tokenizer. The changes are comprehensive, adding the core model implementation, associated tests, and a script for converting Hugging Face checkpoints.

My review has identified a few areas for improvement, primarily related to ensuring correct model serialization and a potential performance optimization in the attention mechanism. Specifically, some get_config methods are missing parameters required for proper model saving and loading. Additionally, the attention implementation could be optimized by leveraging fused kernels.

Overall, this is a solid contribution. Addressing the identified issues will enhance the robustness and performance of the new T5Gemma model implementation.

@laxmareddyp
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the T5Gemma model to KerasHub, a significant and well-structured contribution. The implementation includes the backbone, causal language model, preprocessor, tokenizer, and comprehensive tests, along with a script for converting Hugging Face checkpoints. The overall code quality is high. My main feedback focuses on improving the documentation. The examples in the docstrings for the new components are not self-contained, missing necessary imports or variable definitions. Addressing these issues will make the examples runnable out-of-the-box and significantly improve the developer experience.

@laxmareddyp laxmareddyp added the kokoro:force-run Runs Tests on GPU label Jul 25, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 25, 2025
@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Jul 25, 2025

These are copied-over docstrings edited from Gemma to suit T5Gemma. I think we've reached a point where it's picking up unnecessary nits. I'll be resolving them without changes, thanks!

@harshaljanjani harshaljanjani added the kokoro:force-run Runs Tests on GPU label Jul 26, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 26, 2025
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution Harshal! left a NIT comment

@@ -0,0 +1,2 @@
# Metadata for loading pretrained model weights.
backbone_presets = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can upload to presets on Kaggle on your account temporarily, fill this out. we can run the preset tests and make sure everything is fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In total, the 28 presets amount to ~250 GB's of content. If you'd still like me to upload them to Kaggle right away, please let me know, but we’ll have to reupload them to the model page again afterward; thanks!

@divyashreepathihalli
Copy link
Collaborator

can you add a colab demo showing generate outputs matching

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Jul 30, 2025

Good day, @divyashreepathihalli! Thanks for the reviews!

can you add a colab demo showing generate outputs matching

I'll work on it shortly! In the meantime, I’ve enabled the mixed precision and quantization tests and left a reply regarding the Kaggle preset upload. Please feel free to have a look at your convenience, thanks!

Additionally, I'll be adding another commit shortly after this, after which all 32 presets from the original model (including the asymmetrical configurations) should be supported. I've been meticulous to only separate the arguments, which are different across the encoder and decoder in the presets, not the invariants (apologies for the mishap in the commit message!).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants