diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..8f2c6d42646e 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,7 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -44,7 +45,6 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable - distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..1f1dfd04d884 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,260 @@ +from keras.src import layers +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split + + +def analyze_dense_layer(layer): + """Analyzes a Keras Dense layer to classify its sharding strategy. + + This function inspects the input and output dimensions of a Dense layer + to determine if it functions as an expansion layer ("up-projection"), a + contraction layer ("down-projection"), or neither ("dense"). This + classification is a heuristic commonly used to apply tensor parallelism + in Transformer-based models, such as in an MLP block where an up-projection + is followed by a down-projection. + + The classification is based on an `expansion_threshold` (set to 1.5). + + Args: + layer: The Keras `layers.Dense` instance to analyze. + + Returns: + str: A string classifying the layer as 'up_projection', + 'down_projection', or 'dense'. + """ + + if not isinstance(layer, layers.Dense): + return "dense" + + input_dim = None + output_dim = None + + if hasattr(layer, "kernel") and layer.kernel is not None: + kernel_shape = layer.kernel.shape + if len(kernel_shape) == 2: + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] + + if input_dim is None or output_dim is None: + if hasattr(layer, "units"): + output_dim = layer.units + else: + return "dense" + + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + else: + return "dense" + + if not input_dim or not output_dim: + return "dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "dense" + + +def _recursive_layer_traversal( + current_layer, + prefix, + device_count, + state_rules, + output_rules, + processed_layers, +): + """Recursively traverses the model graph to apply sharding rules. + + This function is necessary because Keras Model.layers property does not + recursively find all sub-layers in all architectures. It applies sharding + rules based on layer type and heuristic classification (e.g., up/down + projection). + + - Split Logic: + - 'up_projection' (or general 'dense'): Column-wise sharding + (`Split(..., 1, "column")`) on kernel. Requires output to be + gathered (`gather`). + - 'down_projection' (or attention output): Row-wise sharding + (`Split(..., 0, "row")`) on kernel. Requires output to be + reduced (`allreduce`). + - Embedding: Column-wise sharding (`Split(..., 1, "column")`). + + Args: + current_layer: The Keras layer instance currently being inspected. + prefix: The fully qualified name prefix for the current layer's scope. + device_count: The number of devices (replicas) in the parallelism group. + state_rules: A dictionary to accumulate variable sharding rules + (`LayoutMap.state_rules`). + output_rules: A dictionary to accumulate layer output communication + rules (`LayoutMap.output_rules`). + processed_layers: A set of layer IDs to prevent infinite recursion + in graph structures. + """ + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(current_layer, layers.Dense): + mlp_type = analyze_dense_layer(current_layer) + + if mlp_type == "up_projection": + # Column-wise sharding for the first MLP layer + state_rules[f"{full_name}.kernel"] = Split( + device_count, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"{full_name}.bias"] = Split( + device_count, 0, "column" + ) + # The result needs to be gathered back to a full tensor. + output_rules[f"{full_name}"] = {0: "gather"} + + elif mlp_type == "down_projection": + # Row-wise sharding for the second MLP layer (down-projection) + state_rules[f"{full_name}.kernel"] = Split(device_count, 0, "row") + # Results from different devices needs to be summed (all-reduced). + output_rules[f"{full_name}"] = {0: "allreduce"} + + else: + # Fallback for generic dense layers (treat as column-wise split) + state_rules[f"{full_name}.kernel"] = Split( + device_count, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"{full_name}.bias"] = Split( + device_count, 0, "column" + ) + output_rules[f"{full_name}"] = {0: "gather -1"} + + elif isinstance(current_layer, layers.EinsumDense): + if "attention_output" in full_name: + # Row-wise sharding for the attention output layer + state_rules[f"{full_name}.kernel"] = Split(device_count, 0, "row") + output_rules[f"{full_name}"] = {0: "allreduce"} + else: + # Column-wise sharding for key/query/value projections + state_rules[f"{full_name}.kernel"] = Split( + device_count, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"{full_name}.bias"] = Split( + device_count, 0, "column" + ) + output_rules[f"{full_name}"] = {0: "gather -1"} + + elif isinstance(current_layer, (layers.Embedding,)): + weight_name = None + + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" + + if weight_name: + # Column-wise sharding on the embedding dimension + state_rules[f"{full_name}.{weight_name}"] = Split( + device_count, 1, "column" + ) + # Output requires no communication + output_rules[f"{full_name}"] = {0: "no_comm"} + + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): + pass + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + _recursive_layer_traversal( + sub_layer, + full_name, + device_count, + state_rules, + output_rules, + processed_layers, + ) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + if hasattr(current_layer, attr_name): + attr = getattr(current_layer, attr_name) + + if isinstance(attr, layers.Layer) and attr is not current_layer: + _recursive_layer_traversal( + attr, + full_name, + device_count, + state_rules, + output_rules, + processed_layers, + ) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + _recursive_layer_traversal( + item, + full_name, + device_count, + state_rules, + output_rules, + processed_layers, + ) + + +def get_default_config_keras(module, device_ids): + """Generates a default tensor parallelism sharding configuration. + + This function leverages model-traversal and heuristic layer analysis to + automatically generate sharding rules (for state and layer outputs) + optimized for large-scale language models (Transformers). + + Args: + module: The root Keras `Model` or `Layer` instance representing the + module to be sharded. + device_ids: A list of device identifiers (e.g., strings) that define + the parallelism group. The length of this list determines + the number of slices (`device_count`). + + Returns: + LayoutMap: An object containing the generated `state_rules` (variable + sharding) and `output_rules` (layer communication). + """ + + device_count = len(device_ids) + state_rules = {} + output_rules = {} + + processed_layers = set() + + _recursive_layer_traversal( + current_layer=module, + prefix="", + device_count=device_count, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + ) + + return LayoutMap(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..3e12a0598699 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,134 @@ +from autoconfig import analyze_dense_layer +from autoconfig import get_default_config_keras + +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing + + +class AutoConfigTest(testing.TestCase): + def test_analyze_dense_layer_directly(self): + """Tests the heuristic for classifying Dense layers.""" + + up_proj_layer = layers.Dense(64, name="up") + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual(analyze_dense_layer(up_proj_layer), "up_projection") + down_proj_layer = layers.Dense(16, name="down") + down_proj_layer.build(input_shape=(None, 64)) + self.assertEqual( + analyze_dense_layer(down_proj_layer), + "down_projection", + ) + generic_layer = layers.Dense(32, name="generic") + generic_layer.build(input_shape=(None, 28)) + self.assertEqual(analyze_dense_layer(generic_layer), "dense") + non_dense_layer = layers.LayerNormalization() + self.assertEqual(analyze_dense_layer(non_dense_layer), "dense") + + def test_simple_mlp_model(self): + """Tests rule generation for a standard MLP block.""" + device_count = 2 + devices = [f"gpu:{i}" for i in range(device_count)] + + model = models.Sequential( + [ + layers.Input(shape=(32,)), + layers.Dense(128, name="mlp_up"), + layers.Dense(32, name="mlp_down"), + ], + name="mlp_block", + ) + + layout_map = get_default_config_keras(model, devices) + state_rules = layout_map.state_rules + output_rules = layout_map.output_rules + + up_kernel_key = "mlp_block.mlp_up.kernel" + up_kernel_rule = state_rules[up_kernel_key] + self.assertEqual(up_kernel_rule.device_count, device_count) + self.assertEqual(up_kernel_rule.dim, 1) + + down_kernel_key = "mlp_block.mlp_down.kernel" + down_kernel_rule = state_rules[down_kernel_key] + self.assertEqual(down_kernel_rule.device_count, device_count) + self.assertEqual(down_kernel_rule.dim, 0) + + self.assertEqual(output_rules["mlp_block.mlp_up"], {0: "gather"}) + self.assertEqual(output_rules["mlp_block.mlp_down"], {0: "allreduce"}) + + def test_model_with_embedding_and_einsumdense(self): + """Tests rule generation for Embedding and EinsumDense layers.""" + device_count = 4 + devices = [f"gpu:{i}" for i in range(device_count)] + + class SimpleTransformer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.embedding = layers.Embedding( + input_dim=1000, output_dim=64, name="embedding" + ) + self.qkv_proj = layers.EinsumDense( + "abc,cde->abde", + output_shape=(None, 3, 128), + bias_axes="de", + name="qkv_proj", + ) + self.attention_output = layers.EinsumDense( + "abde,cde->abc", + output_shape=(None, 64), + bias_axes="c", + name="attention_output", + ) + + def call(self, inputs): + x = self.embedding(inputs) + x = self.qkv_proj(x) + x = self.attention_output(x) + return x + + model = SimpleTransformer(name="transformer") + model(ops.zeros((1, 10))) + + layout_map = get_default_config_keras(model, devices) + state_rules = layout_map.state_rules + + expected_key = "transformer.embedding.embeddings" + self.assertIn(expected_key, state_rules) + emb_rule = state_rules[expected_key] + self.assertEqual(emb_rule.device_count, device_count) + self.assertEqual(emb_rule.dim, 1) + + qkv_key = "transformer.qkv_proj.kernel" + qkv_rule = state_rules[qkv_key] + self.assertEqual(qkv_rule.device_count, device_count) + self.assertEqual(qkv_rule.dim, 1) + + attn_out_key = "transformer.attention_output.kernel" + attn_out_rule = state_rules[attn_out_key] + self.assertEqual(attn_out_rule.device_count, device_count) + self.assertEqual(attn_out_rule.dim, 0) + + def test_nested_model(self): + """Tests that the recursive traversal finds layers in nested models.""" + device_count = 2 + devices = [f"gpu:{i}" for i in range(device_count)] + inner_model = models.Sequential( + [layers.Dense(64, name="inner_dense")], name="inner_block" + ) + outer_model = models.Sequential( + [ + layers.Input(shape=(32,)), + layers.Dense(32, name="outer_dense_1"), + inner_model, + ], + name="outer_block", + ) + layout_map = get_default_config_keras(outer_model, devices) + state_rules = layout_map.state_rules + + expected_key = "outer_block.inner_block.inner_dense.kernel" + self.assertIn(expected_key, state_rules) + inner_rule = state_rules[expected_key] + self.assertEqual(inner_rule.device_count, device_count) + self.assertEqual(inner_rule.dim, 1) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..9a56890f116b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,682 @@ +import re + +import numpy as np + +from keras.src import ops +from keras.src import optimizers +from keras.src import saving +from keras.src.backend import core +from keras.src.backend import distribution_lib + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. + It is the core logic engine for the `TensorParallelOptimizer`. + + Args: + base_optimizer: The Keras optimizer instance. + device_count: The total number of devices/processes in the distributed + setup. + shard_optimizer_states: If `True`, the optimizer's state variables + will be partitioned across `device_count` devices. Defaults to + `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer, + device_count, + shard_optimizer_states=True, + tensor_parallel_config=None, + ): + self.base_optimizer = base_optimizer + self.device_count = device_count + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._slot_path_to_parameter = {} + self._variable_to_slot_name = {} + self._variables = None + + def _initialize_sharded_states(self): + """Partitions the optimizer's state variables across shards. + + This method inspects the variables created by the base optimizer (slots + like 'momentum', 'velocity', etc.) and partitions them across the + available devices if `shard_optimizer_states` is True. + + It relies on the pre-calculated mapping between a model + parameter and its corresponding optimizer state variable to + determine the sharding dimension from the tensor parallelism rules. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + found_param = self._slot_path_to_parameter.get(state_var.path, None) + slot_name = self._variable_to_slot_name.get(state_var.path, None) + + if found_param is not None and slot_name is not None: + sharding_dim = 0 + if self.tensor_parallel_config: + norm_param_name = found_param.path.replace("/", ".") + for ( + p, + a, + ) in self.tensor_parallel_config.state_rules.items(): + if re.search(p, norm_param_name) and hasattr(a, "dim"): + sharding_dim = a.dim + break + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + + def _partition_state(self, state_variable, dim): + """Splits a single state variable numpy array into chunks for sharding. + + Splits the variable's array along the specified dimension (`dim`) for + all available devices. If the variable's shape is incompatible with + sharding, it is replicated across all devices instead. + + Args: + state_variable: The Keras variable representing the optimizer state. + dim: The dimension along which to split the state array. + + Returns: + A list of numpy arrays, where each element is a shard of the + original state array. + """ + state_array = ops.convert_to_numpy(state_variable) + if ( + state_array.ndim > dim + and state_array.shape[dim] >= self.device_count + ): + return np.array_split(state_array, self.device_count, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.device_count)] + + def apply_gradients(self, gradients_and_vars, shard_models): + """Coordinates gradient synchronization and application. + + First, synchronizes the gradients across all shards based on tensor + parallelism rules (if any). Then, applies the synchronized gradients + using either sharded or replicated optimizer states. + + Args: + gradients_and_vars: A list of lists, where each inner list + contains `(gradient, variable)` tuples for a single device. + shard_models: A list of Keras models, one for each shard, which + contain the local base optimizers. + + Raises: + ValueError: If the number of gradient sets does not match the + configured device count. + """ + if len(gradients_and_vars) != self.device_count: + raise ValueError( + f"Expected {self.device_count} sets of gradients, " + f"but received {len(gradients_and_vars)}." + ) + + synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + + if self.shard_optimizer_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients, shard_models + ): + """Averages gradients across all shards and applies them once. + + Used when optimizer states are *not* sharded + (`shard_optimizer_states=False`). The gradients are averaged across + all replicas and then applied to the single, non-sharded base + optimizer. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients, shard_models + ): + """Applies gradients to each shard using its local optimizer state. + + Used when optimizer states are sharded (`shard_optimizer_states=True`). + Iterates over each shard: + 1. Retrieves the local shard of the optimizer state. + 2. Assigns the local state to the shard's base optimizer. + 3. Applies the synchronized gradients specific to that shard. + 4. Updates the global `self.sharded_states` with the new local state. + """ + for shard_idx in range(self.device_count): + local_states = self._get_local_optimizer_states(shard_idx) + # Access the base optimizer inside TensorParallelOptimizer wrapper + shard_optimizer = shard_models[shard_idx].optimizer.base_optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) + + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) + + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx): + """Constructs the state dictionary for a single shard. + + Extracts the `shard_idx`-th chunk of all partitioned state variables + (e.g., 'momentum', 'velocity') to create the complete local state + dictionary required by the optimizer on that specific device. + + Args: + shard_idx: The index of the device/shard. + + Returns: + A dictionary containing the local optimizer state for the shard. + """ + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + + def _update_optimizer_internal_state(self, optimizer, local_states): + """Assigns local sharded state values to the optimizer's variables. + + This effectively loads the local state partition into the base + optimizer's internal Keras `Variable` objects before a gradient update. + + Args: + optimizer: The base Keras optimizer instance for the current shard. + local_states: The local optimizer state dictionary for the shard. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + if "iterations" in local_states: + var.assign(local_states["iterations"]) + continue + + param = self._slot_path_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + var.assign(local_param_state) + + def _update_global_sharded_states(self, optimizer, shard_idx): + """Updates the main sharded_states dictionary after a gradient step. + + After the base optimizer has applied gradients and updated its internal + state variables (like momentum), this method copies those updated values + back into the global `self.sharded_states` at the `shard_idx` partition. + + Args: + optimizer: The base Keras optimizer instance for the current shard. + shard_idx: The index of the device/shard. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._slot_path_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + + def _synchronize_gradients(self, gradients_and_vars): + """Synchronizes gradients across shards based on tensor parallel rules. + + Gradients corresponding to parameters that are column-parallel are + aggregated (all-reduced) across devices to ensure all devices see the + full, accumulated gradient for that parameter. + + Args: + gradients_and_vars: A list of lists, where each inner list contains + `(gradient, variable)` tuples for a single device/shard. + + Returns: + The list of lists of gradients and variables, with column-parallel + gradients replaced by their synchronized (all-reduced) version. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + rules = self.tensor_parallel_config.state_rules.items() + column_parallel_patterns = { + pattern + for pattern, action in rules + if hasattr(action, "sharding_type") + and action.sharding_type == "column" + } + + if not column_parallel_patterns: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + var_name = getattr(variable, "path", getattr(variable, "name", "")) + + if any( + re.search(pattern, var_name) + for pattern in column_parallel_patterns + ): + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.device_count): + if gradients_and_vars[shard_idx][i][0] is not None: + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients): + """Performs a mean all-reduce operation on a list of gradients. + + If multiple devices are detected by the backend, this uses the efficient + on-device communication primitive (`core.all_reduce`) to + average gradients. Otherwise, it performs the mean reduction on the host + using Keras backend operations. + + Args: + gradients: A list of gradient tensors, one from each device. + + Returns: + A list containing the mean-reduced gradient, replicated for + all devices. + """ + if not gradients: + return [] + + if distribution_lib.get_device_count() > 1: + local_grad = gradients[0] + synced_tensor = core.all_reduce( + local_grad, op="mean", axis_name="model" + ) + + return [synced_tensor for _ in range(self.device_count)] + + if len(gradients) == 1: + mean_grad = gradients[0] + else: + stacked_grads = ops.stack(gradients, axis=0) + mean_grad = ops.mean(stacked_grads, axis=0) + + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self): + """Returns the weights of the base optimizer. + + Delegates to the coordinated optimizer which handles sharded/unsharded + state. + + Returns: + A list of numpy arrays representing the optimizer's state variables. + """ + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] + + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Args: + weights: A list of numpy arrays to assign to the optimizer's + state variables. + """ + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables): + """Enables and initializes optimizer state sharding. + + This method is typically called after the model variables have been + created and the base optimizer has been built. It establishes the + necessary mapping between optimizer state variables (slots) and model + parameters. + + Args: + variables: The list of model variables/parameters. + """ + self.shard_optimizer_states = True + self._variables = variables + + self._slot_path_to_parameter = {} + self._variable_to_slot_name = {} + + for slot_var in self.base_optimizer.variables: + if slot_var is self.base_optimizer.iterations: + continue + + slot_path_clean = slot_var.path.replace("/", "_") + + best_match_param = None + + for param in variables: + param_path_clean = param.path.replace("/", "_") + + if param_path_clean in slot_path_clean: + if best_match_param is None or len(param_path_clean) > len( + best_match_param.path.replace("/", "_") + ): + best_match_param = param + + if best_match_param is not None: + path_without_suffix = slot_var.path.rsplit(":", 1)[0] + slot_name = path_without_suffix.rsplit("/", 1)[-1] + + self._slot_path_to_parameter[slot_var.path] = best_match_param + self._variable_to_slot_name[slot_var.path] = slot_name + + self._initialize_sharded_states() + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This class serves as the public interface (inherits + `optimizers.Optimizer`). It delegates the complex tasks of state + management, gradient synchronization, and sharding to the internal + `CoordinatedOptimizer` instance. This separation adheres to the + principle of keeping the public API clean while encapsulating complex + distribution logic. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier. + device_count: The total number of devices/processes in the distributed + setup. + tensor_parallel_config: An optional configuration object. Defaults to + `None`. + """ + + def __init__( + self, + base_optimizer, + device_count, + tensor_parallel_config=None, + ): + if isinstance(base_optimizer, str): + base_optimizer_instance = optimizers.get(base_optimizer) + else: + base_optimizer_instance = base_optimizer + + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) + else: + lr_value = float(ops.convert_to_numpy(learning_rate)) + + super().__init__( + learning_rate=lr_value, + name=f"TensorParallel_{base_optimizer_instance.name}", + ) + + self.base_optimizer = base_optimizer_instance + self.device_count = device_count + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + device_count, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars, **kwargs): + """Applies gradients to the model variables. + + This method acts as a dispatcher: + - If `grads_and_vars` is a list of lists (sharded gradients), it + delegates to the internal `CoordinatedOptimizer` for synchronization + and sharded updates. + - Otherwise (non-sharded gradients), it delegates directly to the + `base_optimizer`. + + Args: + grads_and_vars: A list of `(gradient, variable)` tuples, or a list + of such lists if gradients are sharded across devices. + **kwargs: Must include `shard_models` if sharded gradients are + provided. + + Raises: + ValueError: If sharded gradients are provided without the + `shard_models` keyword argument. + """ + is_sharded_grads = ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ) + if is_sharded_grads: + if "shard_models" not in kwargs: + raise ValueError( + "The `shard_models` keyword argument is required when " + "applying sharded gradients (a list of lists)." + ) + shard_models = kwargs.get("shard_models") + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars) + + def update_step(self, gradient, variable, *args, **kwargs): + """Delegates the single-variable update step to the base optimizer. + + This ensures that when called on a single gradient/variable pair (e.g., + in a non-distributed context), the base optimizer's logic is used. + + Args: + gradient: The gradient tensor. + variable: The model variable to update. + *args: Positional arguments for the base optimizer's update step. + **kwargs: Keyword arguments for the base optimizer's update step. + + Returns: + The result of the base optimizer's `update_step`. + """ + if hasattr(self.base_optimizer, "update_step"): + try: + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + except TypeError: + return self.base_optimizer.update_step(gradient, variable) + + try: + return super().update_step(gradient, variable, *args, **kwargs) + except TypeError: + return super().update_step(gradient, variable) + + def get_config(self): + config = super().get_config() + + config["base_optimizer"] = saving.serialize_keras_object( + self.base_optimizer + ) + + config["device_count"] = self.device_count + config["tensor_parallel_config"] = ( + self.coordinated_optimizer.tensor_parallel_config + ) + + return config + + @classmethod + def from_config(cls, config): + """Creates an optimizer instance from its configuration. + + This static method is used during deserialization/loading. + + Args: + config: The optimizer configuration dictionary. + + Returns: + A new `TensorParallelOptimizer` instance. + """ + from keras.src import saving + + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object(base_optimizer_config) + + init_kwargs = { + "device_count": config.get("device_count"), + "tensor_parallel_config": config.get("tensor_parallel_config"), + } + + config.pop("device_count", None) + config.pop("tensor_parallel_config", None) + + return cls(base_optimizer=base_optimizer, **init_kwargs) + + def build(self, variables): + """Builds the optimizer and initializes sharded states. + + This method first builds the internal base optimizer, then performs a + dummy `apply_gradients` call with zero gradients to force the base + optimizer to create its state variables (slots). Finally, it instructs + the `CoordinatedOptimizer` to initialize sharding of those state + variables. + + Args: + variables: The list of model variables/parameters. + """ + if self.built: + return + + self.base_optimizer.build(variables) + if variables: + iterations = self.base_optimizer.iterations + original_iterations_val = None + if iterations is not None: + original_iterations_val = ops.convert_to_numpy(iterations.value) + + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + if iterations is not None and original_iterations_val is not None: + iterations.assign(original_iterations_val) + + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self): + """Returns the weights of the base optimizer. + + Delegates to the coordinated optimizer which handles sharded/unsharded + state. + + Returns: + A list of numpy arrays representing the optimizer's state variables. + """ + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights): + """Sets the weights of the base optimizer. + + Delegates to the coordinated optimizer to set the weights, which + updates the underlying state variables. + + Args: + weights: A list of numpy arrays to assign to the optimizer's + state variables. + """ + self.coordinated_optimizer.set_weights(weights) + + @property + def variables(self): + """Returns the list of variables from the base optimizer.""" + return self.base_optimizer.variables + + @property + def learning_rate(self): + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, value): + """Sets the learning rate of the base optimizer.""" + self.base_optimizer.learning_rate = value + + @property + def iterations(self): + """Returns the training iteration count directly from base optimizer.""" + return self.base_optimizer.iterations diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..47b2e207e02d --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,190 @@ +import numpy as np +import pytest + +import keras +from keras import ops +from keras.src import backend +from keras.src import optimizers +from keras.src import testing +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, device_count): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(device_count): + multiplier = float(i + 1) + gradients = [ + ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer(base_optimizer, device_count=4) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + super().apply_gradients(grads_and_vars, *args, **kwargs) + + device_count = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, device_count) + + coord = CoordinatedOptimizer( + optimizer, + device_count, + shard_optimizer_states=False, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) + self.assertAllClose( + grad_numpy, + np.ones_like(grad_numpy) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer("adam", device_count=4) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + device_count = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer(base_opt, device_count) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, device_count) + + coord_apply_tracker = {"called": False} + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + + base_apply_tracker = {"called": False} + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer(optimizers.Adam(), device_count=4) + model = self._get_simple_model() + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + + self.assertTrue( + any("momentum" in key for key in sharded_states), + msg="Momentum slot not found in sharded_states keys.", + ) + self.assertTrue( + any("velocity" in key for key in sharded_states), + msg="Velocity slot not found in sharded_states keys.", + ) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + + momentum_slot_key = "dense_1_kernel_momentum" + + self.assertIn(dense_1_kernel_path, sharded_states[momentum_slot_key]) + self.assertEqual( + len(sharded_states[momentum_slot_key][dense_1_kernel_path]), 4 + ) + + def test_serialization(self): + device_count = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + + optimizer = TensorParallelOptimizer(base_opt, device_count) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.device_count, device_count) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), device_count=2) + optimizer.build(model.trainable_variables) + + state_to_param = optimizer.coordinated_optimizer._slot_path_to_parameter + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + + found_slot_path = None + for slot_path, param in state_to_param.items(): + if param is dense_output_kernel: + found_slot_path = slot_path + break + + self.assertIsNotNone(found_slot_path) + self.assertIs(state_to_param[found_slot_path], dense_output_kernel)