Skip to content

Conversation

buildwithsuhana
Copy link
Contributor

This pull request marks the first foundational step in introducing a powerful, backend-agnostic auto-sharding capability for Tensor Parallelism in Keras 3.0. The ultimate goal is to democratize distributed training by allowing users to train models larger than a single device's memory with minimal code changes ideally just two lines to define and use the distribution strategy.

This initial PR does not introduce the full end-to-end logic but lays the critical groundwork by establishing the core data structures and communication abstractions that the rest of the system will be built upon.

Key Components Introduced
This PR introduces three fundamental building blocks for the auto-sharding framework:

  1. state_actions_keras.py: Abstracting Sharding Operations
    This module defines how tensor transformations for distribution should be represented.

StateActionKeras: An abstract base class that represents an action to be performed on a tensor, such as splitting it for a specific worker.

SplitKeras: A concrete implementation that defines the logic for splitting a tensor along a given dimension.

  1. communications_keras.py: High-Level Communication Primitives
    This module provides high-level wrappers for the collective communication operations essential for tensor parallelism.

AllReduceKeras, AllGatherKeras, BroadcastKeras, ScatterKeras: These classes encapsulate the logic for cross-device communication.

Purpose: They serve as a backend-agnostic interface for collective ops.

  1. config_keras.py: The Sharding Plan Data Structure
    This module introduces the data class that will hold the complete, model-wide sharding configuration.

ConfigKeras: A dataclass designed to store the rules and actions for sharding both the model's weights (state_rules) and its outputs (output_rules).

Design Document: Autosharding for Keras

Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code.

Highlights

  • Core Distributed Backend Abstraction: Introduced BaseDistributedBackend as an abstract interface for distributed operations and a get_distributed_backend factory function to provide a unified, backend-agnostic way to interact with JAX, TensorFlow, PyTorch, and NumPy distributed environments.
  • High-Level Communication Primitives: Defined AllReduceKeras, AllGatherKeras, BroadcastKeras, and ScatterKeras classes, which serve as high-level wrappers for essential collective communication operations required for tensor parallelism.
  • Tensor Sharding Actions: Implemented StateActionKeras as an abstract base class for defining how tensors are transformed for distribution. Concrete implementations like SplitKeras handle tensor sharding, while GatherKeras and SumKeras define how to reconstruct original tensors from their distributed parts.
  • Sharding Plan Configuration: Introduced the ConfigKeras dataclass to store and manage model-wide sharding rules and output configurations, including a mechanism to dynamically create collective operations based on these rules.
  • Tensor Parallel Communicator: Added TensorParallelCommunicator to orchestrate complex communication patterns for tensor parallelism, including specific methods for handling forward and backward passes in column-parallel and row-parallel operations, along with gradient slicing logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.

@codecov-commenter
Copy link

codecov-commenter commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 70.06803% with 88 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.52%. Comparing base (5ae5503) to head (b9f36e9).
⚠️ Report is 9 commits behind head on master.

Files with missing lines Patch % Lines
...src/distribution/tensor_parallel/communications.py 56.52% 40 Missing ⚠️
keras/src/backend/jax/distributed_backend.py 58.97% 22 Missing and 10 partials ⚠️
keras/api/_tf_keras/keras/distribution/__init__.py 0.00% 5 Missing ⚠️
...distribution/tensor_parallel/state_action_keras.py 89.36% 2 Missing and 3 partials ⚠️
keras/src/distribution/distributed_backend.py 78.94% 4 Missing ⚠️
keras/src/distribution/tensor_parallel/config.py 94.73% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21697      +/-   ##
==========================================
- Coverage   82.59%   82.52%   -0.07%     
==========================================
  Files         572      577       +5     
  Lines       58327    58829     +502     
  Branches     9131     9187      +56     
==========================================
+ Hits        48177    48551     +374     
- Misses       7818     7927     +109     
- Partials     2332     2351      +19     
Flag Coverage Δ
keras 82.33% <70.06%> (-0.07%) ⬇️
keras-jax 63.22% <68.70%> (-0.09%) ⬇️
keras-numpy 57.44% <34.01%> (-0.21%) ⬇️
keras-openvino 34.33% <34.01%> (+0.01%) ⬆️
keras-tensorflow 63.79% <34.01%> (-0.25%) ⬇️
keras-torch 63.34% <34.01%> (-0.29%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@JyotinderSingh JyotinderSingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a few initial comments and questions during my first look.

To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

Some high level comments:

  • Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
    • Why do we need all these primitives?
    • Why do we need 3 layers of abstraction for the same concepts: the communications layer, the state_actions layer and the keras.distributed.get_communication_ops layer? Can we just have one?
  • These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
  • I wouldn't export these symbols that you added to keras.distributed, I don't think they are needed. What we'll expose is the "Tensor Parallel" API.
  • For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.

Comment on lines +14 to +15
def compute_gradients(
_loss: jnp.ndarray, trainable_vars: List[jnp.ndarray]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature doesn't work for JAX. You cannot take the gradient of a tensor. You can only transform a function so that you can take its gradient.

Comment on lines +19 to +20
Note: This is a placeholder implementation that returns zeros. A real
implementation would use `jax.grad`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why are we doing this if it's not a real implementation?

Comment on lines +34 to +47
def apply_gradients(
gradients: List[jnp.ndarray],
trainable_vars: List[jnp.ndarray],
learning_rate: float = 0.001,
) -> List[jnp.ndarray]:
"""Applies gradients and returns the updated variables."""
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
new_var = var - (learning_rate * grad)
updated_vars.append(new_var)
else:
updated_vars.append(var)
return updated_vars
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an inline implementation of SGD. Why is this needed?

return updated_vars


def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Does this mean that TensorParallel won't work with Keras optimizers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants