-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix ModelParallel OOM issue during weight loading #21712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
amitsrivastava78
commented
Oct 3, 2025
- Modified load_own_variables() to use _direct_assign() for sharded variables
- Prevents loading full weight tensors on single device before distribution
- Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
- Maintains backward compatibility for non-sharded variables
- Enables loading of models like Gemma2 2B/7B without OOM errors
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial improvements to Keras's JAX backend, specifically targeting Out-Of-Memory (OOM) issues that arise when loading large models in a ModelParallel distributed setup. By modifying how sharded variables are initialized and assigned, the system now avoids loading entire weight tensors onto a single device, instead distributing them directly across available devices. This change, encapsulated in a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an out-of-memory (OOM) issue during weight loading for models using ModelParallel
by introducing sharded variable assignment. The changes primarily involve modifying load_own_variables
methods to use a new _direct_assign
approach for sharded variables, preventing the full weight tensor from being loaded onto a single device. A new helper function, load_variable_with_sharded_support
, centralizes this logic, and it has been integrated into various layers and the base optimizer.
My review has identified a few areas for improvement, including a potential memory leak, duplicated code, and some inconsistencies. Addressing these points will enhance the robustness and maintainability of the solution. Overall, the changes are well-structured and include thorough testing, which is excellent.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21712 +/- ##
==========================================
- Coverage 82.60% 82.56% -0.04%
==========================================
Files 572 572
Lines 58326 58710 +384
Branches 9134 9195 +61
==========================================
+ Hits 48179 48474 +295
- Misses 7817 7887 +70
- Partials 2330 2349 +19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
7816f0c
to
af6c766
Compare
3a86f1e
to
303f241
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
It's unfortunate that this is combining 3 different things (which you describe in your doc):
- initialization of sharded variables
- reloading of sharded variables
- reference counting of variables
In particular, I don't think we should do 3 because both Python and JAX already track usage of arrays. Therefore:
- I believe it's hiding some other bug
- Adding our own tracking system on top is error prone and is probably going to add memory leaks because it's very easy to forget to clear references.
keras/src/backend/jax/core.py
Outdated
import ml_dtypes | ||
import numpy as np | ||
from jax import export as jax_export | ||
from absl import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is undoing a lot of changes that were made in this file. It wasn't rebased correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok rebased again and ensured no fixes are missing also restored jax_export and other fixes overwritten
keras/src/backend/jax/core.py
Outdated
IS_THREAD_SAFE = True | ||
|
||
|
||
def _is_jax_tracer(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you use (and potentially change) the one from jax_utils
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
keras/src/utils/variable_loading.py
Outdated
# Check if variable has a layout (is sharded) | ||
if hasattr(variable, "_layout") and variable._layout is not None: | ||
# Use _direct_assign for sharded variables to avoid OOM | ||
logging.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logging will be very noisy, let's remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
keras/src/utils/variable_loading.py
Outdated
from absl import logging | ||
|
||
|
||
def load_variable_with_sharded_support(variable, weight_data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this, you can just do variable._direct_assign(weight_data)
everywhere you used load_variable_with_sharded_support
.
That check for if hasattr(variable, "_layout") and variable._layout is not None:
is already done within _direct_assign
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
keras/src/backend/jax/core.py
Outdated
) | ||
|
||
# Ensure value is on host (numpy array) | ||
if not isinstance(value, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that's too late. The initializer has already been called on device:0
. For this to work the way you intend it to, you need to run the initializer on CPU using a with device(...)
scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
and loaded_var._shard_references | ||
) | ||
|
||
logging.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove all logging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
keras/src/layers/core/dense.py
Outdated
if self.quantization_mode == "gptq": | ||
# GPTQ: bias first, then quantized_kernel | ||
target_variables = [self.bias] if self.use_bias else [] | ||
target_variables.append(self.quantized_kernel) | ||
else: | ||
target_variables = [self._kernel] | ||
if self.use_bias and self.quantization_mode != "gptq": | ||
target_variables.append(self.bias) | ||
if self.quantization_mode is not None: | ||
if self.quantization_mode in ("int8", "int4"): | ||
target_variables.append(self.kernel_scale) | ||
elif self.quantization_mode == "float8": | ||
target_variables.append(self.inputs_scale) | ||
target_variables.append(self.inputs_amax_history) | ||
target_variables.append(self.kernel_scale) | ||
target_variables.append(self.kernel_amax_history) | ||
target_variables.append(self.outputs_grad_scale) | ||
target_variables.append(self.outputs_grad_amax_history) | ||
elif self.quantization_mode == "gptq": | ||
target_variables.append(self.kernel_scale) | ||
target_variables.append(self.kernel_zero) | ||
target_variables.append(self.g_idx) | ||
else: | ||
raise self._quantization_mode_error(self.quantization_mode) | ||
for i, variable in enumerate(target_variables): | ||
weight_data = store[str(i)] | ||
load_variable_with_sharded_support(variable, weight_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is being changed I believe.
But why couldn't you do a 1-line change:
< variable.assign(store[str(i)])
---
> load_variable_with_sharded_support(variable, store[str(i)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
raise self._quantization_mode_error(self.quantization_mode) | ||
for i, variable in enumerate(target_variables): | ||
weight_data = store[str(i)] | ||
load_variable_with_sharded_support(variable, weight_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about this code being changed and a 1-line change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
keras/src/layers/core/embedding.py
Outdated
raise self._quantization_mode_error(self.quantization_mode) | ||
for i, variable in enumerate(target_variables): | ||
weight_data = store[str(i)] | ||
load_variable_with_sharded_support(variable, weight_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about this code being changed and a 1-line change.
f"{has_shard_refs_loaded}" | ||
) | ||
|
||
self.assertTrue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assertTrue
forces you to put a specific error message to provide context.
I would replace all of this:
has_shard_refs_orig = (
hasattr(orig_var, "_shard_references")
and orig_var._shard_references
)
logging.debug(
f" Original has shard references: "
f"{has_shard_refs_orig}"
)
self.assertTrue(
has_shard_refs_orig,
f"Original {var_name} should have shard references",
)
self.assertGreater(
len(orig_var._shard_references),
0,
f"Original {var_name} has empty shard references",
)
With line:
self.assertLen(orig_var._shard_references, 1)
Not only it's a lot less code, it will actually gives you more information in case of error:
- if
orig_var
doesn't have_shard_references
as an attribute, it will raise an error telling you exactly that - if
orig_var._shard_references
is None,len
will fail telling you it's None, so you'll know (has_shard_refs_orig
won't tell you directly if the attribute is missing or None, you'll have to look at the debug loggin) assertLen
will tell you what you're taking the len of whereasassertGreater
will tell you0 < 0
which is not super helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
8beda65
to
0ecb55d
Compare