Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
555e5c9
Refactoring to remove communications.py and state_action_keras.py
buildwithsuhana Oct 12, 2025
b80d264
formatting the files
buildwithsuhana Oct 12, 2025
93b1738
fixing skip issues
buildwithsuhana Oct 12, 2025
b7b2b9b
fixing test
buildwithsuhana Oct 12, 2025
f6c1142
fixing test
buildwithsuhana Oct 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,29 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distributed_backend = None
distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distributed_backend
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
Expand Down
146 changes: 146 additions & 0 deletions keras/src/backend/jax/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Literal

import jax
import jax.lax as lax
import jax.numpy as jnp


def get_device_info() -> Dict[str, Any]:
"""Retrieves information about the available JAX devices."""
available_devices = jax.devices()
return {
"backend": "jax",
"devices": [str(d) for d in available_devices],
"device_count": len(available_devices),
}


def is_multi_device_capable() -> bool:
"""Checks if more than one JAX device is available."""
return jax.local_device_count() > 1


def get_communication_ops() -> Dict[str, Callable]:
"""
Provides a dictionary of JAX collective communication operations.

Note: These operations are thin wrappers around `jax.lax` primitives
and are intended to be used exclusively within a `jax.pmap` context.
Calling them outside of `pmap` will result in an error.

Returns:
Dict[str, Callable]: A dictionary mapping operation names to their
JAX implementations.
"""

def all_reduce(
x: jnp.ndarray,
op: Literal["sum", "mean"] = "sum",
axis_name: str = "data",
) -> jnp.ndarray:
"""Reduces a tensor across all devices in a `pmap`.

Args:
x (jnp.ndarray): The tensor to reduce.
op (Literal["sum", "mean"], optional): The reduction operation.
Defaults to "sum".
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The reduced tensor.
"""
reduce_ops = {
"sum": lax.psum,
"mean": lax.pmean,
}
reduce_fn = reduce_ops.get(op)

if reduce_fn is None:
raise ValueError(f"Unsupported all_reduce op: {op}")
return reduce_fn(x, axis_name=axis_name)

def all_gather(
x: jnp.ndarray, axis: int = 0, axis_name: str = "data"
) -> jnp.ndarray:
"""Gathers tensors from all devices and concatenates them.

Args:
x (jnp.ndarray): The local tensor to gather.
axis (int, optional): The axis to concatenate along. Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The concatenated tensor from all devices.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis)

def broadcast(
x: jnp.ndarray, root: int = 0, axis_name: str = "data"
) -> jnp.ndarray:
"""Broadcasts a tensor from a root device to all other devices.

This is implemented by gathering the tensor from all devices and then
having each device select the tensor from the `root` device. It assumes
the value of `x` on the `root` device is the one to be broadcast.

Args:
x (jnp.ndarray): The tensor to broadcast.
root (int, optional): The rank of the source device. Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The tensor received from the root device.
"""
return lax.all_gather(x, axis_name=axis_name, axis=0)[root]

def scatter(
x: jnp.ndarray,
root: int = 0,
axis: int = 0,
axis_name: str = "data",
) -> jnp.ndarray:
"""Scatters a tensor from a root device to all devices.

