Skip to content

Conversation

amitsrivastava78
Copy link
Collaborator

  • Modified load_own_variables() to use _direct_assign() for sharded variables
  • Prevents loading full weight tensors on single device before distribution
  • Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
  • Maintains backward compatibility for non-sharded variables
  • Enables loading of models like Gemma2 2B/7B without OOM errors

@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 3, 2025
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces crucial improvements to Keras's JAX backend, specifically targeting Out-Of-Memory (OOM) issues that arise when loading large models in a ModelParallel distributed setup. By modifying how sharded variables are initialized and assigned, the system now avoids loading entire weight tensors onto a single device, instead distributing them directly across available devices. This change, encapsulated in a new load_variable_with_sharded_support utility, ensures efficient memory usage and enables the successful loading of very large models like Gemma2 2B/7B. Additionally, it enhances the robustness of dynamic shape handling and prevents premature garbage collection of sharded JAX arrays, solidifying Keras's distributed capabilities on JAX.

Highlights

  • OOM Fix for ModelParallel: Modified weight loading logic to prevent Out-Of-Memory (OOM) errors when using ModelParallel with large models, especially on the JAX backend.
  • Sharded Variable Loading: Implemented a new utility function, load_variable_with_sharded_support, to ensure sharded variables are loaded directly onto their respective devices, avoiding full tensor materialization on a single device.
  • JAX Backend Enhancements: Updated JaxVariable and NnxVariable to manage strong references to sharded JAX arrays, preventing premature garbage collection and ensuring data availability during inference.
  • Dynamic Shape Handling: Improved compute_output_spec in the JAX backend to more robustly handle dynamic shapes during symbolic tracing.
  • Broad Layer Support: Integrated the new sharded loading mechanism across various Keras layers (e.g., Dense, Conv, Embedding, EinsumDense, BatchNormalization) and optimizers.
  • Comprehensive Testing: Added new tests to validate the correct behavior of sharded array protection, strong references, and the end-to-end sharded variable loading process for ModelParallel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an out-of-memory (OOM) issue during weight loading for models using ModelParallel by introducing sharded variable assignment. The changes primarily involve modifying load_own_variables methods to use a new _direct_assign approach for sharded variables, preventing the full weight tensor from being loaded onto a single device. A new helper function, load_variable_with_sharded_support, centralizes this logic, and it has been integrated into various layers and the base optimizer.

My review has identified a few areas for improvement, including a potential memory leak, duplicated code, and some inconsistencies. Addressing these points will enhance the robustness and maintainability of the solution. Overall, the changes are well-structured and include thorough testing, which is excellent.

@codecov-commenter
Copy link

codecov-commenter commented Oct 3, 2025

Codecov Report

❌ Patch coverage is 77.13004% with 51 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.56%. Comparing base (3fac66f) to head (8beda65).
⚠️ Report is 7 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 72.78% 31 Missing and 12 partials ⚠️
keras/src/layers/core/dense.py 91.66% 1 Missing and 1 partial ⚠️
keras/src/layers/core/einsum_dense.py 91.66% 1 Missing and 1 partial ⚠️
keras/src/layers/core/embedding.py 75.00% 1 Missing and 1 partial ⚠️
keras/src/layers/preprocessing/index_lookup.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21712      +/-   ##
==========================================
- Coverage   82.60%   82.56%   -0.04%     
==========================================
  Files         572      572              
  Lines       58326    58710     +384     
  Branches     9134     9195      +61     
