-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21697 +/- ##
==========================================
- Coverage 82.59% 82.52% -0.07%
==========================================
Files 572 577 +5
Lines 58327 58829 +502
Branches 9131 9187 +56
==========================================
+ Hits 48177 48551 +374
- Misses 7818 7927 +109
- Partials 2332 2351 +19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few initial comments and questions during my first look.
To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?
…uhana/keras into Tensor_parallel_keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Some high level comments:
- Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
- Why do we need all these primitives?
- Why do we need 3 layers of abstraction for the same concepts: the
communications
layer, thestate_actions
layer and thekeras.distributed.get_communication_ops
layer? Can we just have one?
- These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
- I wouldn't export these symbols that you added to
keras.distributed
, I don't think they are needed. What we'll expose is the "Tensor Parallel" API. - For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.
def compute_gradients( | ||
_loss: jnp.ndarray, trainable_vars: List[jnp.ndarray] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature doesn't work for JAX. You cannot take the gradient of a tensor. You can only transform a function so that you can take its gradient.
Note: This is a placeholder implementation that returns zeros. A real | ||
implementation would use `jax.grad`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So why are we doing this if it's not a real implementation?
def apply_gradients( | ||
gradients: List[jnp.ndarray], | ||
trainable_vars: List[jnp.ndarray], | ||
learning_rate: float = 0.001, | ||
) -> List[jnp.ndarray]: | ||
"""Applies gradients and returns the updated variables.""" | ||
updated_vars = [] | ||
for grad, var in zip(gradients, trainable_vars): | ||
if grad is not None: | ||
new_var = var - (learning_rate * grad) | ||
updated_vars.append(new_var) | ||
else: | ||
updated_vars.append(var) | ||
return updated_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an inline implementation of SGD. Why is this needed?
return updated_vars | ||
|
||
|
||
def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Does this mean that TensorParallel won't work with Keras optimizers?
This pull request marks the first foundational step in introducing a powerful, backend-agnostic auto-sharding capability for Tensor Parallelism in Keras 3.0. The ultimate goal is to democratize distributed training by allowing users to train models larger than a single device's memory with minimal code changes ideally just two lines to define and use the distribution strategy.
This initial PR does not introduce the full end-to-end logic but lays the critical groundwork by establishing the core data structures and communication abstractions that the rest of the system will be built upon.
Key Components Introduced
This PR introduces three fundamental building blocks for the auto-sharding framework:
This module defines how tensor transformations for distribution should be represented.
StateActionKeras: An abstract base class that represents an action to be performed on a tensor, such as splitting it for a specific worker.
SplitKeras: A concrete implementation that defines the logic for splitting a tensor along a given dimension.
This module provides high-level wrappers for the collective communication operations essential for tensor parallelism.
AllReduceKeras, AllGatherKeras, BroadcastKeras, ScatterKeras: These classes encapsulate the logic for cross-device communication.
Purpose: They serve as a backend-agnostic interface for collective ops.
This module introduces the data class that will hold the complete, model-wide sharding configuration.
ConfigKeras: A dataclass designed to store the rules and actions for sharding both the model's weights (state_rules) and its outputs (output_rules).
Design Document: Autosharding for Keras