Args:
x (jnp.ndarray): On the root device, the full tensor to scatter.
root (int, optional): The rank of the source device. Defaults to 0.
axis (int, optional): The axis along which to split the tensor.
Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The chunk of the tensor for the local device.
"""
full_tensor = broadcast(x, root=root, axis_name=axis_name)

device_id = lax.axis_index(axis_name=axis_name)
num_devices = lax.psum(1, axis_name=axis_name)

if full_tensor.shape[axis] % num_devices != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {num_devices} devices."
)

chunk_size = full_tensor.shape[axis] // num_devices
start_index = device_id * chunk_size
return lax.dynamic_slice_in_dim(
operand=full_tensor,
start_index=start_index,
slice_size=chunk_size,
axis=axis,
)

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
"broadcast": broadcast,
"scatter": scatter,
}
86 changes: 86 additions & 0 deletions keras/src/backend/jax/distributed_backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax
import jax.numpy as jnp
import pytest

from keras.src import backend
from keras.src import ops
from keras.src import testing
from keras.src.backend import distributed_backend


@pytest.mark.skipif(
backend.backend() != "jax" or jax.device_count() < 2,
reason="Test requires JAX backend and at least 2 devices",
)
class TestJaxDistributedFunctions(testing.TestCase):
"""Unit tests for the JAX distributed backend standalone functions."""

def test_get_device_info(self):
"""Test retrieving device information from the JAX backend."""
info = distributed_backend.get_device_info()
self.assertEqual(info["backend"], "jax")
self.assertIsInstance(info["devices"], list)
self.assertEqual(info["device_count"], 8)

def test_is_multi_device_capable(self):
"""Test the boolean check for multi-device capability."""
self.assertTrue(distributed_backend.is_multi_device_capable())

def test_ops_raise_error_outside_pmap(self):
"""Verify that communication ops fail when not in pmap."""
comm_ops = distributed_backend.get_communication_ops()
x = ops.array([1.0, 2.0])
with self.assertRaisesRegex(NameError, "unbound axis name: data"):
comm_ops["all_reduce"](x)

def test_communication_ops_in_pmap(self):
"""Test the communication ops work correctly inside jax.pmap context."""
comm_ops = distributed_backend.get_communication_ops()
world_size = distributed_backend.get_device_info()["device_count"]

x_reduce = ops.array([[1.0, 2.0], [3.0, 4.0]])
sharded_reduce_input = jnp.stack([x_reduce] * world_size)
pmapped_reduce = jax.pmap(
lambda x: comm_ops["all_reduce"](x, op="sum"), axis_name="data"
)
reduced_result = pmapped_reduce(sharded_reduce_input)
expected_reduce = ops.multiply(x_reduce, float(world_size))
self.assertAllClose(reduced_result[0], expected_reduce)

x_gather = jnp.arange(world_size * 2, dtype="float32").reshape(
(world_size, 2)
)
pmapped_gather = jax.pmap(
lambda x: comm_ops["all_gather"](x, axis=0), axis_name="data"
)
gathered_result = pmapped_gather(x_gather)
self.assertAllClose(gathered_result[0], x_gather)

x_broadcast = ops.array([5.0, 6.0])
sharded_broadcast_input = jnp.stack(
[x_broadcast] + [jnp.zeros_like(x_broadcast)] * (world_size - 1)
)
pmapped_broadcast = jax.pmap(
lambda x: comm_ops["broadcast"](x, root=0), axis_name="data"
)
broadcasted_result = pmapped_broadcast(sharded_broadcast_input)
self.assertAllClose(broadcasted_result[0], x_broadcast)

x_scatter = jnp.arange(world_size * 2, dtype="float32").reshape(
(world_size, 2)
)
sharded_scatter_input = jnp.stack(
[x_scatter] + [jnp.zeros_like(x_scatter)] * (world_size - 1)
)
pmapped_scatter = jax.pmap(
lambda x: comm_ops["scatter"](x, root=0, axis=0), axis_name="data"
)
scattered_result = pmapped_scatter(sharded_scatter_input)

fixed_scattered_result = jnp.squeeze(scattered_result, axis=1)
self.assertAllClose(fixed_scattered_result, x_scatter)
39 changes: 39 additions & 0 deletions keras/src/distribution/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from keras.src.backend import distributed_backend


def get_device_info() -> dict:
"""Gets information about available computational devices.

Retrieves details about the devices (e.g., CPU, GPU) that are visible
to the current backend.

Returns:
dict: A dictionary containing information about the available devices.
"""
return distributed_backend.get_device_info()


def is_multi_device_capable() -> bool:
"""Checks if the backend supports multi-device operations.

This function determines if the underlying backend is configured and
capable of running computations across multiple devices.

Returns:
bool: `True` if the backend supports multi-device training,
`False` otherwise.
"""
return distributed_backend.is_multi_device_capable()


def get_communication_ops() -> dict:
"""Gets collective communication operations for the backend.

This function returns a dictionary of collective ops (e.g., `all_reduce`,
`all_gather`) that can be used for distributed communication.

Returns:
dict: A dictionary mapping the names of communication operations
(str) to their callable implementations.
"""
return distributed_backend.get_communication_ops()
Loading