==========================================
+ Hits        48179    48474     +295     
- Misses       7817     7887      +70     
- Partials     2330     2349      +19     
Flag Coverage Δ
keras 82.36% <77.13%> (-0.04%) ⬇️
keras-jax 63.22% <76.68%> (-0.10%) ⬇️
keras-numpy 57.47% <32.73%> (-0.19%) ⬇️
keras-openvino 34.25% <7.62%> (-0.06%) ⬇️
keras-tensorflow 63.83% <32.73%> (-0.22%) ⬇️
keras-torch 63.38% <33.18%> (-0.26%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

It's unfortunate that this is combining 3 different things (which you describe in your doc):

  1. initialization of sharded variables
  2. reloading of sharded variables
  3. reference counting of variables

In particular, I don't think we should do 3 because both Python and JAX already track usage of arrays. Therefore:

  • I believe it's hiding some other bug
  • Adding our own tracking system on top is error prone and is probably going to add memory leaks because it's very easy to forget to clear references.

import ml_dtypes
import numpy as np
from jax import export as jax_export
from absl import logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is undoing a lot of changes that were made in this file. It wasn't rebased correctly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok rebased again and ensured no fixes are missing also restored jax_export and other fixes overwritten

IS_THREAD_SAFE = True


def _is_jax_tracer(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use (and potentially change) the one from jax_utils?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

# Check if variable has a layout (is sharded)
if hasattr(variable, "_layout") and variable._layout is not None:
# Use _direct_assign for sharded variables to avoid OOM
logging.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logging will be very noisy, let's remove it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

from absl import logging


def load_variable_with_sharded_support(variable, weight_data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this, you can just do variable._direct_assign(weight_data) everywhere you used load_variable_with_sharded_support.

That check for if hasattr(variable, "_layout") and variable._layout is not None: is already done within _direct_assign.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

)

# Ensure value is on host (numpy array)
if not isinstance(value, np.ndarray):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's too late. The initializer has already been called on device:0. For this to work the way you intend it to, you need to run the initializer on CPU using a with device(...) scope.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

and loaded_var._shard_references
)

logging.debug(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove all logging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment on lines 310 to 336
if self.quantization_mode == "gptq":
# GPTQ: bias first, then quantized_kernel
target_variables = [self.bias] if self.use_bias else []
target_variables.append(self.quantized_kernel)
else:
target_variables = [self._kernel]
if self.use_bias and self.quantization_mode != "gptq":
target_variables.append(self.bias)
if self.quantization_mode is not None:
if self.quantization_mode in ("int8", "int4"):
target_variables.append(self.kernel_scale)
elif self.quantization_mode == "float8":
target_variables.append(self.inputs_scale)
target_variables.append(self.inputs_amax_history)
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_amax_history)
target_variables.append(self.outputs_grad_scale)
target_variables.append(self.outputs_grad_amax_history)
elif self.quantization_mode == "gptq":
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_zero)
target_variables.append(self.g_idx)
else:
raise self._quantization_mode_error(self.quantization_mode)
for i, variable in enumerate(target_variables):
weight_data = store[str(i)]
load_variable_with_sharded_support(variable, weight_data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is being changed I believe.

But why couldn't you do a 1-line change:

< variable.assign(store[str(i)])
---
> load_variable_with_sharded_support(variable, store[str(i)])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

raise self._quantization_mode_error(self.quantization_mode)
for i, variable in enumerate(target_variables):
weight_data = store[str(i)]
load_variable_with_sharded_support(variable, weight_data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about this code being changed and a 1-line change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

raise self._quantization_mode_error(self.quantization_mode)
for i, variable in enumerate(target_variables):
weight_data = store[str(i)]
load_variable_with_sharded_support(variable, weight_data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about this code being changed and a 1-line change.

f"{has_shard_refs_loaded}"
)

self.assertTrue(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assertTrue forces you to put a specific error message to provide context.

I would replace all of this:

                    has_shard_refs_orig = (
                        hasattr(orig_var, "_shard_references")
                        and orig_var._shard_references
                    )
                    logging.debug(
                        f"  Original has shard references: "
                        f"{has_shard_refs_orig}"
                    )
                    self.assertTrue(
                        has_shard_refs_orig,
                        f"Original {var_name} should have shard references",
                    )
                    self.assertGreater(
                        len(orig_var._shard_references),
                        0,
                        f"Original {var_name} has empty shard references",
                    )

With line:

self.assertLen(orig_var._shard_references, 1)

Not only it's a lot less code, it will actually gives you more information in case of error:

  • if orig_var doesn't have _shard_references as an attribute, it will raise an error telling you exactly that
  • if orig_var._shard_references is None, len will fail telling you it's None, so you'll know (has_shard_refs_orig won't tell you directly if the attribute is missing or None, you'll have to look at the debug loggin)
  • assertLen will tell you what you're taking the len of whereas assertGreater will tell you 0 < 0 which is not super helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues size:XL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants