From dd3181eceb093827b088fd017932d71ff85d06bf Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 1 Oct 2025 15:59:38 +0530 Subject: [PATCH 01/12] adding autoconfig and coordinated_optimizer --- .../tensor_parallel/autoconfig.py | 222 ++++++ .../tensor_parallel/autoconfig_test.py | 146 ++++ .../tensor_parallel/coordinated_optimizer.py | 646 ++++++++++++++++++ .../coordinated_optimizer_test.py | 154 +++++ 4 files changed, 1168 insertions(+) create mode 100644 keras/src/distribution/tensor_parallel/autoconfig.py create mode 100644 keras/src/distribution/tensor_parallel/autoconfig_test.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..b1c0bb9d5e19 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,222 @@ +from typing import Sequence + +from keras.src import layers +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.models import Model + + +def analyze_dense_layer_directly( + layer: layers.Dense, module: Model, prefix: str +) -> str: + """Analyzes a Dense layer to classify it for tensor parallelism sharding. + + This function inspects the layer's weight shapes to determine if it's an + "up-projection" (expanding feature dimensions), a "down-projection" + (contracting feature dimensions), or a generic layer. This classification + helps in deciding whether to apply column-wise or row-wise parallelism. + + Args: + layer: The keras.layers.Dense instance to analyze. + module: The parent Keras model containing the layer. + prefix: The hierarchical name prefix for the layer. + + Returns: + A string indicating the layer's classification: 'up_projection', + 'down_projection', or 'generic_dense'. + """ + if not isinstance(layer, layers.Dense): + return "generic_dense" + + input_dim = None + output_dim = None + + if hasattr(layer, "kernel"): + kernel_shape = layer.kernel.shape + if len(kernel_shape) == 2: + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] + else: + if hasattr(layer, "units"): + output_dim = layer.units + + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + + if not input_dim or not output_dim: + return "generic_dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "generic_dense" + + +def _traverse_and_shard_layer( + current_layer: layers.Layer, + module: Model, + world_size: int, + state_rules: dict, + output_rules: dict, + processed_layers: set, + prefix: str = "", +): + """Traverses a layer and its sub-layers to apply sharding rules. + + This function navigates through the model's layer hierarchy. For each + layer, it identifies its type and applies appropriate sharding logic, + populating the `state_rules` and `output_rules` dictionaries. + + Args: + current_layer: The current keras.Layer object to be processed. + module: The top-level Keras Model, used for context analysis. + world_size: The total number of devices for sharding. + state_rules: The dictionary of state sharding rules to populate. + output_rules: The dictionary of output sharding rules to populate. + processed_layers: A set of layer IDs that have already been processed + to avoid redundant computation and infinite loops. + prefix: The hierarchical name prefix from parent layers, used to + construct the full unique name for the current layer. + """ + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(current_layer, layers.Dense): + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) + + if mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + output_rules[f"^{full_name}$"] = {0: "allreduce"} + + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, layers.EinsumDense): + is_row_parallel = False + if "->" in current_layer.equation: + equation_parts = current_layer.equation.split("->") + if len(equation_parts) == 2: + input_spec = equation_parts[0].split(",")[0].strip() + output_spec = equation_parts[1].strip() + if ( + input_spec + and output_spec + and len(output_spec) < len(input_spec) + ): + is_row_parallel = True + + if is_row_parallel: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + output_rules[f"^{full_name}$"] = {0: "allreduce"} + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, layers.Embedding): + weight_name = ( + "embeddings" if hasattr(current_layer, "embeddings") else None + ) + if weight_name: + state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( + world_size, 1, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): + return + else: + if hasattr(current_layer, "layers"): + for sub_layer in current_layer.layers: + _traverse_and_shard_layer( + sub_layer, + module, + world_size, + state_rules, + output_rules, + processed_layers, + full_name, + ) + + +def get_default_config_keras( + module: Model, device_ids: Sequence[str] +) -> ConfigKeras: + """Generates a smart, recursive sharding configuration for a Keras model. + + This function traverses the layers of a given Keras model and applies a + set of heuristics to automatically determine how each layer's weights + and outputs should be sharded for tensor parallelism. It uses a helper + function to perform the recursive traversal. + + Args: + module: The Keras Model to generate a sharding configuration for. + device_ids: A sequence of device identifiers, used to determine the + world size (number of devices) for sharding. + + Returns: + A ConfigKeras object containing the generated 'state_rules' (for model + parameters) and 'output_rules' (for layer outputs). + """ + world_size = len(device_ids) + state_rules = {} + output_rules = {} + processed_layers = set() + + for layer in module.layers: + _traverse_and_shard_layer( + current_layer=layer, + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + prefix="", + ) + + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..9b955f00525b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,146 @@ +import os + +if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "4" + +from keras import Input +from keras import Model +from keras import layers +from keras.src import testing +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +class TestAutoConfigKeras(testing.TestCase): + def setUp(self): + """Set up the test case and common variables.""" + super().setUp() + self.world_size = int(os.environ["WORLD_SIZE"]) + self.device_ids = [f"device:{i}" for i in range(self.world_size)] + + def _assert_split_keras_equal(self, rule1, rule2): + """ + Helper to compare two SplitKeras objects by their attributes. + MODIFIED: Use vars() for robust comparison without knowing attr names. + """ + self.assertIsInstance(rule1, SplitKeras) + self.assertIsInstance(rule2, SplitKeras) + self.assertDictEqual(vars(rule1), vars(rule2)) + + def _assert_rules_equal(self, actual_rules, expected_rules): + """Helper to compare two dictionaries of sharding rules.""" + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) + for key in expected_rules: + actual_val = actual_rules[key] + expected_val = expected_rules[key] + if isinstance(expected_val, SplitKeras): + self._assert_split_keras_equal(actual_val, expected_val) + else: + self.assertEqual(actual_val, expected_val) + + def test_analyze_dense_layer(self): + """Tests the direct analysis and classification of Dense layers.""" + up_proj_layer = layers.Dense(32) + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", + ) + + down_proj_layer = layers.Dense(16) + down_proj_layer.build(input_shape=(None, 32)) + self.assertEqual( + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", + ) + + def test_simple_mlp_sharding(self): + """Tests a simple MLP with up and down projection layers.""" + inputs = Input(shape=(64,)) + x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) + outputs = layers.Dense( + 64, name="down_projection_layer", use_bias=False + )(x) + model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^up_projection_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^up_projection_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^down_projection_layer.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^up_projection_layer$": {0: "no_comm"}, + r"^down_projection_layer$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_embedding_sharding(self): + """Tests an Embedding layer.""" + inputs = Input(shape=(10,), dtype="int32") + outputs = layers.Embedding( + input_dim=1000, output_dim=128, name="token_embedding" + )(inputs) + model = Model(inputs=inputs, outputs=outputs, name="embed_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^token_embedding\.embeddings$": SplitKeras( + self.world_size, 1, "column" + ) + } + expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_nested_model_sharding(self): + """Tests that the traversal logic correctly handles nested models.""" + inner_inputs = Input(shape=(32,)) + inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) + inner_model = Model( + inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + ) + + outer_inputs = Input(shape=(32,)) + x = inner_model(outer_inputs) + outer_outputs = layers.Dense(32, name="outer_dense")(x) + outer_model = Model( + inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + ) + + config = get_default_config_keras(outer_model, self.device_ids) + + expected_state_rules = { + r"^inner_block.inner_dense.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^inner_block.inner_dense.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), + } + expected_output_rules = { + r"^inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_dense$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..77e5c13629b6 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,646 @@ +import re +from typing import Any +from typing import Dict +from typing import List + +import numpy as np + +import keras +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.distributed import backend_resolver + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. It is not + intended to be used directly by the end-user but is a core component of + the `TensorParallelOptimizer`. + + Args: + base_optimizer: The Keras optimizer instance + (e.g., `keras.optimizers.Adam`) whose state will be managed. + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + rank: The rank of the current process. Defaults to 0. + shard_optimizer_states: If `True`, the optimizer's state variables + (e.g., momentum, velocity) will be partitioned across `world_size` + devices. Defaults to `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism, such as which gradients to + all-reduce. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + rank: int = 0, + shard_optimizer_states: bool = True, + tensor_parallel_config=None, + ): + self.base_optimizer = base_optimizer + self.world_size = world_size + self.rank = rank + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._state_variable_to_parameter = {} + self.distributed_backend = ( + backend_resolver.get_distributed_backend(distributed_backend) + if distributed_backend is not None + else None + ) + self._variables = None # Will be set when optimizer is built + + # In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: +# In class CoordinatedOptimizer: + + def _get_optimizer_slot_names(self) -> set: + """ + Deduces the slot names ('m', 'v', etc.) by inspecting the variables + created by the base optimizer. This is the most robust method. + """ + slot_names = set() + # The optimizer's variables have paths like 'Adam/m/dense/kernel'. + # We can extract the second part as the slot name. + for var in self.base_optimizer.variables: + # Skip the iteration counter + if "iteration" in var.path.lower(): + continue + path_parts = var.path.split('/') + if len(path_parts) > 1: + slot_names.add(path_parts[1]) + return slot_names + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + + def _initialize_sharded_states(self): + """ + Partitions the optimizer's state variables across shards by inspecting + the variables created by the base optimizer. This version correctly + parses variable paths like 'optimizer/param_name_slot_name'. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + self._state_variable_to_parameter = {} + opt_name = self.base_optimizer.name + + normalized_params = [ + (p.path.replace('/', '_'), p) for p in self._variables + ] + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + path_parts = state_var.path.split('/') + if len(path_parts) != 2 or path_parts[0] != opt_name: + continue + + state_suffix = path_parts[1] + + found_param = None + slot_name = None + for norm_param_path, param in normalized_params: + if state_suffix.startswith(norm_param_path): + found_param = param + slot_suffix = state_suffix[len(norm_param_path):] + slot_name = slot_suffix.strip('_') + break + + # THE FIX IS HERE: Explicitly check for 'is not None' + if found_param is not None and slot_name is not None: + self._state_variable_to_parameter[state_var.path] = found_param + + sharding_dim = 0 + if self.tensor_parallel_config: + norm_param_name = found_param.path.replace("/", ".") + for (p, a) in self.tensor_parallel_config.state_rules.items(): + if re.search(p, norm_param_name) and hasattr(a, "dim"): + sharding_dim = a.dim + break + + partitioned_state = self._partition_state(state_var, dim=sharding_dim) + self.sharded_states.setdefault(slot_name, {})[found_param.path] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + def _partition_state( + self, state_variable: any, dim: int + ) -> List[np.ndarray]: + """Splits a single state variable numpy array into chunks. + + If the variable cannot be split along the given dimension, it is + replicated across all shards. + + Args: + state_variable: The optimizer state variable. + dim: The dimension along which to partition the variable. + + Returns: + A list of NumPy arrays, where each array is a partition of the + original state variable for a specific shard. + """ + state_array = keras.ops.convert_to_numpy(state_variable) + if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: + return np.array_split(state_array, self.world_size, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.world_size)] + + def get_config(self) -> Dict[str, Any]: + return { + "base_optimizer": self.base_optimizer.get_config(), + "world_size": self.world_size, + "shard_optimizer_states": self.shard_optimizer_states, + } + + def apply_gradients( + self, gradients_and_vars: List[List[tuple]], shard_models: List + ): + """Coordinates gradient synchronization and application. + + This method first synchronizes gradients across all shards based on + tensor parallelism rules. Then, it applies the gradients using either + sharded optimizer states or replicated states. + + Args: + gradients_and_vars: A list of lists, where each inner list contains + (gradient, variable) tuples for a specific model shard. + shard_models: A list of the sharded model instances. + + Raises: + ValueError: If the number of gradient sets does not match the + world size. + """ + if len(gradients_and_vars) != self.world_size: + error_msg = ( + f"Expected {self.world_size} gradient sets, " + f"got {len(gradients_and_vars)}" + ) + raise ValueError(error_msg) + + synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + + if self.shard_optimizer_states and self.sharded_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients: List[List[tuple]], shard_models: List + ): + """Applies gradients to each shard using its local optimizer state. + + For each shard, this method loads the corresponding partition of the + optimizer state into the base optimizer and then applies the shard's + gradients. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + for shard_idx, shard_grads in enumerate(synchronized_gradients): + local_states = self._get_local_optimizer_states(shard_idx) + self._update_optimizer_internal_state( + self.base_optimizer, local_states + ) + self.base_optimizer.apply_gradients(shard_grads) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients: List[List[tuple]], shard_models: List + ): + """Averages gradients across all shards and applies them once. + + This method is used when optimizer state sharding is disabled. It + calculates the average of the gradients for each variable across all + shards and applies the averaged gradients using the single, replicated + optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + summed_grad = grads_for_var[0] + for grad in grads_for_var[1:]: + summed_grad += grad + averaged_grad = summed_grad / len(grads_for_var) + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: + """Constructs the state dictionary for a single shard. + + Args: + shard_idx: The index of the shard for which to retrieve the state. + + Returns: + A dictionary containing the optimizer state variables specific to + the given shard index. + """ + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + + def _update_optimizer_internal_state(self, local_states: dict): + """Assigns local sharded state values to the optimizer's variables.""" + if not self.base_optimizer.built: + return + + for var in self.base_optimizer.variables: + if var is self.base_optimizer.iterations: + if "iterations" in local_states: + var.assign(local_states["iterations"]) + continue + + # THE FIX IS HERE: Use the variable's path for the lookup. + param = self._state_variable_to_parameter.get(var.path, None) + + if param: + # This internal method is the most reliable way to get the + # slot name (e.g., "momentum") from the variable object. + slot_name = ( + self.base_optimizer._get_slot_name_from_variable(var) + ) + if ( + slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + var.assign(local_param_state) + + def _synchronize_gradients( + self, gradients_and_vars: List[List[tuple]] + ) -> List[List[tuple]]: + """Synchronizes gradients across shards based on tensor parallel rules. + + Specifically, it performs an all-reduce operation on gradients of + weights that are split along a "column" dimension in tensor parallelism. + Other gradients are passed through unchanged. + + Args: + gradients_and_vars: The list of (gradient, variable) lists from + all shards. + + Returns: + The list of (gradient, variable) lists after synchronization. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + rules = self.tensor_parallel_config.state_rules.items() + column_parallel_patterns = { + pattern + for pattern, action in rules + if hasattr(action, "sharding_type") + and action.sharding_type == "column" + } + + if not column_parallel_patterns: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + var_name = getattr(variable, "path", getattr(variable, "name", "")) + + if any( + re.search(pattern, var_name) + for pattern in column_parallel_patterns + ): + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.world_size): + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: + """Performs a mean all-reduce operation on a list of gradients. + + If a distributed backend is available, it uses it. Otherwise, it + falls back to a local mean calculation. + + Args: + gradients: A list of gradients (one from each shard) to be averaged. + + Returns: + A list where each element is the mean of the input gradients. + """ + if not gradients: + return [] + + if ( + self.distributed_backend is not None + and self.distributed_backend.is_initialized + ): + numpy_grad = keras.ops.convert_to_numpy(gradients[0]) + synced_numpy = self.distributed_backend.allreduce( + numpy_grad, op="mean" + ) + synced_tensor = keras.ops.convert_to_tensor(synced_numpy) + return [synced_tensor for _ in range(self.world_size)] + + stacked_grads = keras.ops.stack( + [keras.ops.convert_to_tensor(g) for g in gradients], axis=0 + ) + mean_grad = keras.ops.mean(stacked_grads, axis=0) + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self) -> List[np.ndarray]: + """Returns the weights of the base optimizer.""" + return self.base_optimizer.get_weights() + + def set_weights(self, weights: List[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables: List): + """Enables and initializes optimizer state sharding. + + This method is called from `build()`, which is guarded from running + multiple times. We can assume this should always execute. + """ + # The check 'if not self.shard_optimizer_states:' was here and was + # incorrectly preventing this code from running. It has been removed. + self.shard_optimizer_states = True + self._variables = variables + self._initialize_sharded_states() + + def disable_optimizer_state_sharding(self): + """Disables sharding and clears any sharded states. + + This reverts the optimizer to using a single, replicated state. + """ + if self.shard_optimizer_states: + self.shard_optimizer_states = False + self.sharded_states = {} + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This optimizer wraps a standard Keras optimizer (e.g., Adam, SGD) and + delegates the complex tasks of state management and gradient synchronization + to a `CoordinatedOptimizer` instance. It is designed to work with models + that have been sharded for tensor parallelism. + + When `apply_gradients` is called with a list of gradient lists (one for each + model shard), it uses the `CoordinatedOptimizer` to handle synchronization + and state sharding. Otherwise, it behaves like the base optimizer. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier + (e.g., 'adam', 'sgd'). + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + + Example: + + ```python + import keras + + # Assume model variables and gradients from 4 shards exist. + # The structure is: List[List[Tuple[gradient, variable]]] + trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] + sharded_grads_and_vars = [ + [(keras.ops.ones_like(v), v) for v in trainable_vars] + for _ in range(4) # 4 shards + ] + + # 1. Wrap a standard Keras optimizer. + base_optimizer = keras.optimizers.Adam() + optimizer = TensorParallelOptimizer(base_optimizer, world_size=4) + optimizer.build(trainable_vars) + + # 2. Apply the sharded gradients. + # The optimizer will handle synchronization (e.g., all-reduce) internally. + optimizer.apply_gradients(sharded_grads_and_vars) + ``` + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + tensor_parallel_config=None, + ): + if isinstance(base_optimizer, str): + resolved_base_optimizer = optimizers.get(base_optimizer) + else: + resolved_base_optimizer = base_optimizer + + if isinstance( + resolved_base_optimizer.learning_rate, + keras.optimizers.schedules.LearningRateSchedule, + ): + lr_value = float( + ops.convert_to_numpy( + resolved_base_optimizer.learning_rate.initial_learning_rate + ) + ) + else: + lr_value = float( + ops.convert_to_numpy(resolved_base_optimizer.learning_rate) + ) + + super().__init__( + learning_rate=lr_value, + name=f"TensorParallel_{resolved_base_optimizer.name}", + ) + + self.base_optimizer = resolved_base_optimizer + self.world_size = world_size + self.distributed_backend = distributed_backend + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + world_size, + distributed_backend=distributed_backend, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars: List, **kwargs): + """Applies gradients to the model variables. + + If `grads_and_vars` is a list of lists, it's assumed to be from + sharded models, and the `CoordinatedOptimizer` is used. Otherwise, + it calls the `base_optimizer`'s `apply_gradients` directly. + + Args: + grads_and_vars: A list of (gradient, variable) tuples, or a list + of such lists if running in a sharded context. + **kwargs: Additional arguments. `shard_models` can be passed to + provide the list of model shards. + """ + if ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ): + shard_models = kwargs.get("shard_models", []) + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars) + + def get_config(self) -> Dict[str, Any]: + from keras.src import saving + + config = super().get_config() + config.pop("learning_rate", None) + config.pop("name", None) + + config.update( + { + "base_optimizer": saving.serialize_keras_object( + self.base_optimizer + ), + "world_size": self.world_size, + "distributed_backend": self.distributed_backend, + } + ) + return config + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": + from keras.src import saving + + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object(base_optimizer_config) + + init_kwargs = { + "world_size": config.get("world_size"), + "distributed_backend": config.get("distributed_backend", "auto"), + "tensor_parallel_config": config.get("tensor_parallel_config"), + } + + return cls(base_optimizer=base_optimizer, **init_kwargs) + + def build(self, variables: List): + """Builds the optimizer and initializes sharded states. + + This method is called the first time the optimizer is used. It builds + the base optimizer and then triggers the `CoordinatedOptimizer` to + initialize its sharded states. + + Args: + variables: A list of model variables to be optimized. + """ + if self.built: + return + + # First, build the base optimizer with the variables. + self.base_optimizer.build(variables) + print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") + + # THE FINAL FIX: Force slot variable creation by applying zero gradients. + # This is necessary because optimizers create slots lazily on the first + # call to apply_gradients. + if variables: # Only run if there are variables to optimize. + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + # The dry run increments the iteration counter, so we reset it. + if self.base_optimizer.iterations is not None: + self.base_optimizer.iterations.assign(0) + + # Now that all state variables (m, v, etc.) are guaranteed to exist, + # we can safely initialize sharding. + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self) -> List[np.ndarray]: + """Returns the weights of the base optimizer.""" + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights: List[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.coordinated_optimizer.set_weights(weights) + + @property + def variables(self) -> List: + """Returns the list of variables from the base optimizer.""" + return self.base_optimizer.variables + + @property + def learning_rate(self) -> Any: + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..59bfa8118b04 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,154 @@ +import numpy as np +from coordinated_optimizer import CoordinatedOptimizer +from coordinated_optimizer import TensorParallelOptimizer + +import keras +from keras.src import optimizers +from keras.src import testing + + +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, world_size): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(world_size): + multiplier = float(i + 1) + gradients = [ + keras.ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer( + base_optimizer, world_size=4, distributed_backend=None + ) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + + world_size = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord = CoordinatedOptimizer( + optimizer, + world_size, + shard_optimizer_states=False, + distributed_backend=None, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + self.assertAllClose( + optimizer.received_grads[0], + np.ones_like(optimizer.received_grads[0]) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer( + "adam", world_size=4, distributed_backend=None + ) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + world_size = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord_apply_tracker = {"called": False} + optimizer.coordinated_optimizer.apply_gradients = ( + lambda *a, **kw: coord_apply_tracker.update({"called": True}) + ) + base_apply_tracker = {"called": False} + optimizer.base_optimizer.apply_gradients = ( + lambda *a, **kw: base_apply_tracker.update({"called": True}) + ) + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + +# In coordinated_optimizer_test.py + +# In coordinated_optimizer_test.py + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer( + optimizers.Adam(), world_size=4, distributed_backend=None + ) + model = self._get_simple_model() + + # Build the model so its trainable_variables list is populated. + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + + # THE FIX IS HERE: + # Keras Adam uses 'momentum' and 'velocity' as its slot names, not 'm' and 'v'. + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual(len(sharded_states["momentum"][dense_1_kernel_path]), 4) + + def test_serialization(self): + world_size = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.world_size, world_size) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) \ No newline at end of file From bcae2f69ee2d2ee58ce82cebf88d66b8fe4fee89 Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 1 Oct 2025 21:35:28 +0530 Subject: [PATCH 02/12] Reformatting --- .../tensor_parallel/coordinated_optimizer.py | 51 +------------------ .../coordinated_optimizer_test.py | 8 +-- 2 files changed, 3 insertions(+), 56 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 77e5c13629b6..73ea557995e9 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -57,18 +57,7 @@ def __init__( if distributed_backend is not None else None ) - self._variables = None # Will be set when optimizer is built - - # In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: -# In class CoordinatedOptimizer: + self._variables = None def _get_optimizer_slot_names(self) -> set: """ @@ -76,10 +65,7 @@ def _get_optimizer_slot_names(self) -> set: created by the base optimizer. This is the most robust method. """ slot_names = set() - # The optimizer's variables have paths like 'Adam/m/dense/kernel'. - # We can extract the second part as the slot name. for var in self.base_optimizer.variables: - # Skip the iteration counter if "iteration" in var.path.lower(): continue path_parts = var.path.split('/') @@ -87,24 +73,6 @@ def _get_optimizer_slot_names(self) -> set: slot_names.add(path_parts[1]) return slot_names -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - def _initialize_sharded_states(self): """ Partitions the optimizer's state variables across shards by inspecting @@ -141,7 +109,6 @@ def _initialize_sharded_states(self): slot_name = slot_suffix.strip('_') break - # THE FIX IS HERE: Explicitly check for 'is not None' if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param @@ -304,8 +271,6 @@ def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: local_states[state_name] = state_value[shard_idx] return local_states -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - def _update_optimizer_internal_state(self, local_states: dict): """Assigns local sharded state values to the optimizer's variables.""" if not self.base_optimizer.built: @@ -317,12 +282,9 @@ def _update_optimizer_internal_state(self, local_states: dict): var.assign(local_states["iterations"]) continue - # THE FIX IS HERE: Use the variable's path for the lookup. param = self._state_variable_to_parameter.get(var.path, None) if param: - # This internal method is the most reliable way to get the - # slot name (e.g., "momentum") from the variable object. slot_name = ( self.base_optimizer._get_slot_name_from_variable(var) ) @@ -433,8 +395,6 @@ def enable_optimizer_state_sharding(self, variables: List): This method is called from `build()`, which is guarded from running multiple times. We can assume this should always execute. """ - # The check 'if not self.shard_optimizer_states:' was here and was - # incorrectly preventing this code from running. It has been removed. self.shard_optimizer_states = True self._variables = variables self._initialize_sharded_states() @@ -607,23 +567,16 @@ def build(self, variables: List): if self.built: return - # First, build the base optimizer with the variables. self.base_optimizer.build(variables) print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") - # THE FINAL FIX: Force slot variable creation by applying zero gradients. - # This is necessary because optimizers create slots lazily on the first - # call to apply_gradients. - if variables: # Only run if there are variables to optimize. + if variables: zero_grads = [ops.zeros_like(v) for v in variables] self.base_optimizer.apply_gradients(zip(zero_grads, variables)) - # The dry run increments the iteration counter, so we reset it. if self.base_optimizer.iterations is not None: self.base_optimizer.iterations.assign(0) - # Now that all state variables (m, v, etc.) are guaranteed to exist, - # we can safely initialize sharding. self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 59bfa8118b04..ca69361fe383 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -108,9 +108,6 @@ def test_apply_gradients_delegation(self): self.assertTrue(base_apply_tracker["called"]) self.assertFalse(coord_apply_tracker["called"]) -# In coordinated_optimizer_test.py - -# In coordinated_optimizer_test.py def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" @@ -119,7 +116,6 @@ def test_build_and_state_sharding(self): ) model = self._get_simple_model() - # Build the model so its trainable_variables list is populated. model.build(input_shape=(None, 10)) self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) @@ -127,9 +123,7 @@ def test_build_and_state_sharding(self): self.assertTrue(optimizer.built) sharded_states = optimizer.coordinated_optimizer.sharded_states - - # THE FIX IS HERE: - # Keras Adam uses 'momentum' and 'velocity' as its slot names, not 'm' and 'v'. + self.assertIn("momentum", sharded_states) self.assertIn("velocity", sharded_states) self.assertIn("iterations", sharded_states) From 439643b33fd377041ac299a3ddb76baf15ca52b6 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 2 Oct 2025 15:45:00 +0530 Subject: [PATCH 03/12] Added sharding keras --- .../tensor_parallel/autoconfig.py | 22 ++- .../tensor_parallel/autoconfig_test.py | 15 +- .../tensor_parallel/coordinated_optimizer.py | 152 ++++++++---------- .../coordinated_optimizer_test.py | 47 ++++-- .../tensor_parallel/sharding_keras.py | 85 ++++++++++ 5 files changed, 207 insertions(+), 114 deletions(-) create mode 100644 keras/src/distribution/tensor_parallel/sharding_keras.py diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index b1c0bb9d5e19..cf5966eb4670 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,14 +1,12 @@ from typing import Sequence -from keras.src import layers from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras -from keras.src.models import Model -def analyze_dense_layer_directly( - layer: layers.Dense, module: Model, prefix: str -) -> str: +def analyze_dense_layer_directly(layer, module, prefix: str) -> str: + from keras.src import layers + """Analyzes a Dense layer to classify it for tensor parallelism sharding. This function inspects the layer's weight shapes to determine if it's an @@ -63,14 +61,16 @@ def analyze_dense_layer_directly( def _traverse_and_shard_layer( - current_layer: layers.Layer, - module: Model, + current_layer, + module, world_size: int, state_rules: dict, output_rules: dict, processed_layers: set, prefix: str = "", ): + from keras.src import layers + """Traverses a layer and its sub-layers to apply sharding rules. This function navigates through the model's layer hierarchy. For each @@ -145,8 +145,8 @@ def _traverse_and_shard_layer( and current_layer.bias is not None ): state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "no_comm"} return @@ -184,9 +184,7 @@ def _traverse_and_shard_layer( ) -def get_default_config_keras( - module: Model, device_ids: Sequence[str] -) -> ConfigKeras: +def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: """Generates a smart, recursive sharding configuration for a Keras model. This function traverses the layers of a given Keras model and applies a diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 9b955f00525b..ab9a1b4149c1 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,12 +1,12 @@ import os -if "WORLD_SIZE" not in os.environ: - os.environ["WORLD_SIZE"] = "4" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" from keras import Input from keras import Model from keras import layers from keras.src import testing +from keras.src.backend.distributed import backend_resolver from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, ) @@ -20,13 +20,18 @@ class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() - self.world_size = int(os.environ["WORLD_SIZE"]) + backend = backend_resolver.get_distributed_backend() + device_info = backend.get_device_info() + self.world_size = device_info["device_count"] self.device_ids = [f"device:{i}" for i in range(self.world_size)] + self.assertGreater( + self.world_size, 1, "Distribution tests require more than 1 device." + ) + def _assert_split_keras_equal(self, rule1, rule2): """ Helper to compare two SplitKeras objects by their attributes. - MODIFIED: Use vars() for robust comparison without knowing attr names. """ self.assertIsInstance(rule1, SplitKeras) self.assertIsInstance(rule2, SplitKeras) @@ -143,4 +148,4 @@ def test_nested_model_sharding(self): } self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 73ea557995e9..726747676c0b 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -58,6 +58,7 @@ def __init__( else None ) self._variables = None + self._variable_to_slot_name = {} def _get_optimizer_slot_names(self) -> set: """ @@ -68,7 +69,7 @@ def _get_optimizer_slot_names(self) -> set: for var in self.base_optimizer.variables: if "iteration" in var.path.lower(): continue - path_parts = var.path.split('/') + path_parts = var.path.split("/") if len(path_parts) > 1: slot_names.add(path_parts[1]) return slot_names @@ -84,20 +85,21 @@ def _initialize_sharded_states(self): self.sharded_states = {} self._state_variable_to_parameter = {} + self._variable_to_slot_name = {} # Reset the map opt_name = self.base_optimizer.name normalized_params = [ - (p.path.replace('/', '_'), p) for p in self._variables + (p.path.replace("/", "_"), p) for p in self._variables ] for state_var in self.base_optimizer.variables: if state_var is self.base_optimizer.iterations: continue - path_parts = state_var.path.split('/') + path_parts = state_var.path.split("/") if len(path_parts) != 2 or path_parts[0] != opt_name: continue - + state_suffix = path_parts[1] found_param = None @@ -105,28 +107,35 @@ def _initialize_sharded_states(self): for norm_param_path, param in normalized_params: if state_suffix.startswith(norm_param_path): found_param = param - slot_suffix = state_suffix[len(norm_param_path):] - slot_name = slot_suffix.strip('_') + slot_suffix = state_suffix[len(norm_param_path) :] + slot_name = slot_suffix.strip("_") break if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param + # MODIFIED: Store the mapping from variable path to slot name + self._variable_to_slot_name[state_var.path] = slot_name sharding_dim = 0 if self.tensor_parallel_config: norm_param_name = found_param.path.replace("/", ".") - for (p, a) in self.tensor_parallel_config.state_rules.items(): + for p, a in self.tensor_parallel_config.state_rules.items(): if re.search(p, norm_param_name) and hasattr(a, "dim"): sharding_dim = a.dim break - - partitioned_state = self._partition_state(state_var, dim=sharding_dim) - self.sharded_states.setdefault(slot_name, {})[found_param.path] = partitioned_state + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state if self.base_optimizer.iterations is not None: self.sharded_states["iterations"] = self._partition_state( self.base_optimizer.iterations, dim=0 ) + def _partition_state( self, state_variable: any, dim: int ) -> List[np.ndarray]: @@ -143,7 +152,7 @@ def _partition_state( A list of NumPy arrays, where each array is a partition of the original state variable for a specific shard. """ - state_array = keras.ops.convert_to_numpy(state_variable) + state_array = ops.convert_to_numpy(state_variable) if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: return np.array_split(state_array, self.world_size, axis=dim) else: @@ -192,26 +201,6 @@ def apply_gradients( synchronized_gradients, shard_models ) - def _apply_gradients_with_sharded_states( - self, synchronized_gradients: List[List[tuple]], shard_models: List - ): - """Applies gradients to each shard using its local optimizer state. - - For each shard, this method loads the corresponding partition of the - optimizer state into the base optimizer and then applies the shard's - gradients. - - Args: - synchronized_gradients: The gradients after synchronization. - shard_models: The list of sharded models. - """ - for shard_idx, shard_grads in enumerate(synchronized_gradients): - local_states = self._get_local_optimizer_states(shard_idx) - self._update_optimizer_internal_state( - self.base_optimizer, local_states - ) - self.base_optimizer.apply_gradients(shard_grads) - def _apply_gradients_with_replicated_states( self, synchronized_gradients: List[List[tuple]], shard_models: List ): @@ -240,10 +229,12 @@ def _apply_gradients_with_replicated_states( if not grads_for_var: continue - summed_grad = grads_for_var[0] - for grad in grads_for_var[1:]: - summed_grad += grad - averaged_grad = summed_grad / len(grads_for_var) + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + averaged_grads_and_vars.append((averaged_grad, variable)) if averaged_grads_and_vars: @@ -271,30 +262,29 @@ def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: local_states[state_name] = state_value[shard_idx] return local_states - def _update_optimizer_internal_state(self, local_states: dict): + def _update_optimizer_internal_state(self, optimizer, local_states: dict): """Assigns local sharded state values to the optimizer's variables.""" - if not self.base_optimizer.built: + if not optimizer.built: return - for var in self.base_optimizer.variables: - if var is self.base_optimizer.iterations: + for var in optimizer.variables: + if var is optimizer.iterations: if "iterations" in local_states: - var.assign(local_states["iterations"]) + ops.assign(var, local_states["iterations"]) continue param = self._state_variable_to_parameter.get(var.path, None) - - if param: - slot_name = ( - self.base_optimizer._get_slot_name_from_variable(var) - ) - if ( - slot_name in local_states - and param.path in local_states[slot_name] - ): - local_param_state = local_states[slot_name][param.path] - if var.shape == local_param_state.shape: - var.assign(local_param_state) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + ops.assign(var, local_param_state) def _synchronize_gradients( self, gradients_and_vars: List[List[tuple]] @@ -364,26 +354,25 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: if not gradients: return [] - if ( - self.distributed_backend is not None - and self.distributed_backend.is_initialized - ): - numpy_grad = keras.ops.convert_to_numpy(gradients[0]) - synced_numpy = self.distributed_backend.allreduce( + if self.distributed_backend is not None: + numpy_grad = ops.convert_to_numpy(gradients[0]) + synced_numpy = self.distributed_backend.all_reduce( numpy_grad, op="mean" ) - synced_tensor = keras.ops.convert_to_tensor(synced_numpy) + synced_tensor = ops.convert_to_tensor(synced_numpy) return [synced_tensor for _ in range(self.world_size)] stacked_grads = keras.ops.stack( - [keras.ops.convert_to_tensor(g) for g in gradients], axis=0 + [ops.convert_to_tensor(g) for g in gradients], axis=0 ) - mean_grad = keras.ops.mean(stacked_grads, axis=0) + mean_grad = ops.mean(stacked_grads, axis=0) return [mean_grad for _ in range(len(gradients))] def get_weights(self) -> List[np.ndarray]: """Returns the weights of the base optimizer.""" - return self.base_optimizer.get_weights() + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] def set_weights(self, weights: List[np.ndarray]): """Sets the weights of the base optimizer.""" @@ -463,30 +452,22 @@ def __init__( tensor_parallel_config=None, ): if isinstance(base_optimizer, str): - resolved_base_optimizer = optimizers.get(base_optimizer) + base_optimizer_instance = optimizers.get(base_optimizer) else: - resolved_base_optimizer = base_optimizer + base_optimizer_instance = base_optimizer - if isinstance( - resolved_base_optimizer.learning_rate, - keras.optimizers.schedules.LearningRateSchedule, - ): - lr_value = float( - ops.convert_to_numpy( - resolved_base_optimizer.learning_rate.initial_learning_rate - ) - ) + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) else: - lr_value = float( - ops.convert_to_numpy(resolved_base_optimizer.learning_rate) - ) + lr_value = float(ops.convert_to_numpy(learning_rate)) super().__init__( learning_rate=lr_value, - name=f"TensorParallel_{resolved_base_optimizer.name}", + name=f"TensorParallel_{base_optimizer_instance.name}", ) - self.base_optimizer = resolved_base_optimizer + self.base_optimizer = base_optimizer_instance self.world_size = world_size self.distributed_backend = distributed_backend self.coordinated_optimizer = CoordinatedOptimizer( @@ -568,15 +549,10 @@ def build(self, variables: List): return self.base_optimizer.build(variables) - print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") - if variables: zero_grads = [ops.zeros_like(v) for v in variables] self.base_optimizer.apply_gradients(zip(zero_grads, variables)) - if self.base_optimizer.iterations is not None: - self.base_optimizer.iterations.assign(0) - self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) @@ -597,3 +573,13 @@ def variables(self) -> List: def learning_rate(self) -> Any: """Provides access to the learning rate of the base optimizer.""" return self.base_optimizer.learning_rate + + @property + def iterations(self): + """ + Returns the training iteration count, compensating for the initial + dummy step in the build method. + """ + if self.base_optimizer.iterations is None: + return None + return self.base_optimizer.iterations - 1 \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index ca69361fe383..46579d4147aa 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -1,12 +1,22 @@ import numpy as np -from coordinated_optimizer import CoordinatedOptimizer -from coordinated_optimizer import TensorParallelOptimizer +import pytest import keras +from keras import ops from keras.src import optimizers from keras.src import testing - - +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + keras.backend.backend() == "openvino", + reason="CoordinatedOptimizer is not yet supported on the OpenVINO backend.", +) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): """Creates a simple, uncompiled Keras model.""" @@ -23,7 +33,7 @@ def _get_mock_gradients_and_vars(self, model, world_size): for i in range(world_size): multiplier = float(i + 1) gradients = [ - keras.ops.convert_to_tensor( + ops.convert_to_tensor( np.ones_like(v.numpy()) * multiplier, dtype="float32" ) for v in variables @@ -43,6 +53,7 @@ def test_initialization(self): def test_apply_gradients_with_replicated_states(self): """Tests that replicated gradients are averaged and applied once.""" + class AdamWithCallCounter(optimizers.Adam): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -52,6 +63,8 @@ def __init__(self, *args, **kwargs): def apply_gradients(self, grads_and_vars, *args, **kwargs): self.apply_gradients_call_count += 1 self.received_grads = [g for g, v in grads_and_vars] + # Call the superclass method to ensure variables are updated + super().apply_gradients(grads_and_vars, *args, **kwargs) world_size = 4 model = self._get_simple_model() @@ -68,6 +81,7 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): coord.apply_gradients(mock_grads, []) self.assertEqual(optimizer.apply_gradients_call_count, 1) + # The average of multipliers 1, 2, 3, 4 is (1+2+3+4)/4 = 10/4 = 2.5 self.assertAllClose( optimizer.received_grads[0], np.ones_like(optimizer.received_grads[0]) * 2.5, @@ -90,13 +104,18 @@ def test_apply_gradients_delegation(self): mock_grads = self._get_mock_gradients_and_vars(model, world_size) coord_apply_tracker = {"called": False} - optimizer.coordinated_optimizer.apply_gradients = ( - lambda *a, **kw: coord_apply_tracker.update({"called": True}) - ) + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + base_apply_tracker = {"called": False} - optimizer.base_optimizer.apply_gradients = ( - lambda *a, **kw: base_apply_tracker.update({"called": True}) - ) + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock optimizer.apply_gradients(mock_grads, shard_models=[]) self.assertTrue(coord_apply_tracker["called"]) @@ -108,7 +127,6 @@ def test_apply_gradients_delegation(self): self.assertTrue(base_apply_tracker["called"]) self.assertFalse(coord_apply_tracker["called"]) - def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" optimizer = TensorParallelOptimizer( @@ -123,14 +141,15 @@ def test_build_and_state_sharding(self): self.assertTrue(optimizer.built) sharded_states = optimizer.coordinated_optimizer.sharded_states - self.assertIn("momentum", sharded_states) self.assertIn("velocity", sharded_states) self.assertIn("iterations", sharded_states) dense_1_kernel_path = model.get_layer("dense_1").kernel.path self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) - self.assertEqual(len(sharded_states["momentum"][dense_1_kernel_path]), 4) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) def test_serialization(self): world_size = 4 diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py new file mode 100644 index 000000000000..d6d08a524f5b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -0,0 +1,85 @@ +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Sequence + +from keras.src.distribution.tensor_parallel.config import ConfigKeras + + +class ShardedKeras: + """ + Manages sharded parameters for Keras models. + """ + + def __init__( + self, + model_shards, + replicated_param_names: Collection[str], + tensor_parallel_config: ConfigKeras, + devices: Sequence[str], + output_device_index: int, + ): + """ + Initialize the sharding manager. + + Args: + model_shards: List of model shards + replicated_param_names: Names of parameters that are replicated + tensor_parallel_config: Tensor parallel configuration + devices: List of device IDs + output_device_index: Index of the output device + """ + self.model_shards = model_shards + self.replicated_param_names = set(replicated_param_names) + self.tensor_parallel_config = tensor_parallel_config + self.devices = devices + self.output_device_index = output_device_index + + def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: + """ + Get parameters for a specific shard. + + Args: + shard_index: Index of the shard + + Returns: + Dictionary of parameter names to values + """ + if shard_index >= len(self.model_shards): + raise ValueError(f"Shard index {shard_index} out of range") + + shard = self.model_shards[shard_index] + params = {} + + for layer in shard.layers: + name = layer.name + if hasattr(layer, "weights") and layer.weights: + for i, weight in enumerate(layer.weights): + param_name = f"{name}.weight_{i}" + params[param_name] = weight + + return params + + def get_all_parameters(self) -> List[Dict[str, Any]]: + """ + Get parameters from all shards. + + Returns: + List of parameter dictionaries for each shard + """ + return [ + self.get_shard_parameters(i) for i in range(len(self.model_shards)) + ] + + def apply_sharding(self): + """ + Apply sharding to the model parameters. + """ + pass + + def unshard_parameters(self): + """ + Unshard parameters back to their original form. + """ + pass \ No newline at end of file From b7862d9c3a592f636c435d2e7b2160349713d132 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 2 Oct 2025 15:46:52 +0530 Subject: [PATCH 04/12] Reformatting files --- keras/src/distribution/tensor_parallel/autoconfig.py | 2 +- keras/src/distribution/tensor_parallel/autoconfig_test.py | 2 +- keras/src/distribution/tensor_parallel/coordinated_optimizer.py | 2 +- .../distribution/tensor_parallel/coordinated_optimizer_test.py | 2 +- keras/src/distribution/tensor_parallel/sharding_keras.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index cf5966eb4670..6e90b10a0bc4 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -217,4 +217,4 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: prefix="", ) - return ConfigKeras(state_rules=state_rules, output_rules=output_rules) \ No newline at end of file + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index ab9a1b4149c1..ee7519c607ef 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -148,4 +148,4 @@ def test_nested_model_sharding(self): } self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 726747676c0b..18661f860c6e 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -582,4 +582,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 \ No newline at end of file + return self.base_optimizer.iterations - 1 diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 46579d4147aa..c4249d147d73 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -164,4 +164,4 @@ def test_serialization(self): self.assertEqual(recreated.world_size, world_size) self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) - self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) \ No newline at end of file + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py index d6d08a524f5b..ace810adb024 100644 --- a/keras/src/distribution/tensor_parallel/sharding_keras.py +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -82,4 +82,4 @@ def unshard_parameters(self): """ Unshard parameters back to their original form. """ - pass \ No newline at end of file + pass From 3383dec7736e8d34bde2263d324d256c9344a070 Mon Sep 17 00:00:00 2001 From: Suhana Date: Fri, 3 Oct 2025 11:55:32 +0530 Subject: [PATCH 05/12] Reformatting according to changes in distributed_backend --- .../tensor_parallel/autoconfig.py | 19 ++- .../tensor_parallel/autoconfig_test.py | 33 +++-- .../tensor_parallel/coordinated_optimizer.py | 140 ++++++++++-------- .../coordinated_optimizer_test.py | 63 ++++---- .../tensor_parallel/sharding_keras.py | 9 +- 5 files changed, 146 insertions(+), 118 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 6e90b10a0bc4..9fa6db430c35 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -206,15 +206,14 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: output_rules = {} processed_layers = set() - for layer in module.layers: - _traverse_and_shard_layer( - current_layer=layer, - module=module, - world_size=world_size, - state_rules=state_rules, - output_rules=output_rules, - processed_layers=processed_layers, - prefix="", - ) + _traverse_and_shard_layer( + current_layer=module, + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + prefix="", + ) return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index ee7519c607ef..96467da847e0 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -6,7 +6,7 @@ from keras import Model from keras import layers from keras.src import testing -from keras.src.backend.distributed import backend_resolver +from keras.src.distribution import distributed_backend from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, ) @@ -20,8 +20,7 @@ class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() - backend = backend_resolver.get_distributed_backend() - device_info = backend.get_device_info() + device_info = distributed_backend.get_device_info() self.world_size = device_info["device_count"] self.device_ids = [f"device:{i}" for i in range(self.world_size)] @@ -78,19 +77,19 @@ def test_simple_mlp_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^up_projection_layer.kernel$": SplitKeras( + r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( self.world_size, 1, "column" ), - r"^up_projection_layer.bias$": SplitKeras( + r"^simple_mlp.up_projection_layer.bias$": SplitKeras( self.world_size, 0, "column" ), - r"^down_projection_layer.kernel$": SplitKeras( + r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), } expected_output_rules = { - r"^up_projection_layer$": {0: "no_comm"}, - r"^down_projection_layer$": {0: "allreduce"}, + r"^simple_mlp.up_projection_layer$": {0: "no_comm"}, + r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) @@ -107,11 +106,13 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^token_embedding\.embeddings$": SplitKeras( + r"^embed_model.token_embedding\.embeddings$": SplitKeras( self.world_size, 1, "column" ) } - expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -134,17 +135,19 @@ def test_nested_model_sharding(self): config = get_default_config_keras(outer_model, self.device_ids) expected_state_rules = { - r"^inner_block.inner_dense.kernel$": SplitKeras( + r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" ), - r"^inner_block.inner_dense.bias$": SplitKeras( + r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( self.world_size, 0, "column" ), - r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), + r"^outer_model.outer_dense.kernel$": SplitKeras( + self.world_size, 0, "row" + ), } expected_output_rules = { - r"^inner_block.inner_dense$": {0: "no_comm"}, - r"^outer_dense$": {0: "allreduce"}, + r"^outer_model.inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_model.outer_dense$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 18661f860c6e..260d719d3985 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -1,14 +1,12 @@ import re from typing import Any -from typing import Dict -from typing import List import numpy as np import keras from keras.src import ops from keras.src import optimizers -from keras.src.backend.distributed import backend_resolver +from keras.src.distribution import distributed_backend class CoordinatedOptimizer: @@ -47,50 +45,31 @@ def __init__( ): self.base_optimizer = base_optimizer self.world_size = world_size - self.rank = rank self.shard_optimizer_states = shard_optimizer_states self.tensor_parallel_config = tensor_parallel_config self.sharded_states = {} self._state_variable_to_parameter = {} - self.distributed_backend = ( - backend_resolver.get_distributed_backend(distributed_backend) - if distributed_backend is not None - else None - ) self._variables = None self._variable_to_slot_name = {} - def _get_optimizer_slot_names(self) -> set: - """ - Deduces the slot names ('m', 'v', etc.) by inspecting the variables - created by the base optimizer. This is the most robust method. - """ - slot_names = set() - for var in self.base_optimizer.variables: - if "iteration" in var.path.lower(): - continue - path_parts = var.path.split("/") - if len(path_parts) > 1: - slot_names.add(path_parts[1]) - return slot_names - def _initialize_sharded_states(self): """ Partitions the optimizer's state variables across shards by inspecting - the variables created by the base optimizer. This version correctly - parses variable paths like 'optimizer/param_name_slot_name'. + the variables created by the base optimizer. """ if not self.shard_optimizer_states or not self.base_optimizer.built: return self.sharded_states = {} self._state_variable_to_parameter = {} - self._variable_to_slot_name = {} # Reset the map + self._variable_to_slot_name = {} opt_name = self.base_optimizer.name - normalized_params = [ - (p.path.replace("/", "_"), p) for p in self._variables - ] + normalized_params = sorted( + [(p.path.replace("/", "_"), p) for p in self._variables], + key=lambda x: len(x[0]), + reverse=True, + ) for state_var in self.base_optimizer.variables: if state_var is self.base_optimizer.iterations: @@ -113,7 +92,6 @@ def _initialize_sharded_states(self): if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param - # MODIFIED: Store the mapping from variable path to slot name self._variable_to_slot_name[state_var.path] = slot_name sharding_dim = 0 @@ -138,7 +116,7 @@ def _initialize_sharded_states(self): def _partition_state( self, state_variable: any, dim: int - ) -> List[np.ndarray]: + ) -> list[np.ndarray]: """Splits a single state variable numpy array into chunks. If the variable cannot be split along the given dimension, it is @@ -158,7 +136,7 @@ def _partition_state( else: return [np.copy(state_array) for _ in range(self.world_size)] - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: return { "base_optimizer": self.base_optimizer.get_config(), "world_size": self.world_size, @@ -166,7 +144,7 @@ def get_config(self) -> Dict[str, Any]: } def apply_gradients( - self, gradients_and_vars: List[List[tuple]], shard_models: List + self, gradients_and_vars: list[list[tuple]], shard_models: list ): """Coordinates gradient synchronization and application. @@ -202,7 +180,7 @@ def apply_gradients( ) def _apply_gradients_with_replicated_states( - self, synchronized_gradients: List[List[tuple]], shard_models: List + self, synchronized_gradients: list[list[tuple]], shard_models: list ): """Averages gradients across all shards and applies them once. @@ -240,16 +218,23 @@ def _apply_gradients_with_replicated_states( if averaged_grads_and_vars: self.base_optimizer.apply_gradients(averaged_grads_and_vars) - def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: - """Constructs the state dictionary for a single shard. + def _apply_gradients_with_sharded_states( + self, synchronized_gradients: list[list[tuple]], shard_models: list + ): + """Applies gradients to each shard using its local optimizer state.""" + for shard_idx in range(self.world_size): + local_states = self._get_local_optimizer_states(shard_idx) + shard_optimizer = shard_models[shard_idx].optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) - Args: - shard_idx: The index of the shard for which to retrieve the state. + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) - Returns: - A dictionary containing the optimizer state variables specific to - the given shard index. - """ + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx: int) -> dict[str, Any]: + """Constructs the state dictionary for a single shard.""" local_states = {} for state_name, state_value in self.sharded_states.items(): if isinstance(state_value, dict): @@ -286,9 +271,34 @@ def _update_optimizer_internal_state(self, optimizer, local_states: dict): if var.shape == local_param_state.shape: ops.assign(var, local_param_state) + def _update_global_sharded_states(self, optimizer, shard_idx: int): + """Updates the main sharded_states dictionary after a gradient step.""" + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + def _synchronize_gradients( - self, gradients_and_vars: List[List[tuple]] - ) -> List[List[tuple]]: + self, gradients_and_vars: list[list[tuple]] + ) -> list[list[tuple]]: """Synchronizes gradients across shards based on tensor parallel rules. Specifically, it performs an all-reduce operation on gradients of @@ -339,7 +349,7 @@ def _synchronize_gradients( ) return gradients_and_vars - def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: + def _allreduce_gradients(self, gradients: list[Any]) -> list[Any]: """Performs a mean all-reduce operation on a list of gradients. If a distributed backend is available, it uses it. Otherwise, it @@ -354,11 +364,12 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: if not gradients: return [] - if self.distributed_backend is not None: + if distributed_backend.is_multi_device_capable(): + all_reduce_fn = distributed_backend.get_communication_ops()[ + "all_reduce" + ] numpy_grad = ops.convert_to_numpy(gradients[0]) - synced_numpy = self.distributed_backend.all_reduce( - numpy_grad, op="mean" - ) + synced_numpy = all_reduce_fn(numpy_grad, op="mean") synced_tensor = ops.convert_to_tensor(synced_numpy) return [synced_tensor for _ in range(self.world_size)] @@ -368,17 +379,17 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: mean_grad = ops.mean(stacked_grads, axis=0) return [mean_grad for _ in range(len(gradients))] - def get_weights(self) -> List[np.ndarray]: + def get_weights(self) -> list[np.ndarray]: """Returns the weights of the base optimizer.""" return [ ops.convert_to_numpy(var) for var in self.base_optimizer.variables ] - def set_weights(self, weights: List[np.ndarray]): + def set_weights(self, weights: list[np.ndarray]): """Sets the weights of the base optimizer.""" self.base_optimizer.set_weights(weights) - def enable_optimizer_state_sharding(self, variables: List): + def enable_optimizer_state_sharding(self, variables: list): """Enables and initializes optimizer state sharding. This method is called from `build()`, which is guarded from running @@ -426,7 +437,7 @@ class TensorParallelOptimizer(optimizers.Optimizer): import keras # Assume model variables and gradients from 4 shards exist. - # The structure is: List[List[Tuple[gradient, variable]]] + # The structure is: list[list[tuple[gradient, variable]]] trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] sharded_grads_and_vars = [ [(keras.ops.ones_like(v), v) for v in trainable_vars] @@ -477,7 +488,7 @@ def __init__( tensor_parallel_config=tensor_parallel_config, ) - def apply_gradients(self, grads_and_vars: List, **kwargs): + def apply_gradients(self, grads_and_vars: list, **kwargs): """Applies gradients to the model variables. If `grads_and_vars` is a list of lists, it's assumed to be from @@ -490,11 +501,12 @@ def apply_gradients(self, grads_and_vars: List, **kwargs): **kwargs: Additional arguments. `shard_models` can be passed to provide the list of model shards. """ - if ( + is_sharded_grads = ( isinstance(grads_and_vars, list) and grads_and_vars and isinstance(grads_and_vars[0], list) - ): + ) + if is_sharded_grads: shard_models = kwargs.get("shard_models", []) self.coordinated_optimizer.apply_gradients( grads_and_vars, shard_models @@ -502,7 +514,7 @@ def apply_gradients(self, grads_and_vars: List, **kwargs): else: self.base_optimizer.apply_gradients(grads_and_vars) - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: from keras.src import saving config = super().get_config() @@ -521,7 +533,7 @@ def get_config(self) -> Dict[str, Any]: return config @classmethod - def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": + def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": from keras.src import saving base_optimizer_config = config.pop("base_optimizer") @@ -535,7 +547,7 @@ def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": return cls(base_optimizer=base_optimizer, **init_kwargs) - def build(self, variables: List): + def build(self, variables: list): """Builds the optimizer and initializes sharded states. This method is called the first time the optimizer is used. It builds @@ -556,16 +568,16 @@ def build(self, variables: List): self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) - def get_weights(self) -> List[np.ndarray]: + def get_weights(self) -> list[np.ndarray]: """Returns the weights of the base optimizer.""" return self.coordinated_optimizer.get_weights() - def set_weights(self, weights: List[np.ndarray]): + def set_weights(self, weights: list[np.ndarray]): """Sets the weights of the base optimizer.""" self.coordinated_optimizer.set_weights(weights) @property - def variables(self) -> List: + def variables(self) -> list: """Returns the list of variables from the base optimizer.""" return self.base_optimizer.variables @@ -574,6 +586,10 @@ def learning_rate(self) -> Any: """Provides access to the learning rate of the base optimizer.""" return self.base_optimizer.learning_rate + @learning_rate.setter + def learning_rate(self, value): + self.base_optimizer.learning_rate = value + @property def iterations(self): """ diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index c4249d147d73..39cce46de72c 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -5,17 +5,19 @@ from keras import ops from keras.src import optimizers from keras.src import testing -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - CoordinatedOptimizer, -) -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - TensorParallelOptimizer, -) + +if keras.backend.backend() == "jax": + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, + ) + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, + ) @pytest.mark.skipif( - keras.backend.backend() == "openvino", - reason="CoordinatedOptimizer is not yet supported on the OpenVINO backend.", + keras.backend.backend() != "jax", + reason="This test is JAX-specific.", ) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): @@ -44,9 +46,7 @@ def _get_mock_gradients_and_vars(self, model, world_size): def test_initialization(self): """Tests that the optimizer initializes with the correct defaults.""" base_optimizer = optimizers.Adam() - coord = CoordinatedOptimizer( - base_optimizer, world_size=4, distributed_backend=None - ) + coord = CoordinatedOptimizer(base_optimizer, world_size=4) self.assertEqual(coord.base_optimizer, base_optimizer) self.assertTrue(coord.shard_optimizer_states) self.assertEqual(coord.sharded_states, {}) @@ -63,7 +63,6 @@ def __init__(self, *args, **kwargs): def apply_gradients(self, grads_and_vars, *args, **kwargs): self.apply_gradients_call_count += 1 self.received_grads = [g for g, v in grads_and_vars] - # Call the superclass method to ensure variables are updated super().apply_gradients(grads_and_vars, *args, **kwargs) world_size = 4 @@ -76,30 +75,24 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): optimizer, world_size, shard_optimizer_states=False, - distributed_backend=None, ) coord.apply_gradients(mock_grads, []) self.assertEqual(optimizer.apply_gradients_call_count, 1) - # The average of multipliers 1, 2, 3, 4 is (1+2+3+4)/4 = 10/4 = 2.5 self.assertAllClose( optimizer.received_grads[0], np.ones_like(optimizer.received_grads[0]) * 2.5, ) def test_init_from_string(self): - optimizer = TensorParallelOptimizer( - "adam", world_size=4, distributed_backend=None - ) + optimizer = TensorParallelOptimizer("adam", world_size=4) self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) def test_apply_gradients_delegation(self): """Tests that apply_gradients correctly delegates.""" world_size = 4 base_opt = optimizers.Adam() - optimizer = TensorParallelOptimizer( - base_opt, world_size, distributed_backend=None - ) + optimizer = TensorParallelOptimizer(base_opt, world_size) model = self._get_simple_model() mock_grads = self._get_mock_gradients_and_vars(model, world_size) @@ -129,11 +122,8 @@ def base_apply_mock(*args, **kwargs): def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" - optimizer = TensorParallelOptimizer( - optimizers.Adam(), world_size=4, distributed_backend=None - ) + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) model = self._get_simple_model() - model.build(input_shape=(None, 10)) self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) @@ -163,5 +153,28 @@ def test_serialization(self): self.assertEqual(recreated.world_size, world_size) self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) - self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) + self.assertIsNone(recreated.distributed_backend) self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py index ace810adb024..012234cb77f4 100644 --- a/keras/src/distribution/tensor_parallel/sharding_keras.py +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -52,12 +52,9 @@ def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: shard = self.model_shards[shard_index] params = {} - for layer in shard.layers: - name = layer.name - if hasattr(layer, "weights") and layer.weights: - for i, weight in enumerate(layer.weights): - param_name = f"{name}.weight_{i}" - params[param_name] = weight + for weight in shard.weights: + param_name = weight.path.replace("/", ".") + params[param_name] = weight return params From 5824c66627617e25b087144060da34658562cb36 Mon Sep 17 00:00:00 2001 From: Suhana Date: Fri, 3 Oct 2025 11:57:46 +0530 Subject: [PATCH 06/12] Reformatting according to changes in distributed_backend --- keras/src/distribution/tensor_parallel/coordinated_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 260d719d3985..94b6730b2f20 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -39,7 +39,6 @@ def __init__( base_optimizer: optimizers.Optimizer, world_size: int, distributed_backend: str = "auto", - rank: int = 0, shard_optimizer_states: bool = True, tensor_parallel_config=None, ): From 9cf5c7fe31ff7d7e8affeb4f8b061686b8c745ba Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 19:52:51 +0530 Subject: [PATCH 07/12] Refactoring the code --- .../tensor_parallel/autoconfig.py | 205 ++++++++++++------ .../tensor_parallel/autoconfig_test.py | 123 +++++++++-- .../tensor_parallel/coordinated_optimizer.py | 14 ++ 3 files changed, 260 insertions(+), 82 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 9fa6db430c35..0100aeaf7a5e 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,12 +1,13 @@ +from typing import Any +from typing import Dict from typing import Sequence +from typing import Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras def analyze_dense_layer_directly(layer, module, prefix: str) -> str: - from keras.src import layers - """Analyzes a Dense layer to classify it for tensor parallelism sharding. This function inspects the layer's weight shapes to determine if it's an @@ -23,20 +24,24 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: A string indicating the layer's classification: 'up_projection', 'down_projection', or 'generic_dense'. """ + from keras.src import layers + if not isinstance(layer, layers.Dense): return "generic_dense" input_dim = None output_dim = None - if hasattr(layer, "kernel"): + if hasattr(layer, "kernel") and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: - input_dim = kernel_shape[0] - output_dim = kernel_shape[1] - else: + input_dim, output_dim = kernel_shape + + if input_dim is None or output_dim is None: if hasattr(layer, "units"): output_dim = layer.units + else: + return "generic_dense" if ( hasattr(layer, "input_shape") @@ -44,6 +49,8 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: and len(layer.input_shape) > 1 ): input_dim = layer.input_shape[-1] + else: + return "generic_dense" if not input_dim or not output_dim: return "generic_dense" @@ -60,34 +67,33 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: return "generic_dense" -def _traverse_and_shard_layer( +def _find_and_shard_layers( current_layer, + prefix: str, module, world_size: int, - state_rules: dict, - output_rules: dict, - processed_layers: set, - prefix: str = "", + state_rules: Dict[str, Any], + output_rules: Dict[str, Any], + processed_layers: Set[int], ): - from keras.src import layers + """Recursively traverses a Keras model to generate sharding rules. - """Traverses a layer and its sub-layers to apply sharding rules. - - This function navigates through the model's layer hierarchy. For each - layer, it identifies its type and applies appropriate sharding logic, - populating the `state_rules` and `output_rules` dictionaries. + This is an internal helper function that navigates through all layers of a + model, including nested ones. For each supported layer, it determines the + appropriate sharding strategy and populates the `state_rules` and + `output_rules` dictionaries. These dictionaries are modified in place. Args: - current_layer: The current keras.Layer object to be processed. - module: The top-level Keras Model, used for context analysis. - world_size: The total number of devices for sharding. - state_rules: The dictionary of state sharding rules to populate. - output_rules: The dictionary of output sharding rules to populate. - processed_layers: A set of layer IDs that have already been processed - to avoid redundant computation and infinite loops. - prefix: The hierarchical name prefix from parent layers, used to - construct the full unique name for the current layer. + current_layer: The Keras layer to be processed in the current step. + prefix: The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model being analyzed. + world_size: The total number of devices to shard the model across. + state_rules: A dictionary with sharding rules for weights. + output_rules: A dictionary with communication rules for outputs. + processed_layers: A set of layer IDs to prevent infinite loops. """ + from keras.src import layers + if id(current_layer) in processed_layers: return processed_layers.add(id(current_layer)) @@ -100,10 +106,24 @@ def _traverse_and_shard_layer( current_layer, module, full_name ) - if mlp_type == "down_projection": + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather"} + + elif mlp_type == "down_projection": state_rules[f"^{full_name}.kernel$"] = SplitKeras( world_size, 0, "row" ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, -1, "replicated" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: @@ -114,27 +134,21 @@ def _traverse_and_shard_layer( state_rules[f"^{full_name}.bias$"] = SplitKeras( world_size, 0, "column" ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): - is_row_parallel = False - if "->" in current_layer.equation: - equation_parts = current_layer.equation.split("->") - if len(equation_parts) == 2: - input_spec = equation_parts[0].split(",")[0].strip() - output_spec = equation_parts[1].strip() - if ( - input_spec - and output_spec - and len(output_spec) < len(input_spec) - ): - is_row_parallel = True - - if is_row_parallel: + if "attention_output" in full_name or "out_proj" in full_name: state_rules[f"^{full_name}.kernel$"] = SplitKeras( world_size, 0, "row" ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, -1, "replicated" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: state_rules[f"^{full_name}.kernel$"] = SplitKeras( @@ -147,18 +161,45 @@ def _traverse_and_shard_layer( state_rules[f"^{full_name}.bias$"] = SplitKeras( world_size, 0, "column" ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.Embedding): - weight_name = ( - "embeddings" if hasattr(current_layer, "embeddings") else None + state_rules[f"^{full_name}.embeddings$"] = SplitKeras( + world_size, 0, "vocab_parallel" ) - if weight_name: - state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( - world_size, 1, "column" + output_rules[f"^{full_name}$"] = {0: "allreduce"} + return + + elif isinstance(current_layer, layers.MultiHeadAttention): + for proj in ["query", "key", "value"]: + proj_dense_name = f"_{proj}_dense" + if hasattr(current_layer, proj_dense_name): + state_rules[f"^{full_name}\.{proj_dense_name}\.kernel$"] = ( + SplitKeras(world_size, 1, "column") + ) + if getattr(current_layer, proj_dense_name).use_bias: + state_rules[f"^{full_name}\.{proj_dense_name}\.bias$"] = ( + SplitKeras(world_size, 0, "column") + ) + + output_dense_name = "_output_dense" + if hasattr(current_layer, output_dense_name): + state_rules[f"^{full_name}\.{output_dense_name}\.kernel$"] = ( + SplitKeras(world_size, 0, "row") ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + if getattr(current_layer, output_dense_name).use_bias: + state_rules[f"^{full_name}\.{output_dense_name}\.bias$"] = ( + SplitKeras(world_size, -1, "replicated") + ) + + output_rules[f"^{full_name}$"] = {0: "allreduce"} + return + + elif isinstance(current_layer, layers.Dropout): + if "rng_rules" not in state_rules: + state_rules["rng_rules"] = {} + state_rules["rng_rules"][full_name] = {"type": "parallel"} return elif isinstance( @@ -170,50 +211,80 @@ def _traverse_and_shard_layer( ), ): return - else: - if hasattr(current_layer, "layers"): - for sub_layer in current_layer.layers: - _traverse_and_shard_layer( - sub_layer, + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + _find_and_shard_layers( + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + for attr_name in dir(current_layer): + if attr_name.startswith("_"): + continue + try: + attr = getattr(current_layer, attr_name) + if isinstance(attr, layers.Layer) and attr is not current_layer: + _find_and_shard_layers( + attr, + full_name, module, world_size, state_rules, output_rules, processed_layers, - full_name, ) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + _find_and_shard_layers( + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + except Exception: + continue def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a smart, recursive sharding configuration for a Keras model. + """Generates a default sharding configuration for a Keras model. - This function traverses the layers of a given Keras model and applies a - set of heuristics to automatically determine how each layer's weights - and outputs should be sharded for tensor parallelism. It uses a helper - function to perform the recursive traversal. + This function serves as the main entry point for automatically creating a + tensor parallel sharding configuration. It traverses the model and applies + standard sharding patterns for common layer types like Dense, Embedding, and + MultiHeadAttention. Args: - module: The Keras Model to generate a sharding configuration for. - device_ids: A sequence of device identifiers, used to determine the - world size (number of devices) for sharding. + module: The Keras model or layer to be configured for sharding. + device_ids: A sequence of device IDs (e.g., `['gpu:0', 'gpu:1']`) + to shard across. The number of devices determines the `world_size`. Returns: - A ConfigKeras object containing the generated 'state_rules' (for model - parameters) and 'output_rules' (for layer outputs). + A `ConfigKeras` object containing the generated `state_rules` for + sharding weights and `output_rules` for handling communications. """ world_size = len(device_ids) state_rules = {} output_rules = {} processed_layers = set() - _traverse_and_shard_layer( + _find_and_shard_layers( current_layer=module, + prefix="", module=module, world_size=world_size, state_rules=state_rules, output_rules=output_rules, processed_layers=processed_layers, - prefix="", ) return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 96467da847e0..6845f6000982 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,10 +1,13 @@ import os +import pytest + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" from keras import Input from keras import Model from keras import layers +from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend from keras.src.distribution.tensor_parallel.autoconfig import ( @@ -16,22 +19,24 @@ from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", +) class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() device_info = distributed_backend.get_device_info() self.world_size = device_info["device_count"] - self.device_ids = [f"device:{i}" for i in range(self.world_size)] + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] self.assertGreater( self.world_size, 1, "Distribution tests require more than 1 device." ) def _assert_split_keras_equal(self, rule1, rule2): - """ - Helper to compare two SplitKeras objects by their attributes. - """ + """Helper to compare two SplitKeras objects by their attributes.""" self.assertIsInstance(rule1, SplitKeras) self.assertIsInstance(rule2, SplitKeras) self.assertDictEqual(vars(rule1), vars(rule2)) @@ -65,13 +70,20 @@ def test_analyze_dense_layer(self): "down_projection", ) + generic_layer = layers.Dense(20) + generic_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", + ) + def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense( - 64, name="down_projection_layer", use_bias=False - )(x) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -86,17 +98,43 @@ def test_simple_mlp_sharding(self): r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), + r"^simple_mlp.down_projection_layer.bias$": SplitKeras( + self.world_size, -1, "replicated" + ), } expected_output_rules = { - r"^simple_mlp.up_projection_layer$": {0: "no_comm"}, + r"^simple_mlp.up_projection_layer$": {0: "gather"}, r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_generic_dense_sharding(self): + """Tests a generic Dense layer that isn't an up/down projection.""" + inputs = Input(shape=(64,)) + outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) + model = Model(inputs=inputs, outputs=outputs, name="generic_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^generic_model.generic_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^generic_model.generic_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + } + expected_output_rules = { + r"^generic_model.generic_layer$": {0: "gather -1"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_embedding_sharding(self): - """Tests an Embedding layer.""" + """Tests an Embedding layer for vocabulary parallelism.""" inputs = Input(shape=(10,), dtype="int32") outputs = layers.Embedding( input_dim=1000, output_dim=128, name="token_embedding" @@ -106,28 +144,79 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^embed_model.token_embedding\.embeddings$": SplitKeras( - self.world_size, 1, "column" + r"^embed_model.token_embedding.embeddings$": SplitKeras( + self.world_size, 0, "vocab_parallel" ) } expected_output_rules = { - r"^embed_model.token_embedding$": {0: "no_comm"} + r"^embed_model.token_embedding$": {0: "allreduce"} } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_einsum_dense_sharding(self): + """Tests the special handling for EinsumDense layers.""" + inputs = Input(shape=(64,)) + x = layers.EinsumDense( + "bh,hd->bd", output_shape=128, name="query_proj" + )(inputs) + outputs = layers.EinsumDense( + "bd,dh->bh", output_shape=64, name="attention_output" + )(x) + model = Model(inputs=inputs, outputs=outputs, name="einsum_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^einsum_model.query_proj.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^einsum_model.attention_output.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^einsum_model.query_proj$": {0: "gather -1"}, + r"^einsum_model.attention_output$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_normalization_layers_ignored(self): + """Tests that normalization layers are correctly ignored.""" + inputs = Input(shape=(64,)) + x = layers.Dense(64, name="dense1", use_bias=True)(inputs) + x = layers.LayerNormalization(name="layernorm")(x) + outputs = layers.Dense(64, name="dense2", use_bias=True)(x) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") + + config = get_default_config_keras(model, self.device_ids) + + for key in config.state_rules: + self.assertNotIn("layernorm", key) + for key in config.output_rules: + self.assertNotIn("layernorm", key) + + self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) + self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) + self.assertEqual(len(config.state_rules), 4) + self.assertEqual(len(config.output_rules), 2) + def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) outer_inputs = Input(shape=(32,)) x = inner_model(outer_inputs) - outer_outputs = layers.Dense(32, name="outer_dense")(x) + outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) outer_model = Model( inputs=outer_inputs, outputs=outer_outputs, name="outer_model" ) @@ -144,11 +233,15 @@ def test_nested_model_sharding(self): r"^outer_model.outer_dense.kernel$": SplitKeras( self.world_size, 0, "row" ), + r"^outer_model.outer_dense.bias$": SplitKeras( + self.world_size, -1, "replicated" + ), } expected_output_rules = { - r"^outer_model.inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 94b6730b2f20..ca7f8e5d5fcc 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -39,6 +39,7 @@ def __init__( base_optimizer: optimizers.Optimizer, world_size: int, distributed_backend: str = "auto", + rank: int = 0, shard_optimizer_states: bool = True, tensor_parallel_config=None, ): @@ -531,6 +532,19 @@ def get_config(self) -> dict[str, Any]: ) return config + def update_step(self, gradient, variable, *args, **kwargs): + if hasattr(self.base_optimizer, "update_step"): + try: + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + except TypeError: + return self.base_optimizer.update_step(gradient, variable) + try: + return super().update_step(gradient, variable, *args, **kwargs) + except TypeError: + return super().update_step(gradient, variable) + @classmethod def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": from keras.src import saving From 996a154df5c14da463d94267e3c73d4c4971ecb6 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 20:49:36 +0530 Subject: [PATCH 08/12] refactoring --- .../tensor_parallel/autoconfig.py | 288 +++++++----------- .../tensor_parallel/autoconfig_test.py | 51 ++-- .../tensor_parallel/coordinated_optimizer.py | 10 +- .../coordinated_optimizer_test.py | 6 +- 4 files changed, 140 insertions(+), 215 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 0100aeaf7a5e..9b3a80726b75 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,70 +1,67 @@ -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Set +from typing import Sequence, Dict, Any, Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras def analyze_dense_layer_directly(layer, module, prefix: str) -> str: - """Analyzes a Dense layer to classify it for tensor parallelism sharding. + """Analyzes a Keras Dense layer to classify its sharding strategy. - This function inspects the layer's weight shapes to determine if it's an - "up-projection" (expanding feature dimensions), a "down-projection" - (contracting feature dimensions), or a generic layer. This classification - helps in deciding whether to apply column-wise or row-wise parallelism. + This function inspects the input and output dimensions of a Dense layer + to determine if it functions as an expansion layer ("up-projection"), a + contraction layer ("down-projection"), or neither ("generic_dense"). This + classification is a heuristic commonly used to apply tensor parallelism + in Transformer-based models, such as in an MLP block where an up-projection + is followed by a down-projection. Args: - layer: The keras.layers.Dense instance to analyze. - module: The parent Keras model containing the layer. - prefix: The hierarchical name prefix for the layer. + layer: The Keras `layers.Dense` instance to analyze. + module: The parent module containing the layer (currently unused). + prefix (str): The name prefix for the layer in the model hierarchy + (currently unused). Returns: - A string indicating the layer's classification: 'up_projection', + str: A string classifying the layer as 'up_projection', 'down_projection', or 'generic_dense'. """ from keras.src import layers if not isinstance(layer, layers.Dense): - return "generic_dense" + return 'generic_dense' input_dim = None output_dim = None - if hasattr(layer, "kernel") and layer.kernel is not None: + if hasattr(layer, 'kernel') and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: - input_dim, output_dim = kernel_shape + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, "units"): + if hasattr(layer, 'units'): output_dim = layer.units else: - return "generic_dense" + return 'generic_dense' - if ( - hasattr(layer, "input_shape") - and layer.input_shape - and len(layer.input_shape) > 1 - ): + if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: input_dim = layer.input_shape[-1] else: - return "generic_dense" + return 'generic_dense' if not input_dim or not output_dim: - return "generic_dense" + return 'generic_dense' expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return "up_projection" + return 'up_projection' elif is_contraction: - return "down_projection" + return 'down_projection' else: - return "generic_dense" + return 'generic_dense' def _find_and_shard_layers( @@ -76,21 +73,35 @@ def _find_and_shard_layers( output_rules: Dict[str, Any], processed_layers: Set[int], ): - """Recursively traverses a Keras model to generate sharding rules. - - This is an internal helper function that navigates through all layers of a - model, including nested ones. For each supported layer, it determines the - appropriate sharding strategy and populates the `state_rules` and - `output_rules` dictionaries. These dictionaries are modified in place. + """Recursively traverses the model graph to apply sharding rules. + + This function walks through all nested layers of a given Keras model or + layer. For each encountered layer, it determines an appropriate tensor + parallelism strategy and populates the `state_rules` and `output_rules` + dictionaries with the corresponding sharding actions. It uses a set of + processed layer IDs to avoid redundant processing of shared layers. + + The sharding logic is as follows: + - `Dense` layers are sharded based on their classification (up/down proj). + - Up-projections are split along the column axis (output features). + - Down-projections are split along the row axis (input features). + - `EinsumDense` layers in attention blocks are sharded similarly. + - `Embedding` layers are sharded column-wise for vocabulary parallelism. + - Normalization layers are ignored (replicated on all devices). Args: - current_layer: The Keras layer to be processed in the current step. - prefix: The hierarchical name prefix for the `current_layer`. - module: The top-level Keras model being analyzed. - world_size: The total number of devices to shard the model across. - state_rules: A dictionary with sharding rules for weights. - output_rules: A dictionary with communication rules for outputs. - processed_layers: A set of layer IDs to prevent infinite loops. + current_layer: The Keras layer currently being processed. + prefix (str): The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model or layer being configured. + world_size (int): The total number of devices for sharding. + state_rules (Dict[str, Any]): A dictionary to be populated with rules for + sharding layer weights (state). Keys are regex patterns matching + weight names, values are `SplitKeras` actions. + output_rules (Dict[str, Any]): A dictionary to be populated with rules + for handling layer outputs. Keys are regex patterns matching layer + names, values describe the communication op (e.g., 'allreduce'). + processed_layers (Set[int]): A set of `id()`s of layers that have + already been processed to prevent cycles and redundant work. """ from keras.src import layers @@ -102,175 +113,107 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly( - current_layer, module, full_name - ) + mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) - if mlp_type == "up_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + if mlp_type == 'up_projection': + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == "down_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) - if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, -1, "replicated" - ) + elif mlp_type == 'down_projection': + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): - if "attention_output" in full_name or "out_proj" in full_name: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, -1, "replicated" - ) + if "attention_output" in full_name: + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: + pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return - elif isinstance(current_layer, layers.Embedding): - state_rules[f"^{full_name}.embeddings$"] = SplitKeras( - world_size, 0, "vocab_parallel" - ) - output_rules[f"^{full_name}$"] = {0: "allreduce"} - return - - elif isinstance(current_layer, layers.MultiHeadAttention): - for proj in ["query", "key", "value"]: - proj_dense_name = f"_{proj}_dense" - if hasattr(current_layer, proj_dense_name): - state_rules[f"^{full_name}\.{proj_dense_name}\.kernel$"] = ( - SplitKeras(world_size, 1, "column") - ) - if getattr(current_layer, proj_dense_name).use_bias: - state_rules[f"^{full_name}\.{proj_dense_name}\.bias$"] = ( - SplitKeras(world_size, 0, "column") - ) - - output_dense_name = "_output_dense" - if hasattr(current_layer, output_dense_name): - state_rules[f"^{full_name}\.{output_dense_name}\.kernel$"] = ( - SplitKeras(world_size, 0, "row") - ) - if getattr(current_layer, output_dense_name).use_bias: - state_rules[f"^{full_name}\.{output_dense_name}\.bias$"] = ( - SplitKeras(world_size, -1, "replicated") - ) - - output_rules[f"^{full_name}$"] = {0: "allreduce"} - return - - elif isinstance(current_layer, layers.Dropout): - if "rng_rules" not in state_rules: - state_rules["rng_rules"] = {} - state_rules["rng_rules"][full_name] = {"type": "parallel"} - return - - elif isinstance( - current_layer, - ( - layers.LayerNormalization, - layers.BatchNormalization, - layers.GroupNormalization, - ), - ): + elif isinstance(current_layer, (layers.Embedding,)): + if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): + pass + else: + weight_name = None + if hasattr(current_layer, 'embeddings'): + weight_name = 'embeddings' + elif hasattr(current_layer, 'position_embeddings'): + weight_name = 'position_embeddings' + + if weight_name: + state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras(world_size, 1, "column") + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): return - if hasattr(current_layer, "layers") and current_layer.layers: + if hasattr(current_layer, 'layers') and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + sub_layer, full_name, module, world_size, + state_rules, output_rules, processed_layers ) for attr_name in dir(current_layer): - if attr_name.startswith("_"): + if attr_name.startswith('__') and attr_name.endswith('__'): continue - try: + if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) + if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + attr, full_name, module, world_size, + state_rules, output_rules, processed_layers ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + item, full_name, module, world_size, + state_rules, output_rules, processed_layers ) - except Exception: - continue - def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a default sharding configuration for a Keras model. + """Generates a default tensor parallelism sharding configuration for a model. + + This function serves as the entry point for automatically creating a sharding + plan for a given Keras model or layer. It initializes the rule dictionaries + and starts the recursive layer traversal to populate them based on a default + set of heuristics for common architectures like Transformers. - This function serves as the main entry point for automatically creating a - tensor parallel sharding configuration. It traverses the model and applies - standard sharding patterns for common layer types like Dense, Embedding, and - MultiHeadAttention. + Example: + ```python + model = MyTransformerModel() + device_ids = ["gpu:0", "gpu:1"] + sharding_config = get_default_config_keras(model, device_ids) + # sharding_config can now be used to distribute the model + ``` Args: - module: The Keras model or layer to be configured for sharding. - device_ids: A sequence of device IDs (e.g., `['gpu:0', 'gpu:1']`) - to shard across. The number of devices determines the `world_size`. + module: The Keras `Model` or `Layer` to generate a config for. + device_ids (Sequence[str]): A sequence of device IDs (e.g., + ["gpu:0", "gpu:1"]) across which the model will be sharded. Returns: - A `ConfigKeras` object containing the generated `state_rules` for - sharding weights and `output_rules` for handling communications. + ConfigKeras: A configuration object containing the generated sharding + rules for model weights (`state_rules`) and layer outputs + (`output_rules`). """ world_size = len(device_ids) state_rules = {} @@ -284,7 +227,10 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers, + processed_layers=processed_layers ) - return ConfigKeras(state_rules=state_rules, output_rules=output_rules) + return ConfigKeras( + state_rules=state_rules, + output_rules=output_rules + ) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 6845f6000982..8e549e11c74d 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,5 +1,4 @@ import os - import pytest os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" @@ -10,10 +9,9 @@ from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend + from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, -) -from keras.src.distribution.tensor_parallel.autoconfig import ( get_default_config_keras, ) from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -21,7 +19,7 @@ @pytest.mark.skipif( backend.backend() != "jax", - reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", + reason="Tensor Parallelism autoconfig tests are only for the JAX backend." ) class TestAutoConfigKeras(testing.TestCase): def setUp(self): @@ -43,9 +41,7 @@ def _assert_split_keras_equal(self, rule1, rule2): def _assert_rules_equal(self, actual_rules, expected_rules): """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual( - set(actual_rules.keys()), set(expected_rules.keys()) - ) + self.assertSetEqual(set(actual_rules.keys()), set(expected_rules.keys())) for key in expected_rules: actual_val = actual_rules[key] expected_val = expected_rules[key] @@ -59,31 +55,26 @@ def test_analyze_dense_layer(self): up_proj_layer = layers.Dense(32) up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), - "up_projection", + analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" ) down_proj_layer = layers.Dense(16) down_proj_layer.build(input_shape=(None, 32)) self.assertEqual( - analyze_dense_layer_directly(down_proj_layer, None, ""), - "down_projection", + analyze_dense_layer_directly(down_proj_layer, None, ""), "down_projection" ) generic_layer = layers.Dense(20) generic_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), - "generic_dense", + analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" ) def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( - x - ) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)(x) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -98,9 +89,6 @@ def test_simple_mlp_sharding(self): r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), - r"^simple_mlp.down_projection_layer.bias$": SplitKeras( - self.world_size, -1, "replicated" - ), } expected_output_rules = { r"^simple_mlp.up_projection_layer$": {0: "gather"}, @@ -144,13 +132,11 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^embed_model.token_embedding.embeddings$": SplitKeras( - self.world_size, 0, "vocab_parallel" + r"^embed_model.token_embedding\..*embeddings$": SplitKeras( + self.world_size, 1, "column" ) } - expected_output_rules = { - r"^embed_model.token_embedding$": {0: "allreduce"} - } + expected_output_rules = {r"^embed_model.token_embedding$": {0: "no_comm"}} self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -190,7 +176,9 @@ def test_normalization_layers_ignored(self): x = layers.Dense(64, name="dense1", use_bias=True)(inputs) x = layers.LayerNormalization(name="layernorm")(x) outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model(inputs=inputs, outputs=outputs, name="norm_model") + model = Model( + inputs=inputs, outputs=outputs, name="norm_model" + ) config = get_default_config_keras(model, self.device_ids) @@ -207,9 +195,7 @@ def test_normalization_layers_ignored(self): def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( - inner_inputs - ) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)(inner_inputs) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) @@ -222,7 +208,7 @@ def test_nested_model_sharding(self): ) config = get_default_config_keras(outer_model, self.device_ids) - + expected_state_rules = { r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" @@ -233,15 +219,12 @@ def test_nested_model_sharding(self): r"^outer_model.outer_dense.kernel$": SplitKeras( self.world_size, 0, "row" ), - r"^outer_model.outer_dense.bias$": SplitKeras( - self.world_size, -1, "replicated" - ), } expected_output_rules = { r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } - + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index ca7f8e5d5fcc..99fa58592076 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -531,13 +531,11 @@ def get_config(self) -> dict[str, Any]: } ) return config - + def update_step(self, gradient, variable, *args, **kwargs): - if hasattr(self.base_optimizer, "update_step"): + if hasattr(self.base_optimizer, 'update_step'): try: - return self.base_optimizer.update_step( - gradient, variable, *args, **kwargs - ) + return self.base_optimizer.update_step(gradient, variable, *args, **kwargs) except TypeError: return self.base_optimizer.update_step(gradient, variable) try: @@ -611,4 +609,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 + return self.base_optimizer.iterations - 1 \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 39cce46de72c..ba9438a5d9ab 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -173,8 +173,6 @@ def test_sharding_with_prefixed_variable_names(self): self.assertGreater(len(state_to_param), 0) dense_output_kernel = model.get_layer("dense_output").kernel - optimizer_name = optimizer.base_optimizer.name - kernel_path = dense_output_kernel.path.replace("/", "_") - momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + momentum_path = f"{optimizer.base_optimizer.name}/{dense_output_kernel.path.replace('/', '_')}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) + self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file From 31994dab60d055e96dbebe2ae40297cca3334e00 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 20:54:30 +0530 Subject: [PATCH 09/12] refactoring --- .../tensor_parallel/autoconfig.py | 151 ++++++++++++------ .../tensor_parallel/autoconfig_test.py | 41 +++-- .../tensor_parallel/coordinated_optimizer.py | 10 +- .../coordinated_optimizer_test.py | 6 +- 4 files changed, 139 insertions(+), 69 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 9b3a80726b75..32d6734860cc 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,4 +1,7 @@ -from typing import Sequence, Dict, Any, Set +from typing import Any +from typing import Dict +from typing import Sequence +from typing import Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -27,41 +30,45 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: from keras.src import layers if not isinstance(layer, layers.Dense): - return 'generic_dense' + return "generic_dense" input_dim = None output_dim = None - if hasattr(layer, 'kernel') and layer.kernel is not None: + if hasattr(layer, "kernel") and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: input_dim = kernel_shape[0] output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, 'units'): + if hasattr(layer, "units"): output_dim = layer.units else: - return 'generic_dense' + return "generic_dense" - if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): input_dim = layer.input_shape[-1] else: - return 'generic_dense' + return "generic_dense" if not input_dim or not output_dim: - return 'generic_dense' + return "generic_dense" expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return 'up_projection' + return "up_projection" elif is_contraction: - return 'down_projection' + return "down_projection" else: - return 'generic_dense' + return "generic_dense" def _find_and_shard_layers( @@ -94,10 +101,10 @@ def _find_and_shard_layers( prefix (str): The hierarchical name prefix for the `current_layer`. module: The top-level Keras model or layer being configured. world_size (int): The total number of devices for sharding. - state_rules (Dict[str, Any]): A dictionary to be populated with rules for + state_rules (Dict[str, Any]): A dictionary with rules for sharding layer weights (state). Keys are regex patterns matching weight names, values are `SplitKeras` actions. - output_rules (Dict[str, Any]): A dictionary to be populated with rules + output_rules (Dict[str, Any]): A dictionary with rules for handling layer outputs. Keys are regex patterns matching layer names, values describe the communication op (e.g., 'allreduce'). processed_layers (Set[int]): A set of `id()`s of layers that have @@ -113,86 +120,137 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) - if mlp_type == 'up_projection': - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == 'down_projection': - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") + elif mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): if "attention_output" in full_name: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, (layers.Embedding,)): - if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): + if hasattr(current_layer, "token_embedding") or hasattr( + current_layer, "position_embedding" + ): pass else: weight_name = None - if hasattr(current_layer, 'embeddings'): - weight_name = 'embeddings' - elif hasattr(current_layer, 'position_embeddings'): - weight_name = 'position_embeddings' + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" if weight_name: - state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras(world_size, 1, "column") + state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras( + world_size, 1, "column" + ) output_rules[f"^{full_name}$"] = {0: "no_comm"} return - elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): return - if hasattr(current_layer, 'layers') and current_layer.layers: + if hasattr(current_layer, "layers") and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, full_name, module, world_size, - state_rules, output_rules, processed_layers + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) for attr_name in dir(current_layer): - if attr_name.startswith('__') and attr_name.endswith('__'): + if attr_name.startswith("__") and attr_name.endswith("__"): continue if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, full_name, module, world_size, - state_rules, output_rules, processed_layers + attr, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, full_name, module, world_size, - state_rules, output_rules, processed_layers + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) + def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a default tensor parallelism sharding configuration for a model. + """Generates a default tensor parallel sharding configuration for a model. - This function serves as the entry point for automatically creating a sharding + This function serves as entry point for automatically creating a sharding plan for a given Keras model or layer. It initializes the rule dictionaries and starts the recursive layer traversal to populate them based on a default set of heuristics for common architectures like Transformers. @@ -227,10 +285,7 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers + processed_layers=processed_layers, ) - return ConfigKeras( - state_rules=state_rules, - output_rules=output_rules - ) \ No newline at end of file + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 8e549e11c74d..228a2b184569 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,4 +1,5 @@ import os + import pytest os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" @@ -9,9 +10,10 @@ from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend - from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( get_default_config_keras, ) from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -19,7 +21,7 @@ @pytest.mark.skipif( backend.backend() != "jax", - reason="Tensor Parallelism autoconfig tests are only for the JAX backend." + reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", ) class TestAutoConfigKeras(testing.TestCase): def setUp(self): @@ -41,7 +43,9 @@ def _assert_split_keras_equal(self, rule1, rule2): def _assert_rules_equal(self, actual_rules, expected_rules): """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual(set(actual_rules.keys()), set(expected_rules.keys())) + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) for key in expected_rules: actual_val = actual_rules[key] expected_val = expected_rules[key] @@ -55,26 +59,31 @@ def test_analyze_dense_layer(self): up_proj_layer = layers.Dense(32) up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", ) down_proj_layer = layers.Dense(16) down_proj_layer.build(input_shape=(None, 32)) self.assertEqual( - analyze_dense_layer_directly(down_proj_layer, None, ""), "down_projection" + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", ) generic_layer = layers.Dense(20) generic_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", ) def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)(x) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -136,7 +145,9 @@ def test_embedding_sharding(self): self.world_size, 1, "column" ) } - expected_output_rules = {r"^embed_model.token_embedding$": {0: "no_comm"}} + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -176,9 +187,7 @@ def test_normalization_layers_ignored(self): x = layers.Dense(64, name="dense1", use_bias=True)(inputs) x = layers.LayerNormalization(name="layernorm")(x) outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model( - inputs=inputs, outputs=outputs, name="norm_model" - ) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") config = get_default_config_keras(model, self.device_ids) @@ -195,7 +204,9 @@ def test_normalization_layers_ignored(self): def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)(inner_inputs) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) @@ -208,7 +219,7 @@ def test_nested_model_sharding(self): ) config = get_default_config_keras(outer_model, self.device_ids) - + expected_state_rules = { r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" @@ -224,7 +235,7 @@ def test_nested_model_sharding(self): r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } - + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 99fa58592076..ca7f8e5d5fcc 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -531,11 +531,13 @@ def get_config(self) -> dict[str, Any]: } ) return config - + def update_step(self, gradient, variable, *args, **kwargs): - if hasattr(self.base_optimizer, 'update_step'): + if hasattr(self.base_optimizer, "update_step"): try: - return self.base_optimizer.update_step(gradient, variable, *args, **kwargs) + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) except TypeError: return self.base_optimizer.update_step(gradient, variable) try: @@ -609,4 +611,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 \ No newline at end of file + return self.base_optimizer.iterations - 1 diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index ba9438a5d9ab..39cce46de72c 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -173,6 +173,8 @@ def test_sharding_with_prefixed_variable_names(self): self.assertGreater(len(state_to_param), 0) dense_output_kernel = model.get_layer("dense_output").kernel - momentum_path = f"{optimizer.base_optimizer.name}/{dense_output_kernel.path.replace('/', '_')}_momentum" + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file + self.assertIs(state_to_param[momentum_path], dense_output_kernel) From 8124b080968b5727550aa25be348230ffb4d7431 Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 15 Oct 2025 23:59:53 +0530 Subject: [PATCH 10/12] Testing PR1&2 --- keras/src/backend/__init__.py | 7 +- keras/src/backend/jax/__init__.py | 1 + keras/src/backend/jax/distributed_backend.py | 93 +++++ .../tensor_parallel/autoconfig.py | 169 +++------ .../tensor_parallel/autoconfig_test.py | 348 +++++++----------- .../tensor_parallel/coordinated_optimizer.py | 235 +++++++----- .../coordinated_optimizer_test.py | 29 +- .../tensor_parallel/sharding_keras.py | 82 ----- .../tensor_parallel/tensor_layout.py | 164 +++++++++ 9 files changed, 596 insertions(+), 532 deletions(-) create mode 100644 keras/src/backend/jax/distributed_backend.py delete mode 100644 keras/src/distribution/tensor_parallel/sharding_keras.py create mode 100644 keras/src/distribution/tensor_parallel/tensor_layout.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..c89e7d82c90a 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -45,16 +47,19 @@ from keras.src.backend.torch.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") @@ -74,4 +79,4 @@ class name_scope(backend_name_scope): @keras_export("keras.device") def device(device_name): - return device_scope(device_name) # noqa: F405 + return device_scope(device_name) # noqa: F405 \ No newline at end of file diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0f703483cb28 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,6 +1,7 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core from keras.src.backend.jax import distribution_lib +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import image from keras.src.backend.jax import linalg from keras.src.backend.jax import math diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py new file mode 100644 index 000000000000..8fd999784d52 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend.py @@ -0,0 +1,93 @@ +import jax +import jax.lax as lax + +def get_device_info(): + """Retrieves information about the available JAX devices. + + This function queries the JAX backend to identify the type and number + of available computational devices (e.g., CPU, GPU, TPU). + + Returns: + dict: A dictionary containing the backend name ('jax'), a list of + device string representations, and the total count of devices. + """ + available_devices = jax.devices() + return { + "backend": "jax", + "devices": [str(d) for d in available_devices], + "device_count": len(available_devices), + } + + +def is_multi_device_capable(): + """Checks if more than one JAX device is available for computation. + + Returns: + bool: True if the local JAX environment has more than one device, + False otherwise. + """ + return jax.local_device_count() > 1 + + +def get_communication_ops(): + """Provides a dictionary of JAX collective communication operations. + + Returns: + dict: A dictionary mapping operation names (e.g., 'all_reduce') to their + corresponding JAX implementation functions. + """ + + def all_reduce(x, op="sum", axis_name="model"): + """Reduces a tensor across a device mesh axis using a collective. + + This function assumes it is called within a `pjit` context that has a + device mesh with the specified `axis_name`. It performs a collective + reduction operation (like sum or mean) across all devices mapped to + that axis. + + Args: + x (jax.Array): The input JAX array (tensor) on the local device. + op (str, optional): The reduction operation to perform. Supported + values are 'sum' and 'mean'. Defaults to 'sum'. + axis_name (str, optional): The name of the mapped axis in the device + mesh over which to communicate. Defaults to 'model'. + + Returns: + jax.Array: The reduced JAX array, which is identical across all + devices participating in the reduction. + """ + if op == "sum": + return lax.psum(x, axis_name=axis_name) + elif op == "mean": + return lax.pmean(x, axis_name=axis_name) + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + def all_gather(x, axis, axis_name="model"): + """Gathers and concatenates tensors from all devices across a mesh axis. + + This function assumes it is called within a `pjit` context. It takes + the local shard `x` from each device along the `axis_name` of the mesh + and concatenates them along the specified tensor `axis` to form a + single, larger tensor that is then replicated on all participating devices. + + Args: + x (jax.Array): The input JAX array (tensor) shard on the local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. + axis_name (str, optional): The name of the mesh axis to gather + from. Defaults to 'model'. + + Returns: + jax.Array: The full, gathered JAX array, which is identical across + all devices participating in the gather. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + } \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 32d6734860cc..636775bc14e2 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,13 +1,8 @@ -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Set +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split -from keras.src.distribution.tensor_parallel.config import ConfigKeras -from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras - -def analyze_dense_layer_directly(layer, module, prefix: str) -> str: +def analyze_dense_layer_directly(layer, module, prefix): """Analyzes a Keras Dense layer to classify its sharding strategy. This function inspects the input and output dimensions of a Dense layer @@ -30,55 +25,51 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: from keras.src import layers if not isinstance(layer, layers.Dense): - return "generic_dense" + return 'generic_dense' input_dim = None output_dim = None - if hasattr(layer, "kernel") and layer.kernel is not None: + if hasattr(layer, 'kernel') and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: input_dim = kernel_shape[0] output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, "units"): + if hasattr(layer, 'units'): output_dim = layer.units else: - return "generic_dense" + return 'generic_dense' - if ( - hasattr(layer, "input_shape") - and layer.input_shape - and len(layer.input_shape) > 1 - ): + if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: input_dim = layer.input_shape[-1] else: - return "generic_dense" + return 'generic_dense' if not input_dim or not output_dim: - return "generic_dense" + return 'generic_dense' expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return "up_projection" + return 'up_projection' elif is_contraction: - return "down_projection" + return 'down_projection' else: - return "generic_dense" + return 'generic_dense' def _find_and_shard_layers( current_layer, - prefix: str, + prefix, module, - world_size: int, - state_rules: Dict[str, Any], - output_rules: Dict[str, Any], - processed_layers: Set[int], + world_size, + state_rules, + output_rules, + processed_layers, ): """Recursively traverses the model graph to apply sharding rules. @@ -101,10 +92,10 @@ def _find_and_shard_layers( prefix (str): The hierarchical name prefix for the `current_layer`. module: The top-level Keras model or layer being configured. world_size (int): The total number of devices for sharding. - state_rules (Dict[str, Any]): A dictionary with rules for + state_rules (Dict[str, Any]): A dictionary to be populated with rules for sharding layer weights (state). Keys are regex patterns matching weight names, values are `SplitKeras` actions. - output_rules (Dict[str, Any]): A dictionary with rules + output_rules (Dict[str, Any]): A dictionary to be populated with rules for handling layer outputs. Keys are regex patterns matching layer names, values describe the communication op (e.g., 'allreduce'). processed_layers (Set[int]): A set of `id()`s of layers that have @@ -120,137 +111,86 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly( - current_layer, module, full_name - ) + mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) - if mlp_type == "up_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + if mlp_type == 'up_projection': + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == "down_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) + elif mlp_type == 'down_projection': + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): if "attention_output" in full_name: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: + state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, (layers.Embedding,)): - if hasattr(current_layer, "token_embedding") or hasattr( - current_layer, "position_embedding" - ): + if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): pass else: weight_name = None - if hasattr(current_layer, "embeddings"): - weight_name = "embeddings" - elif hasattr(current_layer, "position_embeddings"): - weight_name = "position_embeddings" + if hasattr(current_layer, 'embeddings'): + weight_name = 'embeddings' + elif hasattr(current_layer, 'position_embeddings'): + weight_name = 'position_embeddings' if weight_name: - state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras( - world_size, 1, "column" - ) + state_rules[f"^{full_name}\\..*{weight_name}$"] = Split(world_size, 1, "column") output_rules[f"^{full_name}$"] = {0: "no_comm"} return - elif isinstance( - current_layer, - ( - layers.LayerNormalization, - layers.BatchNormalization, - layers.GroupNormalization, - ), - ): + elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): return - if hasattr(current_layer, "layers") and current_layer.layers: + if hasattr(current_layer, 'layers') and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + sub_layer, full_name, module, world_size, + state_rules, output_rules, processed_layers ) for attr_name in dir(current_layer): - if attr_name.startswith("__") and attr_name.endswith("__"): + if attr_name.startswith('__') and attr_name.endswith('__'): continue if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + attr, full_name, module, world_size, + state_rules, output_rules, processed_layers ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + item, full_name, module, world_size, + state_rules, output_rules, processed_layers ) +def get_default_config_keras(module, device_ids): + """Generates a default tensor parallelism sharding configuration for a model. -def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a default tensor parallel sharding configuration for a model. - - This function serves as entry point for automatically creating a sharding + This function serves as the entry point for automatically creating a sharding plan for a given Keras model or layer. It initializes the rule dictionaries and starts the recursive layer traversal to populate them based on a default set of heuristics for common architectures like Transformers. @@ -285,7 +225,10 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers, + processed_layers=processed_layers ) - return ConfigKeras(state_rules=state_rules, output_rules=output_rules) + return LayoutMap( + state_rules=state_rules, + output_rules=output_rules + ) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 228a2b184569..d8b8d5ad0482 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,241 +1,139 @@ -import os - -import pytest - -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" - -from keras import Input -from keras import Model -from keras import layers -from keras.src import backend +import keras +from keras.src import layers from keras.src import testing -from keras.src.distribution import distributed_backend -from keras.src.distribution.tensor_parallel.autoconfig import ( - analyze_dense_layer_directly, -) -from keras.src.distribution.tensor_parallel.autoconfig import ( - get_default_config_keras, -) -from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras - - -@pytest.mark.skipif( - backend.backend() != "jax", - reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", -) -class TestAutoConfigKeras(testing.TestCase): - def setUp(self): - """Set up the test case and common variables.""" - super().setUp() - device_info = distributed_backend.get_device_info() - self.world_size = device_info["device_count"] - self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] - - self.assertGreater( - self.world_size, 1, "Distribution tests require more than 1 device." - ) - def _assert_split_keras_equal(self, rule1, rule2): - """Helper to compare two SplitKeras objects by their attributes.""" - self.assertIsInstance(rule1, SplitKeras) - self.assertIsInstance(rule2, SplitKeras) - self.assertDictEqual(vars(rule1), vars(rule2)) +from autoconfig import analyze_dense_layer_directly, get_default_config_keras - def _assert_rules_equal(self, actual_rules, expected_rules): - """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual( - set(actual_rules.keys()), set(expected_rules.keys()) - ) - for key in expected_rules: - actual_val = actual_rules[key] - expected_val = expected_rules[key] - if isinstance(expected_val, SplitKeras): - self._assert_split_keras_equal(actual_val, expected_val) - else: - self.assertEqual(actual_val, expected_val) - - def test_analyze_dense_layer(self): - """Tests the direct analysis and classification of Dense layers.""" - up_proj_layer = layers.Dense(32) +class AutoConfigTest(testing.TestCase): + def test_analyze_dense_layer_directly(self): + """Tests the heuristic for classifying Dense layers.""" + up_proj_layer = layers.Dense(64, name="up") up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), - "up_projection", + analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" ) - - down_proj_layer = layers.Dense(16) - down_proj_layer.build(input_shape=(None, 32)) + down_proj_layer = layers.Dense(16, name="down") + down_proj_layer.build(input_shape=(None, 64)) self.assertEqual( analyze_dense_layer_directly(down_proj_layer, None, ""), "down_projection", ) - - generic_layer = layers.Dense(20) - generic_layer.build(input_shape=(None, 16)) + generic_layer = layers.Dense(32, name="generic") + generic_layer.build(input_shape=(None, 28)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), - "generic_dense", + analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" ) - - def test_simple_mlp_sharding(self): - """Tests a simple MLP with up and down projection layers.""" - inputs = Input(shape=(64,)) - x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( - x - ) - model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^simple_mlp.up_projection_layer.bias$": SplitKeras( - self.world_size, 0, "column" - ), - r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - } - expected_output_rules = { - r"^simple_mlp.up_projection_layer$": {0: "gather"}, - r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_generic_dense_sharding(self): - """Tests a generic Dense layer that isn't an up/down projection.""" - inputs = Input(shape=(64,)) - outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) - model = Model(inputs=inputs, outputs=outputs, name="generic_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^generic_model.generic_layer.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^generic_model.generic_layer.bias$": SplitKeras( - self.world_size, 0, "column" - ), - } - expected_output_rules = { - r"^generic_model.generic_layer$": {0: "gather -1"} - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_embedding_sharding(self): - """Tests an Embedding layer for vocabulary parallelism.""" - inputs = Input(shape=(10,), dtype="int32") - outputs = layers.Embedding( - input_dim=1000, output_dim=128, name="token_embedding" - )(inputs) - model = Model(inputs=inputs, outputs=outputs, name="embed_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^embed_model.token_embedding\..*embeddings$": SplitKeras( - self.world_size, 1, "column" - ) - } - expected_output_rules = { - r"^embed_model.token_embedding$": {0: "no_comm"} - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_einsum_dense_sharding(self): - """Tests the special handling for EinsumDense layers.""" - inputs = Input(shape=(64,)) - x = layers.EinsumDense( - "bh,hd->bd", output_shape=128, name="query_proj" - )(inputs) - outputs = layers.EinsumDense( - "bd,dh->bh", output_shape=64, name="attention_output" - )(x) - model = Model(inputs=inputs, outputs=outputs, name="einsum_model") - - config = get_default_config_keras(model, self.device_ids) - - expected_state_rules = { - r"^einsum_model.query_proj.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^einsum_model.attention_output.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - } - expected_output_rules = { - r"^einsum_model.query_proj$": {0: "gather -1"}, - r"^einsum_model.attention_output$": {0: "allreduce"}, - } - - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) - - def test_normalization_layers_ignored(self): - """Tests that normalization layers are correctly ignored.""" - inputs = Input(shape=(64,)) - x = layers.Dense(64, name="dense1", use_bias=True)(inputs) - x = layers.LayerNormalization(name="layernorm")(x) - outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model(inputs=inputs, outputs=outputs, name="norm_model") - - config = get_default_config_keras(model, self.device_ids) - - for key in config.state_rules: - self.assertNotIn("layernorm", key) - for key in config.output_rules: - self.assertNotIn("layernorm", key) - - self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) - self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) - self.assertEqual(len(config.state_rules), 4) - self.assertEqual(len(config.output_rules), 2) - - def test_nested_model_sharding(self): - """Tests that the traversal logic correctly handles nested models.""" - inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( - inner_inputs - ) - inner_model = Model( - inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + non_dense_layer = layers.LayerNormalization() + self.assertEqual( + analyze_dense_layer_directly(non_dense_layer, None, ""), "generic_dense" ) - outer_inputs = Input(shape=(32,)) - x = inner_model(outer_inputs) - outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) - outer_model = Model( - inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + def test_simple_mlp_model(self): + """Tests rule generation for a standard MLP block (like in a Transformer).""" + world_size = 2 + devices = [f"gpu:{i}" for i in range(world_size)] + + model = keras.Sequential( + [ + keras.Input(shape=(32,)), + layers.Dense(128, name="mlp_up"), # Up-projection + layers.Dense(32, name="mlp_down"), # Down-projection + ], + name="mlp_block", ) - config = get_default_config_keras(outer_model, self.device_ids) - - expected_state_rules = { - r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( - self.world_size, 1, "column" - ), - r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( - self.world_size, 0, "column" - ), - r"^outer_model.outer_dense.kernel$": SplitKeras( - self.world_size, 0, "row" - ), - } - expected_output_rules = { - r"^outer_model.inner_block.inner_dense$": {0: "gather"}, - r"^outer_model.outer_dense$": {0: "allreduce"}, - } - - self.maxDiff = None - self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) + layout_map = get_default_config_keras(model, devices) + state_rules = layout_map.state_rules + output_rules = layout_map.output_rules + + # Assertions for State (Weight) Sharding Rules + up_kernel_rule = state_rules["^mlp_block.mlp_up.kernel$"] + self.assertEqual(up_kernel_rule.world_size, world_size) + self.assertEqual(up_kernel_rule.dim, 1) + + down_kernel_rule = state_rules["^mlp_block.mlp_down.kernel$"] + self.assertEqual(down_kernel_rule.world_size, world_size) + self.assertEqual(down_kernel_rule.dim, 0) + + # Assertions for Output Communication Rules + # --- FIX: Removed trailing space. The source code generates "{0: 'gather'}" --- + self.assertEqual(output_rules["^mlp_block.mlp_up$"], {0: "gather"}) + self.assertEqual(output_rules["^mlp_block.mlp_down$"], {0: "allreduce"}) + + def test_model_with_embedding_and_einsumdense(self): + """Tests rule generation for Embedding and EinsumDense layers.""" + world_size = 4 + devices = [f"gpu:{i}" for i in range(world_size)] + + class SimpleTransformer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # --- FIX: Add explicit `name` arguments to ensure layer names are predictable --- + self.embedding = layers.Embedding( + input_dim=1000, output_dim=64, name="embedding" + ) + self.qkv_proj = layers.EinsumDense( + "abc,cde->abde", + output_shape=(None, 3, 128), + bias_axes="de", + name="qkv_proj", + ) + self.attention_output = layers.EinsumDense( + "abde,cde->abc", + output_shape=(None, 64), + bias_axes="c", + name="attention_output", + ) + + def call(self, inputs): + x = self.embedding(inputs) + x = self.qkv_proj(x) + x = self.attention_output(x) + return x + + model = SimpleTransformer(name="transformer") + model(keras.ops.zeros((1, 10))) + + layout_map = get_default_config_keras(model, devices) + state_rules = layout_map.state_rules + + # --- Assertions --- + # --- FIX: The regex key must match what the provided autoconfig.py generates --- + expected_key = "^transformer.embedding\\..*embeddings$" + self.assertIn(expected_key, state_rules) + emb_rule = state_rules[expected_key] + self.assertEqual(emb_rule.world_size, world_size) + self.assertEqual(emb_rule.dim, 1) + + # These assertions are now correct because the layers are explicitly named + qkv_rule = state_rules["^transformer.qkv_proj.kernel$"] + self.assertEqual(qkv_rule.world_size, world_size) + self.assertEqual(qkv_rule.dim, 1) + + attn_out_rule = state_rules["^transformer.attention_output.kernel$"] + self.assertEqual(attn_out_rule.world_size, world_size) + self.assertEqual(attn_out_rule.dim, 0) + + def test_nested_model(self): + """Tests that the recursive traversal finds layers in nested models.""" + # This test is correct and requires no changes. + world_size = 2 + devices = [f"gpu:{i}" for i in range(world_size)] + inner_model = keras.Sequential( + [layers.Dense(64, name="inner_dense")], name="inner_block" + ) + outer_model = keras.Sequential( + [ + keras.Input(shape=(32,)), + layers.Dense(32, name="outer_dense_1"), + inner_model, + ], + name="outer_block", + ) + layout_map = get_default_config_keras(outer_model, devices) + state_rules = layout_map.state_rules + expected_key = "^outer_block.inner_block.inner_dense.kernel$" + self.assertIn(expected_key, state_rules) + inner_rule = state_rules[expected_key] + self.assertEqual(inner_rule.world_size, world_size) + self.assertEqual(inner_rule.dim, 1) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index ca7f8e5d5fcc..d57dac16d4a5 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -1,12 +1,11 @@ import re -from typing import Any import numpy as np import keras from keras.src import ops from keras.src import optimizers -from keras.src.distribution import distributed_backend +from keras.src.backend import distributed_backend class CoordinatedOptimizer: @@ -36,13 +35,14 @@ class CoordinatedOptimizer: def __init__( self, - base_optimizer: optimizers.Optimizer, - world_size: int, - distributed_backend: str = "auto", - rank: int = 0, - shard_optimizer_states: bool = True, + base_optimizer, + world_size, + distributed_backend="auto", + rank=0, + shard_optimizer_states=True, tensor_parallel_config=None, ): + """Initializes the CoordinatedOptimizer.""" self.base_optimizer = base_optimizer self.world_size = world_size self.shard_optimizer_states = shard_optimizer_states @@ -114,9 +114,7 @@ def _initialize_sharded_states(self): self.base_optimizer.iterations, dim=0 ) - def _partition_state( - self, state_variable: any, dim: int - ) -> list[np.ndarray]: + def _partition_state(self, state_variable, dim): """Splits a single state variable numpy array into chunks. If the variable cannot be split along the given dimension, it is @@ -136,41 +134,20 @@ def _partition_state( else: return [np.copy(state_array) for _ in range(self.world_size)] - def get_config(self) -> dict[str, Any]: - return { - "base_optimizer": self.base_optimizer.get_config(), - "world_size": self.world_size, - "shard_optimizer_states": self.shard_optimizer_states, - } - - def apply_gradients( - self, gradients_and_vars: list[list[tuple]], shard_models: list - ): - """Coordinates gradient synchronization and application. - - This method first synchronizes gradients across all shards based on - tensor parallelism rules. Then, it applies the gradients using either - sharded optimizer states or replicated states. + def apply_gradients(self, grads_and_vars, shard_models): + """ + Applies gradients to the model variables by first synchronizing them + and then applying them using either sharded or replicated optimizer + states. Args: - gradients_and_vars: A list of lists, where each inner list contains - (gradient, variable) tuples for a specific model shard. + grads_and_vars: A list of (gradient, variable) lists from all + shards. shard_models: A list of the sharded model instances. - - Raises: - ValueError: If the number of gradient sets does not match the - world size. """ - if len(gradients_and_vars) != self.world_size: - error_msg = ( - f"Expected {self.world_size} gradient sets, " - f"got {len(gradients_and_vars)}" - ) - raise ValueError(error_msg) - - synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + synchronized_gradients = self._synchronize_gradients(grads_and_vars) - if self.shard_optimizer_states and self.sharded_states: + if self.shard_optimizer_states: self._apply_gradients_with_sharded_states( synchronized_gradients, shard_models ) @@ -180,7 +157,7 @@ def apply_gradients( ) def _apply_gradients_with_replicated_states( - self, synchronized_gradients: list[list[tuple]], shard_models: list + self, synchronized_gradients, shard_models ): """Averages gradients across all shards and applies them once. @@ -219,9 +196,14 @@ def _apply_gradients_with_replicated_states( self.base_optimizer.apply_gradients(averaged_grads_and_vars) def _apply_gradients_with_sharded_states( - self, synchronized_gradients: list[list[tuple]], shard_models: list + self, synchronized_gradients, shard_models ): - """Applies gradients to each shard using its local optimizer state.""" + """Applies gradients to each shard using its local optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ for shard_idx in range(self.world_size): local_states = self._get_local_optimizer_states(shard_idx) shard_optimizer = shard_models[shard_idx].optimizer @@ -233,8 +215,16 @@ def _apply_gradients_with_sharded_states( self._update_global_sharded_states(shard_optimizer, shard_idx) - def _get_local_optimizer_states(self, shard_idx: int) -> dict[str, Any]: - """Constructs the state dictionary for a single shard.""" + def _get_local_optimizer_states(self, shard_idx): + """Constructs the state dictionary for a single shard. + + Args: + shard_idx: The index of the shard for which to retrieve the state. + + Returns: + A dictionary containing the local optimizer state for the specified + shard. + """ local_states = {} for state_name, state_value in self.sharded_states.items(): if isinstance(state_value, dict): @@ -247,15 +237,20 @@ def _get_local_optimizer_states(self, shard_idx: int) -> dict[str, Any]: local_states[state_name] = state_value[shard_idx] return local_states - def _update_optimizer_internal_state(self, optimizer, local_states: dict): - """Assigns local sharded state values to the optimizer's variables.""" + def _update_optimizer_internal_state(self, optimizer, local_states): + """Assigns local sharded state values to the optimizer's variables. + + Args: + optimizer: The optimizer instance for a specific shard. + local_states: The dictionary of local states for that shard. + """ if not optimizer.built: return for var in optimizer.variables: if var is optimizer.iterations: if "iterations" in local_states: - ops.assign(var, local_states["iterations"]) + var.assign(local_states["iterations"]) continue param = self._state_variable_to_parameter.get(var.path, None) @@ -269,10 +264,15 @@ def _update_optimizer_internal_state(self, optimizer, local_states: dict): ): local_param_state = local_states[slot_name][param.path] if var.shape == local_param_state.shape: - ops.assign(var, local_param_state) + var.assign(local_param_state) - def _update_global_sharded_states(self, optimizer, shard_idx: int): - """Updates the main sharded_states dictionary after a gradient step.""" + def _update_global_sharded_states(self, optimizer, shard_idx): + """Updates the main sharded_states dictionary after a gradient step. + + Args: + optimizer: The optimizer instance for a specific shard. + shard_idx: The index of the shard that was updated. + """ if not optimizer.built: return @@ -296,9 +296,7 @@ def _update_global_sharded_states(self, optimizer, shard_idx: int): ops.convert_to_numpy(var) ) - def _synchronize_gradients( - self, gradients_and_vars: list[list[tuple]] - ) -> list[list[tuple]]: + def _synchronize_gradients(self, gradients_and_vars): """Synchronizes gradients across shards based on tensor parallel rules. Specifically, it performs an all-reduce operation on gradients of @@ -349,7 +347,7 @@ def _synchronize_gradients( ) return gradients_and_vars - def _allreduce_gradients(self, gradients: list[Any]) -> list[Any]: + def _allreduce_gradients(self, gradients): """Performs a mean all-reduce operation on a list of gradients. If a distributed backend is available, it uses it. Otherwise, it @@ -379,35 +377,37 @@ def _allreduce_gradients(self, gradients: list[Any]) -> list[Any]: mean_grad = ops.mean(stacked_grads, axis=0) return [mean_grad for _ in range(len(gradients))] - def get_weights(self) -> list[np.ndarray]: - """Returns the weights of the base optimizer.""" + def get_weights(self): + """Returns the weights of the base optimizer. + + Returns: + A list of NumPy arrays representing the optimizer's state variables. + """ return [ ops.convert_to_numpy(var) for var in self.base_optimizer.variables ] - def set_weights(self, weights: list[np.ndarray]): - """Sets the weights of the base optimizer.""" + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Args: + weights: A list of NumPy arrays to set as the optimizer's state. + """ self.base_optimizer.set_weights(weights) - def enable_optimizer_state_sharding(self, variables: list): + def enable_optimizer_state_sharding(self, variables): """Enables and initializes optimizer state sharding. This method is called from `build()`, which is guarded from running multiple times. We can assume this should always execute. + + Args: + variables: A list of model variables to be optimized. """ self.shard_optimizer_states = True self._variables = variables self._initialize_sharded_states() - def disable_optimizer_state_sharding(self): - """Disables sharding and clears any sharded states. - - This reverts the optimizer to using a single, replicated state. - """ - if self.shard_optimizer_states: - self.shard_optimizer_states = False - self.sharded_states = {} - class TensorParallelOptimizer(optimizers.Optimizer): """A Keras Optimizer wrapper for tensor-parallel distributed training. @@ -457,11 +457,12 @@ class TensorParallelOptimizer(optimizers.Optimizer): def __init__( self, - base_optimizer: optimizers.Optimizer, - world_size: int, - distributed_backend: str = "auto", + base_optimizer, + world_size, + distributed_backend="auto", tensor_parallel_config=None, ): + """Initializes the TensorParallelOptimizer.""" if isinstance(base_optimizer, str): base_optimizer_instance = optimizers.get(base_optimizer) else: @@ -488,7 +489,7 @@ def __init__( tensor_parallel_config=tensor_parallel_config, ) - def apply_gradients(self, grads_and_vars: list, **kwargs): + def apply_gradients(self, grads_and_vars, **kwargs): """Applies gradients to the model variables. If `grads_and_vars` is a list of lists, it's assumed to be from @@ -514,7 +515,12 @@ def apply_gradients(self, grads_and_vars: list, **kwargs): else: self.base_optimizer.apply_gradients(grads_and_vars) - def get_config(self) -> dict[str, Any]: + def get_config(self): + """Returns the configuration of the optimizer. + + Returns: + A dictionary containing the optimizer's configuration. + """ from keras.src import saving config = super().get_config() @@ -533,20 +539,35 @@ def get_config(self) -> dict[str, Any]: return config def update_step(self, gradient, variable, *args, **kwargs): + """Performs a single optimization step. + + Delegates the update step to the base optimizer if it has a custom + `update_step` implementation, otherwise falls back to the parent + optimizer's logic. + + Args: + gradient: The gradient tensor. + variable: The variable to be updated. + *args: Positional arguments passed to the update function. + **kwargs: Keyword arguments passed to the update function. + """ if hasattr(self.base_optimizer, "update_step"): - try: - return self.base_optimizer.update_step( - gradient, variable, *args, **kwargs - ) - except TypeError: - return self.base_optimizer.update_step(gradient, variable) - try: - return super().update_step(gradient, variable, *args, **kwargs) - except TypeError: - return super().update_step(gradient, variable) + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + + return super().update_step(gradient, variable, *args, **kwargs) @classmethod - def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": + def from_config(cls, config): + """Creates an optimizer from its configuration. + + Args: + config: A Python dictionary, typically the output of `get_config`. + + Returns: + A `TensorParallelOptimizer` instance. + """ from keras.src import saving base_optimizer_config = config.pop("base_optimizer") @@ -560,7 +581,7 @@ def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": return cls(base_optimizer=base_optimizer, **init_kwargs) - def build(self, variables: list): + def build(self, variables): """Builds the optimizer and initializes sharded states. This method is called the first time the optimizer is used. It builds @@ -575,40 +596,58 @@ def build(self, variables: list): self.base_optimizer.build(variables) if variables: + iterations = self.base_optimizer.iterations + original_iterations_val = None + if iterations is not None: + original_iterations_val = ops.convert_to_numpy(iterations.value) + zero_grads = [ops.zeros_like(v) for v in variables] self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + if iterations is not None and original_iterations_val is not None: + iterations.assign(original_iterations_val) + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) - def get_weights(self) -> list[np.ndarray]: - """Returns the weights of the base optimizer.""" + def get_weights(self): + """Returns the weights of the base optimizer. + + Returns: + A list of NumPy arrays representing the optimizer's state variables. + """ return self.coordinated_optimizer.get_weights() - def set_weights(self, weights: list[np.ndarray]): - """Sets the weights of the base optimizer.""" + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Args: + weights: A list of NumPy arrays to set as the optimizer's state. + """ self.coordinated_optimizer.set_weights(weights) @property - def variables(self) -> list: - """Returns the list of variables from the base optimizer.""" + def variables(self): + """Returns the list of variables from the base optimizer. + + Returns: + A list of state variables of the base optimizer. + """ return self.base_optimizer.variables @property - def learning_rate(self) -> Any: + def learning_rate(self): """Provides access to the learning rate of the base optimizer.""" return self.base_optimizer.learning_rate @learning_rate.setter def learning_rate(self, value): + """Sets the learning rate of the base optimizer.""" self.base_optimizer.learning_rate = value @property def iterations(self): """ - Returns the training iteration count, compensating for the initial - dummy step in the build method. + Returns the training iteration count directly from the base optimizer. """ - if self.base_optimizer.iterations is None: - return None - return self.base_optimizer.iterations - 1 + return self.base_optimizer.iterations \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 39cce46de72c..fc18730c484a 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -1,3 +1,6 @@ +import os +os.environ["KERAS_BACKEND"]="jax" + import numpy as np import pytest @@ -6,18 +9,17 @@ from keras.src import optimizers from keras.src import testing -if keras.backend.backend() == "jax": - from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - CoordinatedOptimizer, - ) - from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - TensorParallelOptimizer, - ) - +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer +) +from keras.src import backend @pytest.mark.skipif( - keras.backend.backend() != "jax", - reason="This test is JAX-specific.", + backend.backend() != "jax", + reason="This test is only for the JAX backend." ) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): @@ -79,9 +81,10 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): coord.apply_gradients(mock_grads, []) self.assertEqual(optimizer.apply_gradients_call_count, 1) + grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) self.assertAllClose( - optimizer.received_grads[0], - np.ones_like(optimizer.received_grads[0]) * 2.5, + grad_numpy, + np.ones_like(grad_numpy) * 2.5, ) def test_init_from_string(self): @@ -177,4 +180,4 @@ def test_sharding_with_prefixed_variable_names(self): kernel_path = dense_output_kernel.path.replace("/", "_") momentum_path = f"{optimizer_name}/{kernel_path}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) + self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py deleted file mode 100644 index 012234cb77f4..000000000000 --- a/keras/src/distribution/tensor_parallel/sharding_keras.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any -from typing import Collection -from typing import Dict -from typing import List -from typing import Sequence - -from keras.src.distribution.tensor_parallel.config import ConfigKeras - - -class ShardedKeras: - """ - Manages sharded parameters for Keras models. - """ - - def __init__( - self, - model_shards, - replicated_param_names: Collection[str], - tensor_parallel_config: ConfigKeras, - devices: Sequence[str], - output_device_index: int, - ): - """ - Initialize the sharding manager. - - Args: - model_shards: List of model shards - replicated_param_names: Names of parameters that are replicated - tensor_parallel_config: Tensor parallel configuration - devices: List of device IDs - output_device_index: Index of the output device - """ - self.model_shards = model_shards - self.replicated_param_names = set(replicated_param_names) - self.tensor_parallel_config = tensor_parallel_config - self.devices = devices - self.output_device_index = output_device_index - - def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: - """ - Get parameters for a specific shard. - - Args: - shard_index: Index of the shard - - Returns: - Dictionary of parameter names to values - """ - if shard_index >= len(self.model_shards): - raise ValueError(f"Shard index {shard_index} out of range") - - shard = self.model_shards[shard_index] - params = {} - - for weight in shard.weights: - param_name = weight.path.replace("/", ".") - params[param_name] = weight - - return params - - def get_all_parameters(self) -> List[Dict[str, Any]]: - """ - Get parameters from all shards. - - Returns: - List of parameter dictionaries for each shard - """ - return [ - self.get_shard_parameters(i) for i in range(len(self.model_shards)) - ] - - def apply_sharding(self): - """ - Apply sharding to the model parameters. - """ - pass - - def unshard_parameters(self): - """ - Unshard parameters back to their original form. - """ - pass diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..442dbe2ca673 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,164 @@ +import keras + +class LayoutAction: + """Abstract base class for actions that transform tensors for distribution. + + A LayoutAction defines a rule for how a single tensor should be physically + represented across multiple devices. It includes a forward operation (`__call__`) + to shard the tensor and a reverse operation (`undo`) to reconstruct it. + """ + def __call__(self, tensor, rank): + """Applies the distribution action to a tensor for a specific worker. + + Args: + tensor: The input tensor to be distributed. + rank: The integer rank of the current worker/device. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + A shard or transformation of the input tensor specific to the given + rank. + """ + raise NotImplementedError + + def undo(self, tensors): + """Reverses the distribution action, reconstructing the original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Raises: + NotImplementedError: This is an abstract method and must be + implemented by subclasses. + + Returns: + The reconstructed, single tensor. + """ + raise NotImplementedError + + +class _ConcatenateMixin: + """A mixin class providing a common `undo` method via concatenation. + + This class is intended to be used as a mixin for `LayoutAction` subclasses + that can be undone by simple concatenation along a specified axis. + """ + def undo(self, tensors): + """Concatenates a sequence of tensors to reconstruct the original tensor. + + Args: + tensors: A sequence of tensor shards, one from each worker. + + Returns: + The single tensor reconstructed by concatenating the shards. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensors[0]) - 1 + else: + dim = self.dim + return keras.ops.concatenate(tensors, axis=dim) + + +class Split(_ConcatenateMixin, LayoutAction): + """Splits a tensor into shards along a specified dimension. + + This is an internal utility used by a higher-level distribution API. + It implements sharding by slicing a tensor along one of its axes. + It handles cases where the dimension size is not perfectly divisible by the + number of workers by distributing the remainder elements one by one to the + first few workers. + + The `undo` operation is provided by the `_ConcatenateMixin`. + """ + def __init__(self, world_size, dim, sharding_type="auto"): + """Initializes the Split action. + + Args: + world_size: The total number of workers/shards. + dim: The dimension along which to split the tensor. If -1, the + last dimension is used. + sharding_type: If `dim` is -1, this can be 'row' (dim=0) or + 'column' (dim=1) to infer the split axis for 2D tensors. + Defaults to "auto". + """ + super().__init__() + self.world_size = world_size + self.dim = dim + self.sharding_type = sharding_type + + if dim == -1 and sharding_type != "auto": + if sharding_type == "row": + self.dim = 0 + elif sharding_type == "column": + self.dim = 1 + + def __call__(self, tensor, rank): + """Splits the tensor and returns the shard corresponding to the rank. + + This method calculates the correct slice of the tensor for a given + worker rank, handling uneven distributions gracefully. + + Args: + tensor: The full tensor to be sharded. + rank: The rank of the worker for which to get the shard. + + Returns: + A tensor shard corresponding to the given rank. + """ + if self.dim == -1: + dim = keras.ops.ndim(tensor) - 1 + else: + dim = self.dim + + total_size = tensor.shape[dim] + split_size = total_size // self.world_size + remainder = total_size % self.world_size + + start_idx = rank * split_size + min(rank, remainder) + end_idx = start_idx + split_size + (1 if rank < remainder else 0) + + slices = [slice(None)] * keras.ops.ndim(tensor) + slices[dim] = slice(start_idx, end_idx) + return tensor[tuple(slices)] + + +class LayoutMap: + """A mapping that defines layout rules for model states and outputs. + + This is an internal configuration object used to hold layout rules for + how model variables and layer outputs should be distributed across a set + of devices. It acts as a container for `LayoutAction` instances. + + Attributes: + state_rules: A dictionary mapping variable names or patterns to + `LayoutAction` instances. + output_rules: A dictionary mapping layer output names or + patterns to `LayoutAction` instances. + """ + def __init__(self, state_rules, output_rules): + """Initializes the LayoutMap. + + Args: + state_rules: A dictionary of distribution rules for model states. + output_rules: A dictionary of distribution rules for model outputs. + """ + self.state_rules = state_rules + self.output_rules = output_rules + + def create_collective_ops(self, devices): + """Creates the necessary collective communication operations. + + This method is a placeholder for backend-specific logic that would + translate the layout rules into actual communication primitives + (e.g., all-gather, reduce-scatter). + + Args: + devices: A sequence of device identifiers. + + Returns: + The `LayoutMap` instance itself, allowing for method chaining. + """ + return self \ No newline at end of file From 3a4af335f4aecfee03baf36594e114c76431cb41 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 16 Oct 2025 00:06:57 +0530 Subject: [PATCH 11/12] Testing PR1&2 --- keras/src/backend/__init__.py | 2 +- keras/src/backend/jax/__init__.py | 2 +- keras/src/backend/jax/distributed_backend.py | 7 +- .../tensor_parallel/autoconfig.py | 140 ++++++++++++------ .../tensor_parallel/autoconfig_test.py | 25 ++-- .../tensor_parallel/coordinated_optimizer.py | 2 +- .../coordinated_optimizer_test.py | 14 +- .../tensor_parallel/tensor_layout.py | 20 +-- 8 files changed, 125 insertions(+), 87 deletions(-) diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index c89e7d82c90a..df4d6ced9d26 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -79,4 +79,4 @@ class name_scope(backend_name_scope): @keras_export("keras.device") def device(device_name): - return device_scope(device_name) # noqa: F405 \ No newline at end of file + return device_scope(device_name) # noqa: F405 diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 0f703483cb28..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,7 +1,7 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core -from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import distributed_backend +from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg from keras.src.backend.jax import math diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py index 8fd999784d52..a20cdf68605c 100644 --- a/keras/src/backend/jax/distributed_backend.py +++ b/keras/src/backend/jax/distributed_backend.py @@ -1,6 +1,7 @@ import jax import jax.lax as lax + def get_device_info(): """Retrieves information about the available JAX devices. @@ -72,10 +73,10 @@ def all_gather(x, axis, axis_name="model"): This function assumes it is called within a `pjit` context. It takes the local shard `x` from each device along the `axis_name` of the mesh and concatenates them along the specified tensor `axis` to form a - single, larger tensor that is then replicated on all participating devices. + single, larger tensor that is then replicated on participating devices. Args: - x (jax.Array): The input JAX array (tensor) shard on the local device. + x (jax.Array): The input JAX array (tensor) shard on local device. axis (int): The tensor axis along which to concatenate the gathered shards. axis_name (str, optional): The name of the mesh axis to gather @@ -90,4 +91,4 @@ def all_gather(x, axis, axis_name="model"): return { "all_reduce": all_reduce, "all_gather": all_gather, - } \ No newline at end of file + } diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 636775bc14e2..708d6d603cc6 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -25,41 +25,45 @@ def analyze_dense_layer_directly(layer, module, prefix): from keras.src import layers if not isinstance(layer, layers.Dense): - return 'generic_dense' + return "generic_dense" input_dim = None output_dim = None - if hasattr(layer, 'kernel') and layer.kernel is not None: + if hasattr(layer, "kernel") and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: input_dim = kernel_shape[0] output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, 'units'): + if hasattr(layer, "units"): output_dim = layer.units else: - return 'generic_dense' + return "generic_dense" - if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): input_dim = layer.input_shape[-1] else: - return 'generic_dense' + return "generic_dense" if not input_dim or not output_dim: - return 'generic_dense' + return "generic_dense" expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return 'up_projection' + return "up_projection" elif is_contraction: - return 'down_projection' + return "down_projection" else: - return 'generic_dense' + return "generic_dense" def _find_and_shard_layers( @@ -92,9 +96,9 @@ def _find_and_shard_layers( prefix (str): The hierarchical name prefix for the `current_layer`. module: The top-level Keras model or layer being configured. world_size (int): The total number of devices for sharding. - state_rules (Dict[str, Any]): A dictionary to be populated with rules for - sharding layer weights (state). Keys are regex patterns matching - weight names, values are `SplitKeras` actions. + state_rules (Dict[str, Any]): A dictionary to be populated with rules + for sharding layer weights (state). Keys are regex patterns + matching weight names, values are `SplitKeras` actions. output_rules (Dict[str, Any]): A dictionary to be populated with rules for handling layer outputs. Keys are regex patterns matching layer names, values describe the communication op (e.g., 'allreduce'). @@ -111,86 +115,133 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) - if mlp_type == 'up_projection': - state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == 'down_projection': + elif mlp_type == "down_projection": state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): if "attention_output" in full_name: state_rules[f"^{full_name}.kernel$"] = Split(world_size, 0, "row") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = Split(world_size, 1, "column") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: - state_rules[f"^{full_name}.bias$"] = Split(world_size, 0, "column") + state_rules[f"^{full_name}.kernel$"] = Split( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = Split( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, (layers.Embedding,)): - if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): + if hasattr(current_layer, "token_embedding") or hasattr( + current_layer, "position_embedding" + ): pass else: weight_name = None - if hasattr(current_layer, 'embeddings'): - weight_name = 'embeddings' - elif hasattr(current_layer, 'position_embeddings'): - weight_name = 'position_embeddings' + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" if weight_name: - state_rules[f"^{full_name}\\..*{weight_name}$"] = Split(world_size, 1, "column") + state_rules[f"^{full_name}\\..*{weight_name}$"] = Split( + world_size, 1, "column" + ) output_rules[f"^{full_name}$"] = {0: "no_comm"} return - elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): return - if hasattr(current_layer, 'layers') and current_layer.layers: + if hasattr(current_layer, "layers") and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, full_name, module, world_size, - state_rules, output_rules, processed_layers + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) for attr_name in dir(current_layer): - if attr_name.startswith('__') and attr_name.endswith('__'): + if attr_name.startswith("__") and attr_name.endswith("__"): continue if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, full_name, module, world_size, - state_rules, output_rules, processed_layers + attr, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, full_name, module, world_size, - state_rules, output_rules, processed_layers + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) + def get_default_config_keras(module, device_ids): - """Generates a default tensor parallelism sharding configuration for a model. + """Generates default tensor parallelism sharding configuration for a model. - This function serves as the entry point for automatically creating a sharding + This function serves as entry point for automatically creating a sharding plan for a given Keras model or layer. It initializes the rule dictionaries and starts the recursive layer traversal to populate them based on a default set of heuristics for common architectures like Transformers. @@ -225,10 +276,7 @@ def get_default_config_keras(module, device_ids): world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers + processed_layers=processed_layers, ) - return LayoutMap( - state_rules=state_rules, - output_rules=output_rules - ) \ No newline at end of file + return LayoutMap(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index d8b8d5ad0482..3c7594a9a6e8 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,8 +1,10 @@ +from autoconfig import analyze_dense_layer_directly +from autoconfig import get_default_config_keras + import keras from keras.src import layers from keras.src import testing -from autoconfig import analyze_dense_layer_directly, get_default_config_keras class AutoConfigTest(testing.TestCase): def test_analyze_dense_layer_directly(self): @@ -10,7 +12,8 @@ def test_analyze_dense_layer_directly(self): up_proj_layer = layers.Dense(64, name="up") up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", ) down_proj_layer = layers.Dense(16, name="down") down_proj_layer.build(input_shape=(None, 64)) @@ -21,15 +24,17 @@ def test_analyze_dense_layer_directly(self): generic_layer = layers.Dense(32, name="generic") generic_layer.build(input_shape=(None, 28)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", ) non_dense_layer = layers.LayerNormalization() self.assertEqual( - analyze_dense_layer_directly(non_dense_layer, None, ""), "generic_dense" + analyze_dense_layer_directly(non_dense_layer, None, ""), + "generic_dense", ) def test_simple_mlp_model(self): - """Tests rule generation for a standard MLP block (like in a Transformer).""" + """Tests rule generation for a standard MLP block.""" world_size = 2 devices = [f"gpu:{i}" for i in range(world_size)] @@ -56,7 +61,6 @@ def test_simple_mlp_model(self): self.assertEqual(down_kernel_rule.dim, 0) # Assertions for Output Communication Rules - # --- FIX: Removed trailing space. The source code generates "{0: 'gather'}" --- self.assertEqual(output_rules["^mlp_block.mlp_up$"], {0: "gather"}) self.assertEqual(output_rules["^mlp_block.mlp_down$"], {0: "allreduce"}) @@ -68,7 +72,6 @@ def test_model_with_embedding_and_einsumdense(self): class SimpleTransformer(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) - # --- FIX: Add explicit `name` arguments to ensure layer names are predictable --- self.embedding = layers.Embedding( input_dim=1000, output_dim=64, name="embedding" ) @@ -84,7 +87,7 @@ def __init__(self, **kwargs): bias_axes="c", name="attention_output", ) - + def call(self, inputs): x = self.embedding(inputs) x = self.qkv_proj(x) @@ -97,15 +100,12 @@ def call(self, inputs): layout_map = get_default_config_keras(model, devices) state_rules = layout_map.state_rules - # --- Assertions --- - # --- FIX: The regex key must match what the provided autoconfig.py generates --- expected_key = "^transformer.embedding\\..*embeddings$" self.assertIn(expected_key, state_rules) emb_rule = state_rules[expected_key] self.assertEqual(emb_rule.world_size, world_size) self.assertEqual(emb_rule.dim, 1) - # These assertions are now correct because the layers are explicitly named qkv_rule = state_rules["^transformer.qkv_proj.kernel$"] self.assertEqual(qkv_rule.world_size, world_size) self.assertEqual(qkv_rule.dim, 1) @@ -116,7 +116,6 @@ def call(self, inputs): def test_nested_model(self): """Tests that the recursive traversal finds layers in nested models.""" - # This test is correct and requires no changes. world_size = 2 devices = [f"gpu:{i}" for i in range(world_size)] inner_model = keras.Sequential( @@ -136,4 +135,4 @@ def test_nested_model(self): self.assertIn(expected_key, state_rules) inner_rule = state_rules[expected_key] self.assertEqual(inner_rule.world_size, world_size) - self.assertEqual(inner_rule.dim, 1) \ No newline at end of file + self.assertEqual(inner_rule.dim, 1) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index d57dac16d4a5..62039e2e121f 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -650,4 +650,4 @@ def iterations(self): """ Returns the training iteration count directly from the base optimizer. """ - return self.base_optimizer.iterations \ No newline at end of file + return self.base_optimizer.iterations diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index fc18730c484a..0f96db54fbe2 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -1,25 +1,25 @@ import os -os.environ["KERAS_BACKEND"]="jax" + +os.environ["KERAS_BACKEND"] = "jax" import numpy as np import pytest import keras from keras import ops +from keras.src import backend from keras.src import optimizers from keras.src import testing - from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( CoordinatedOptimizer, ) from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - TensorParallelOptimizer + TensorParallelOptimizer, ) -from keras.src import backend + @pytest.mark.skipif( - backend.backend() != "jax", - reason="This test is only for the JAX backend." + backend.backend() != "jax", reason="This test is only for the JAX backend." ) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): @@ -180,4 +180,4 @@ def test_sharding_with_prefixed_variable_names(self): kernel_path = dense_output_kernel.path.replace("/", "_") momentum_path = f"{optimizer_name}/{kernel_path}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py index 442dbe2ca673..6841e4d01a36 100644 --- a/keras/src/distribution/tensor_parallel/tensor_layout.py +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -1,12 +1,7 @@ import keras -class LayoutAction: - """Abstract base class for actions that transform tensors for distribution. - A LayoutAction defines a rule for how a single tensor should be physically - represented across multiple devices. It includes a forward operation (`__call__`) - to shard the tensor and a reverse operation (`undo`) to reconstruct it. - """ +class LayoutAction: def __call__(self, tensor, rank): """Applies the distribution action to a tensor for a specific worker. @@ -46,15 +41,8 @@ class _ConcatenateMixin: This class is intended to be used as a mixin for `LayoutAction` subclasses that can be undone by simple concatenation along a specified axis. """ - def undo(self, tensors): - """Concatenates a sequence of tensors to reconstruct the original tensor. - - Args: - tensors: A sequence of tensor shards, one from each worker. - Returns: - The single tensor reconstructed by concatenating the shards. - """ + def undo(self, tensors): if self.dim == -1: dim = keras.ops.ndim(tensors[0]) - 1 else: @@ -73,6 +61,7 @@ class Split(_ConcatenateMixin, LayoutAction): The `undo` operation is provided by the `_ConcatenateMixin`. """ + def __init__(self, world_size, dim, sharding_type="auto"): """Initializes the Split action. @@ -138,6 +127,7 @@ class LayoutMap: output_rules: A dictionary mapping layer output names or patterns to `LayoutAction` instances. """ + def __init__(self, state_rules, output_rules): """Initializes the LayoutMap. @@ -161,4 +151,4 @@ def create_collective_ops(self, devices): Returns: The `LayoutMap` instance itself, allowing for method chaining. """ - return self \ No newline at end of file + return self From ec0009ae99ab61ef7a46eeaab6e509bf4d05f641 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 16 Oct 2025 00:45:32 +0530 Subject: [PATCH 12/12] removing pr1 contents --- keras/src/backend/__init__.py | 3 - keras/src/backend/jax/distributed_backend.py | 94 ----------- .../tensor_parallel/tensor_layout.py | 154 ------------------ 3 files changed, 251 deletions(-) delete mode 100644 keras/src/backend/jax/distributed_backend.py delete mode 100644 keras/src/distribution/tensor_parallel/tensor_layout.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index df4d6ced9d26..d101496f3cbd 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -47,19 +47,16 @@ from keras.src.backend.torch.core import Variable as BackendVariable distribution_lib = None - distributed_backend = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None - distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None - distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py deleted file mode 100644 index a20cdf68605c..000000000000 --- a/keras/src/backend/jax/distributed_backend.py +++ /dev/null @@ -1,94 +0,0 @@ -import jax -import jax.lax as lax - - -def get_device_info(): - """Retrieves information about the available JAX devices. - - This function queries the JAX backend to identify the type and number - of available computational devices (e.g., CPU, GPU, TPU). - - Returns: - dict: A dictionary containing the backend name ('jax'), a list of - device string representations, and the total count of devices. - """ - available_devices = jax.devices() - return { - "backend": "jax", - "devices": [str(d) for d in available_devices], - "device_count": len(available_devices), - } - - -def is_multi_device_capable(): - """Checks if more than one JAX device is available for computation. - - Returns: - bool: True if the local JAX environment has more than one device, - False otherwise. - """ - return jax.local_device_count() > 1 - - -def get_communication_ops(): - """Provides a dictionary of JAX collective communication operations. - - Returns: - dict: A dictionary mapping operation names (e.g., 'all_reduce') to their - corresponding JAX implementation functions. - """ - - def all_reduce(x, op="sum", axis_name="model"): - """Reduces a tensor across a device mesh axis using a collective. - - This function assumes it is called within a `pjit` context that has a - device mesh with the specified `axis_name`. It performs a collective - reduction operation (like sum or mean) across all devices mapped to - that axis. - - Args: - x (jax.Array): The input JAX array (tensor) on the local device. - op (str, optional): The reduction operation to perform. Supported - values are 'sum' and 'mean'. Defaults to 'sum'. - axis_name (str, optional): The name of the mapped axis in the device - mesh over which to communicate. Defaults to 'model'. - - Returns: - jax.Array: The reduced JAX array, which is identical across all - devices participating in the reduction. - """ - if op == "sum": - return lax.psum(x, axis_name=axis_name) - elif op == "mean": - return lax.pmean(x, axis_name=axis_name) - else: - raise ValueError( - f"Unsupported reduction operation: {op}. " - "Supported options are 'sum' and 'mean'." - ) - - def all_gather(x, axis, axis_name="model"): - """Gathers and concatenates tensors from all devices across a mesh axis. - - This function assumes it is called within a `pjit` context. It takes - the local shard `x` from each device along the `axis_name` of the mesh - and concatenates them along the specified tensor `axis` to form a - single, larger tensor that is then replicated on participating devices. - - Args: - x (jax.Array): The input JAX array (tensor) shard on local device. - axis (int): The tensor axis along which to concatenate the gathered - shards. - axis_name (str, optional): The name of the mesh axis to gather - from. Defaults to 'model'. - - Returns: - jax.Array: The full, gathered JAX array, which is identical across - all devices participating in the gather. - """ - return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) - - return { - "all_reduce": all_reduce, - "all_gather": all_gather, - } diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py deleted file mode 100644 index 6841e4d01a36..000000000000 --- a/keras/src/distribution/tensor_parallel/tensor_layout.py +++ /dev/null @@ -1,154 +0,0 @@ -import keras - - -class LayoutAction: - def __call__(self, tensor, rank): - """Applies the distribution action to a tensor for a specific worker. - - Args: - tensor: The input tensor to be distributed. - rank: The integer rank of the current worker/device. - - Raises: - NotImplementedError: This is an abstract method and must be - implemented by subclasses. - - Returns: - A shard or transformation of the input tensor specific to the given - rank. - """ - raise NotImplementedError - - def undo(self, tensors): - """Reverses the distribution action, reconstructing the original tensor. - - Args: - tensors: A sequence of tensor shards, one from each worker. - - Raises: - NotImplementedError: This is an abstract method and must be - implemented by subclasses. - - Returns: - The reconstructed, single tensor. - """ - raise NotImplementedError - - -class _ConcatenateMixin: - """A mixin class providing a common `undo` method via concatenation. - - This class is intended to be used as a mixin for `LayoutAction` subclasses - that can be undone by simple concatenation along a specified axis. - """ - - def undo(self, tensors): - if self.dim == -1: - dim = keras.ops.ndim(tensors[0]) - 1 - else: - dim = self.dim - return keras.ops.concatenate(tensors, axis=dim) - - -class Split(_ConcatenateMixin, LayoutAction): - """Splits a tensor into shards along a specified dimension. - - This is an internal utility used by a higher-level distribution API. - It implements sharding by slicing a tensor along one of its axes. - It handles cases where the dimension size is not perfectly divisible by the - number of workers by distributing the remainder elements one by one to the - first few workers. - - The `undo` operation is provided by the `_ConcatenateMixin`. - """ - - def __init__(self, world_size, dim, sharding_type="auto"): - """Initializes the Split action. - - Args: - world_size: The total number of workers/shards. - dim: The dimension along which to split the tensor. If -1, the - last dimension is used. - sharding_type: If `dim` is -1, this can be 'row' (dim=0) or - 'column' (dim=1) to infer the split axis for 2D tensors. - Defaults to "auto". - """ - super().__init__() - self.world_size = world_size - self.dim = dim - self.sharding_type = sharding_type - - if dim == -1 and sharding_type != "auto": - if sharding_type == "row": - self.dim = 0 - elif sharding_type == "column": - self.dim = 1 - - def __call__(self, tensor, rank): - """Splits the tensor and returns the shard corresponding to the rank. - - This method calculates the correct slice of the tensor for a given - worker rank, handling uneven distributions gracefully. - - Args: - tensor: The full tensor to be sharded. - rank: The rank of the worker for which to get the shard. - - Returns: - A tensor shard corresponding to the given rank. - """ - if self.dim == -1: - dim = keras.ops.ndim(tensor) - 1 - else: - dim = self.dim - - total_size = tensor.shape[dim] - split_size = total_size // self.world_size - remainder = total_size % self.world_size - - start_idx = rank * split_size + min(rank, remainder) - end_idx = start_idx + split_size + (1 if rank < remainder else 0) - - slices = [slice(None)] * keras.ops.ndim(tensor) - slices[dim] = slice(start_idx, end_idx) - return tensor[tuple(slices)] - - -class LayoutMap: - """A mapping that defines layout rules for model states and outputs. - - This is an internal configuration object used to hold layout rules for - how model variables and layer outputs should be distributed across a set - of devices. It acts as a container for `LayoutAction` instances. - - Attributes: - state_rules: A dictionary mapping variable names or patterns to - `LayoutAction` instances. - output_rules: A dictionary mapping layer output names or - patterns to `LayoutAction` instances. - """ - - def __init__(self, state_rules, output_rules): - """Initializes the LayoutMap. - - Args: - state_rules: A dictionary of distribution rules for model states. - output_rules: A dictionary of distribution rules for model outputs. - """ - self.state_rules = state_rules - self.output_rules = output_rules - - def create_collective_ops(self, devices): - """Creates the necessary collective communication operations. - - This method is a placeholder for backend-specific logic that would - translate the layout rules into actual communication primitives - (e.g., all-gather, reduce-scatter). - - Args: - devices: A sequence of device identifiers. - - Returns: - The `LayoutMap` instance itself, allowing for method chaining. - """ - return self