diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..b22ea22547bb 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -44,17 +46,20 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable + distributed_backend = None distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py new file mode 100644 index 000000000000..8bb6e0de1f66 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend.py @@ -0,0 +1,146 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Literal + +import jax +import jax.lax as lax +import jax.numpy as jnp + + +def get_device_info() -> Dict[str, Any]: + """Retrieves information about the available JAX devices.""" + available_devices = jax.devices() + return { + "backend": "jax", + "devices": [str(d) for d in available_devices], + "device_count": len(available_devices), + } + + +def is_multi_device_capable() -> bool: + """Checks if more than one JAX device is available.""" + return jax.local_device_count() > 1 + + +def get_communication_ops() -> Dict[str, Callable]: + """ + Provides a dictionary of JAX collective communication operations. + + Note: These operations are thin wrappers around `jax.lax` primitives + and are intended to be used exclusively within a `jax.pmap` context. + Calling them outside of `pmap` will result in an error. + + Returns: + Dict[str, Callable]: A dictionary mapping operation names to their + JAX implementations. + """ + + def all_reduce( + x: jnp.ndarray, + op: Literal["sum", "mean"] = "sum", + axis_name: str = "data", + ) -> jnp.ndarray: + """Reduces a tensor across all devices in a `pmap`. + + Args: + x (jnp.ndarray): The tensor to reduce. + op (Literal["sum", "mean"], optional): The reduction operation. + Defaults to "sum". + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The reduced tensor. + """ + reduce_ops = { + "sum": lax.psum, + "mean": lax.pmean, + } + reduce_fn = reduce_ops.get(op) + + if reduce_fn is None: + raise ValueError(f"Unsupported all_reduce op: {op}") + return reduce_fn(x, axis_name=axis_name) + + def all_gather( + x: jnp.ndarray, axis: int = 0, axis_name: str = "data" + ) -> jnp.ndarray: + """Gathers tensors from all devices and concatenates them. + + Args: + x (jnp.ndarray): The local tensor to gather. + axis (int, optional): The axis to concatenate along. Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The concatenated tensor from all devices. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis) + + def broadcast( + x: jnp.ndarray, root: int = 0, axis_name: str = "data" + ) -> jnp.ndarray: + """Broadcasts a tensor from a root device to all other devices. + + This is implemented by gathering the tensor from all devices and then + having each device select the tensor from the `root` device. It assumes + the value of `x` on the `root` device is the one to be broadcast. + + Args: + x (jnp.ndarray): The tensor to broadcast. + root (int, optional): The rank of the source device. Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The tensor received from the root device. + """ + return lax.all_gather(x, axis_name=axis_name, axis=0)[root] + + def scatter( + x: jnp.ndarray, + root: int = 0, + axis: int = 0, + axis_name: str = "data", + ) -> jnp.ndarray: + """Scatters a tensor from a root device to all devices. + + Args: + x (jnp.ndarray): On the root device, the full tensor to scatter. + root (int, optional): The rank of the source device. Defaults to 0. + axis (int, optional): The axis along which to split the tensor. + Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The chunk of the tensor for the local device. + """ + full_tensor = broadcast(x, root=root, axis_name=axis_name) + + device_id = lax.axis_index(axis_name=axis_name) + num_devices = lax.psum(1, axis_name=axis_name) + + if full_tensor.shape[axis] % num_devices != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {num_devices} devices." + ) + + chunk_size = full_tensor.shape[axis] // num_devices + start_index = device_id * chunk_size + return lax.dynamic_slice_in_dim( + operand=full_tensor, + start_index=start_index, + slice_size=chunk_size, + axis=axis, + ) + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + "broadcast": broadcast, + "scatter": scatter, + } diff --git a/keras/src/backend/jax/distributed_backend_test.py b/keras/src/backend/jax/distributed_backend_test.py new file mode 100644 index 000000000000..bd2fb20a9766 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend_test.py @@ -0,0 +1,86 @@ +import os + +os.environ["JAX_PLATFORM_NAME"] = "cpu" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + +import jax +import jax.numpy as jnp +import pytest + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "jax" or jax.device_count() < 2, + reason="Test requires JAX backend and at least 2 devices", +) +class TestJaxDistributedFunctions(testing.TestCase): + """Unit tests for the JAX distributed backend standalone functions.""" + + def test_get_device_info(self): + """Test retrieving device information from the JAX backend.""" + info = distributed_backend.get_device_info() + self.assertEqual(info["backend"], "jax") + self.assertIsInstance(info["devices"], list) + self.assertEqual(info["device_count"], 8) + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + self.assertTrue(distributed_backend.is_multi_device_capable()) + + def test_ops_raise_error_outside_pmap(self): + """Verify that communication ops fail when not in pmap.""" + comm_ops = distributed_backend.get_communication_ops() + x = ops.array([1.0, 2.0]) + with self.assertRaisesRegex(NameError, "unbound axis name: data"): + comm_ops["all_reduce"](x) + + def test_communication_ops_in_pmap(self): + """Test the communication ops work correctly inside jax.pmap context.""" + comm_ops = distributed_backend.get_communication_ops() + world_size = distributed_backend.get_device_info()["device_count"] + + x_reduce = ops.array([[1.0, 2.0], [3.0, 4.0]]) + sharded_reduce_input = jnp.stack([x_reduce] * world_size) + pmapped_reduce = jax.pmap( + lambda x: comm_ops["all_reduce"](x, op="sum"), axis_name="data" + ) + reduced_result = pmapped_reduce(sharded_reduce_input) + expected_reduce = ops.multiply(x_reduce, float(world_size)) + self.assertAllClose(reduced_result[0], expected_reduce) + + x_gather = jnp.arange(world_size * 2, dtype="float32").reshape( + (world_size, 2) + ) + pmapped_gather = jax.pmap( + lambda x: comm_ops["all_gather"](x, axis=0), axis_name="data" + ) + gathered_result = pmapped_gather(x_gather) + self.assertAllClose(gathered_result[0], x_gather) + + x_broadcast = ops.array([5.0, 6.0]) + sharded_broadcast_input = jnp.stack( + [x_broadcast] + [jnp.zeros_like(x_broadcast)] * (world_size - 1) + ) + pmapped_broadcast = jax.pmap( + lambda x: comm_ops["broadcast"](x, root=0), axis_name="data" + ) + broadcasted_result = pmapped_broadcast(sharded_broadcast_input) + self.assertAllClose(broadcasted_result[0], x_broadcast) + + x_scatter = jnp.arange(world_size * 2, dtype="float32").reshape( + (world_size, 2) + ) + sharded_scatter_input = jnp.stack( + [x_scatter] + [jnp.zeros_like(x_scatter)] * (world_size - 1) + ) + pmapped_scatter = jax.pmap( + lambda x: comm_ops["scatter"](x, root=0, axis=0), axis_name="data" + ) + scattered_result = pmapped_scatter(sharded_scatter_input) + + fixed_scattered_result = jnp.squeeze(scattered_result, axis=1) + self.assertAllClose(fixed_scattered_result, x_scatter) diff --git a/keras/src/distribution/distributed_backend.py b/keras/src/distribution/distributed_backend.py new file mode 100644 index 000000000000..80ad9ccdad98 --- /dev/null +++ b/keras/src/distribution/distributed_backend.py @@ -0,0 +1,39 @@ +from keras.src.backend import distributed_backend + + +def get_device_info() -> dict: + """Gets information about available computational devices. + + Retrieves details about the devices (e.g., CPU, GPU) that are visible + to the current backend. + + Returns: + dict: A dictionary containing information about the available devices. + """ + return distributed_backend.get_device_info() + + +def is_multi_device_capable() -> bool: + """Checks if the backend supports multi-device operations. + + This function determines if the underlying backend is configured and + capable of running computations across multiple devices. + + Returns: + bool: `True` if the backend supports multi-device training, + `False` otherwise. + """ + return distributed_backend.is_multi_device_capable() + + +def get_communication_ops() -> dict: + """Gets collective communication operations for the backend. + + This function returns a dictionary of collective ops (e.g., `all_reduce`, + `all_gather`) that can be used for distributed communication. + + Returns: + dict: A dictionary mapping the names of communication operations + (str) to their callable implementations. + """ + return distributed_backend.get_communication_ops() diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..ff9bd854743b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,171 @@ +import keras + + +class LayoutAction: + """Abstract base class for actions that transform tensors for distribution. + + A LayoutAction defines a rule for how a single tensor should be physically + represented across multiple devices. It includes forward operation + (`__call__`) to shard the tensor and a reverse operation (`undo`) + to reconstruct it.""" + + def __call__(self, tensor, rank): + """Applies the distribution action to a tensor for a specific worker. + + Args: + tensor: The input tensor to be distributed. + rank: The integer rank of the current worker/device. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + A shard or transformation of the input tensor specific to the given + rank. + """ + raise NotImplementedError + + def undo(self, tensors): + """Reverses the distribution action, reconstructing the original tensor. + + Args: + tensors: A sequence of tensor shards from all workers. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + The reconstructed, single tensor. + """ + raise NotImplementedError + + +class _ConcatenateMixin: + """A mixin class providing a common `undo` method via concatenation. + + This class is intended to be used as a mixin for `LayoutAction` subclasses + that can be undone by simple concatenation. + """ + + def undo(self, tensors): + """Concatenates sequence of tensors to reconstruct the original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Returns: + The single tensor reconstructed by concatenating the shards. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensors[0]) - 1 + else: + dim = self.dim + return keras.ops.concatenate(tensors, axis=dim) + + +class Split(_ConcatenateMixin, LayoutAction): + """Splits a tensor into shards along a specified dimension for each worker. + + This action implements sharding by slicing a tensor along one of its axes. + It handles cases where the dimension size is not perfectly divisible by the + number of workers by distributing the remainder elements one by one to the + first few workers. + + The `undo` operation is handled by the `_ConcatenateMixin`, which + concatenates the shards back together. + + Args: + world_size (int): The total number of workers/shards. + dim (int): The dimension along which to split the tensor. If -1, the + last dimension is used. + sharding_type (str): If `dim` is -1, this can be 'row' (dim=0) or + 'column' (dim=1) to infer the split axis for 2D tensors. + Defaults to "auto". + """ + + def __init__(self, world_size, dim, sharding_type="auto"): + """Initializes the Split action. + + Args: + world_size (int): The total number of workers/shards. + dim (int): The dimension along which to split the tensor. + sharding_type (str): A hint for inferring the dimension if `dim` + is -1. + """ + super().__init__() + self.world_size = world_size + self.dim = dim + self.sharding_type = sharding_type + + if dim == -1 and sharding_type != "auto": + if sharding_type == "row": + self.dim = 0 + elif sharding_type == "column": + self.dim = 1 + + def __call__(self, tensor, rank): + """Splits the tensor and returns the shard corresponding to the rank. + + This method calculates the correct slice of the tensor for a given + worker rank, handling uneven distributions gracefully. + + Args: + tensor: The full tensor to be sharded. + rank (int): The rank of the worker for which to get the shard. + + Returns: + A tensor shard corresponding to the given rank. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensor) - 1 + else: + dim = self.dim + + total_size = tensor.shape[dim] + split_size = total_size // self.world_size + remainder = total_size % self.world_size + + start_idx = rank * split_size + min(rank, remainder) + end_idx = start_idx + split_size + (1 if rank < remainder else 0) + + slices = [slice(None)] * keras.ops.ndim(tensor) + slices[dim] = slice(start_idx, end_idx) + return tensor[tuple(slices)] + + +class LayoutMap: + """A mapping that defines layout rules for model states and outputs. + + This class acts as a configuration object that holds dictionaries of + `LayoutAction` instances. These rules specify how model variables (states) + and layer outputs should be distributed across a set of devices. + + Attributes: + state_rules (dict): A dictionary mapping variable names or patterns to + `LayoutAction` instances. + output_rules (dict): A dictionary mapping layer output names or + patterns to `LayoutAction` instances. + """ + + def __init__(self, state_rules, output_rules): + """Initializes the LayoutMap. + + Args: + state_rules (dict): A dictionary of rules for model states. + output_rules (dict): A dictionary of rules for model outputs. + """ + self.state_rules = state_rules + self.output_rules = output_rules + + def create_collective_ops(self, devices): + """Creates the necessary collective communication operations. + + Args: + devices: A sequence of device identifiers. + + Returns: + The `LayoutMap` instance itself. + """ + return self diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py new file mode 100644 index 000000000000..42000f36f82e --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -0,0 +1,164 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutAction +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Test requires JAX backend", +) +class LayoutTest(testing.TestCase): + """Test suite for tensor layout actions and mappings.""" + + def test_layout_action_abstract_methods_raise_error(self): + """Ensures the base class methods raise NotImplementedError.""" + action = LayoutAction() + with self.assertRaises(NotImplementedError): + action(tensor=None, rank=0) + with self.assertRaises(NotImplementedError): + action.undo(tensors=None) + + # --- Split Action Tests --- + + def test_split_with_even_division(self): + """Tests splitting a tensor that divides evenly among workers.""" + world_size = 4 + # Create a tensor of shape (8, 2) + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (8, 2) + ) + action = Split(world_size=world_size, dim=0) + + # Expected shard for rank 0 has shape (2, 2) + expected_shard_0 = keras.ops.array([[0.0, 1.0], [2.0, 3.0]]) + # Expected shard for rank 2 has shape (2, 2) + expected_shard_2 = keras.ops.array([[8.0, 9.0], [10.0, 11.0]]) + + shard_0 = action(tensor, rank=0) + shard_2 = action(tensor, rank=2) + + self.assertAllClose(shard_0, expected_shard_0) + self.assertAllClose(shard_2, expected_shard_2) + self.assertEqual(shard_0.shape, (2, 2)) + + def test_split_with_uneven_division(self): + """Tests splitting where the remainder is distributed correctly.""" + world_size = 3 + # Create a tensor of shape (10, 1). 10 / 3 = 3 with remainder 1. + tensor = keras.ops.reshape( + keras.ops.arange(10, dtype="float32"), (10, 1) + ) + action = Split(world_size=world_size, dim=0) + + # Rank 0 should get 3 + 1 = 4 rows. + shard_0 = action(tensor, rank=0) + self.assertEqual(shard_0.shape, (4, 1)) + self.assertAllClose( + shard_0, keras.ops.array([[0.0], [1.0], [2.0], [3.0]]) + ) + + # Rank 1 should get 3 rows. + shard_1 = action(tensor, rank=1) + self.assertEqual(shard_1.shape, (3, 1)) + self.assertAllClose(shard_1, keras.ops.array([[4.0], [5.0], [6.0]])) + + # Rank 2 should get 3 rows. + shard_2 = action(tensor, rank=2) + self.assertEqual(shard_2.shape, (3, 1)) + self.assertAllClose(shard_2, keras.ops.array([[7.0], [8.0], [9.0]])) + + def test_split_and_undo_cycle_even(self): + """Tests splitting and reconstructing evenly divisible tensor.""" + world_size = 2 + original_tensor = keras.ops.reshape( + keras.ops.arange(12, dtype="float32"), (6, 2) + ) + action = Split(world_size=world_size, dim=0) + + # Create all shards + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Reconstruct the tensor + reconstructed_tensor = action.undo(shards) + + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_and_undo_cycle_uneven(self): + """Tests the full cycle for an unevenly distributed tensor.""" + world_size = 4 + # 11 / 4 = 2 with a remainder of 3. + original_tensor = keras.ops.reshape( + keras.ops.arange(22, dtype="float32"), (11, 2) + ) + action = Split(world_size=world_size, dim=0) + + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Verify shard shapes: first 3 get 2+1=3 rows, last one gets 2. + self.assertEqual(shards[0].shape, (3, 2)) + self.assertEqual(shards[1].shape, (3, 2)) + self.assertEqual(shards[2].shape, (3, 2)) + self.assertEqual(shards[3].shape, (2, 2)) + + reconstructed_tensor = action.undo(shards) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_last_dimension_with_undo(self): + """Tests splitting on the last dimension using dim=-1.""" + world_size = 3 + original_tensor = keras.ops.reshape( + keras.ops.arange(30, dtype="float32"), (2, 5, 3) + ) + action = Split(world_size=world_size, dim=-1) + + shards = [action(original_tensor, rank=i) for i in range(world_size)] + + # Each shard should have the last dimension split. + self.assertEqual(shards[0].shape, (2, 5, 1)) + self.assertEqual(shards[1].shape, (2, 5, 1)) + self.assertEqual(shards[2].shape, (2, 5, 1)) + + reconstructed_tensor = action.undo(shards) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_with_sharding_type_hint(self): + """Tests using 'row' and 'column' sharding hints for 2D tensors.""" + world_size = 2 + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (4, 4) + ) + + # Row sharding should split along axis 0 + action_row = Split(world_size=world_size, dim=-1, sharding_type="row") + shard_row_0 = action_row(tensor, rank=0) + self.assertAllClose(shard_row_0, tensor[:2, :]) + self.assertEqual(action_row.dim, 0) + + # Column sharding should split along axis 1 + action_col = Split( + world_size=world_size, dim=-1, sharding_type="column" + ) + shard_col_0 = action_col(tensor, rank=0) + self.assertAllClose(shard_col_0, tensor[:, :2]) + self.assertEqual(action_col.dim, 1) + + # --- LayoutMap Tests --- + + def test_layout_map_initialization_and_methods(self): + """Tests basic initialization and method behavior of LayoutMap class.""" + state_rules = {"kernel": Split(world_size=2, dim=0)} + output_rules = {"output": Split(world_size=2, dim=-1)} + + layout_map = LayoutMap(state_rules, output_rules) + + self.assertIs(layout_map.state_rules["kernel"], state_rules["kernel"]) + self.assertIs(layout_map.output_rules["output"], output_rules["output"]) + + self.assertIs( + layout_map.create_collective_ops(devices=["cpu:0"]), layout_map + )