From f5e68e217811a5df2cd7ab3bf7118b9165fc2755 Mon Sep 17 00:00:00 2001 From: Zixiang Zhou Date: Tue, 6 May 2025 12:47:01 -0700 Subject: [PATCH] Add options to load ckpt with optimizer_variables/ training steps/ lr separately. PiperOrigin-RevId: 755482380 --- recml/core/training/keras_trainer.py | 2 + recml/core/utils/keras_utils.py | 15 +++- recml/core/utils/keras_utils_test.py | 115 ++++++++++++++++++++++++++- recml/examples/dlrm_experiment.py | 8 +- 4 files changed, 136 insertions(+), 4 deletions(-) diff --git a/recml/core/training/keras_trainer.py b/recml/core/training/keras_trainer.py index 2bb4c49..0fc6416 100644 --- a/recml/core/training/keras_trainer.py +++ b/recml/core/training/keras_trainer.py @@ -13,6 +13,8 @@ # limitations under the License. """Keras task and trainer.""" +from __future__ import annotations + import abc from collections.abc import Mapping import gc diff --git a/recml/core/utils/keras_utils.py b/recml/core/utils/keras_utils.py index 4e6b103..52c6381 100644 --- a/recml/core/utils/keras_utils.py +++ b/recml/core/utils/keras_utils.py @@ -183,6 +183,8 @@ def restore_keras_model( checkpoint_dir: str, step: int | None = None, restore_optimizer_vars: bool = True, + restore_steps: bool = True, + restore_iterations: bool = True, ): """Restores a Keras 3 Jax backend model from an Orbax checkpoint. @@ -192,6 +194,14 @@ def restore_keras_model( step: The step to restore the model to. If `None` then the latest checkpoint will be restored. restore_optimizer_vars: Whether to restore the optimizer variables. + restore_steps: Whether to restore the model's steps. If `True` then the + model will continue training from the step the checkpoint was saved at. If + `False` then the model will start training from the first step. + restore_iterations: Whether to restore the model's iterations. If `True` + then the model will continue training from the iteration the checkpoint + was saved at. This is an optimizer variable used for controlling the + learning rate schedule. This is not supported if restore_optimizer_vars + is `False`. Raises: FileNotFoundError: If no checkpoints are found in the checkpoint directory. @@ -273,10 +283,13 @@ def restore_keras_model( "non_trainable_variables": non_trainable_variables, } if restore_optimizer_vars: - model._initial_epoch = step + 1 # pylint: disable=protected-access optimizer_variables = restored_state[2] model._jax_state["optimizer_variables"] = optimizer_variables # pylint: disable=protected-access model.jax_state_sync() + if restore_steps: + model._initial_epoch = step + 1 # pylint: disable=protected-access + if restore_optimizer_vars and not restore_iterations: + model.optimizer.iterations.assign(0) # TODO(b/343544467): Support logging metrics more frequently. diff --git a/recml/core/utils/keras_utils_test.py b/recml/core/utils/keras_utils_test.py index cdf92fc..010707a 100644 --- a/recml/core/utils/keras_utils_test.py +++ b/recml/core/utils/keras_utils_test.py @@ -25,6 +25,13 @@ import numpy as np from recml.core.utils import keras_utils +_LEARNING_RATE_SCHEDULE = keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=0.1, + decay_steps=100, + end_learning_rate=0.01, + power=1.0, +) + def _create_model(input_shapes: Sequence[int]) -> keras.Model: model = keras_hub.models.BertMaskedLM( @@ -39,7 +46,7 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model: dropout=0.1, ) ) - optimizer = keras.optimizers.Adam(learning_rate=0.1) + optimizer = keras.optimizers.Adam(learning_rate=_LEARNING_RATE_SCHEDULE) loss = keras.losses.SparseCategoricalCrossentropy() metrics = [keras.metrics.SparseCategoricalAccuracy()] model.compile(optimizer, loss, weighted_metrics=metrics) @@ -242,6 +249,112 @@ def test_metrics_variables_checkpointing( ) self.assertSequenceEqual(w1.dtype, w2.dtype) + @parameterized.named_parameters( + { + "testcase_name": "restore_all_variables", + "restore_optimizer_vars": True, + "restore_steps": True, + "restore_iterations": True, + "expected_learning_rate": 0.01, + "expected_iterations": 100, + "expected_initial_epoch": 2, + }, + { + "testcase_name": "restore_without_optimizer_vars", + "restore_optimizer_vars": False, + "restore_steps": True, + "restore_iterations": True, + "expected_learning_rate": 0.1, + "expected_iterations": 0, + "expected_initial_epoch": 2, + }, + { + "testcase_name": "restore_without_steps", + "restore_optimizer_vars": True, + "restore_steps": False, + "restore_iterations": True, + "expected_learning_rate": 0.01, + "expected_iterations": 100, + "expected_initial_epoch": None, + }, + { + "testcase_name": "restore_without_iterations", + "restore_optimizer_vars": True, + "restore_steps": True, + "restore_iterations": False, + "expected_learning_rate": 0.1, + "expected_iterations": 0, + "expected_initial_epoch": 2, + }, + { + "testcase_name": "restore_only_model_variables", + "restore_optimizer_vars": False, + "restore_steps": False, + "restore_iterations": False, + "expected_learning_rate": 0.1, + "expected_iterations": 0, + "expected_initial_epoch": None, + }, + ) + def test_restore_keras_model_with_different_options( + self, + restore_optimizer_vars: bool, + restore_steps: bool, + restore_iterations: bool, + expected_learning_rate: float, + expected_iterations: int, + expected_initial_epoch: int | None, + ): + checkpoint_dir = self.create_tempdir().full_path + checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir) + epoch = 1 + dummy_inputs = { + "token_ids": jax.random.randint( + jax.random.key(0), (64, 128), minval=0, maxval=50_000 + ), + "segment_ids": jax.random.randint( + jax.random.key(0), (64, 128), minval=0, maxval=7 + ), + "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), + "mask_positions": jax.random.randint( + jax.random.key(0), (64, 20), minval=0, maxval=128 + ), + } + + source_bert_pretrainer = _create_model( + jax.tree.map(jnp.shape, dummy_inputs) + ) + source_bert_pretrainer.optimizer.iterations.assign(100) + source_state = source_bert_pretrainer._get_jax_state( # pylint: disable=protected-access + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + ) + checkpointer.save(step=epoch, items=source_state) + checkpointer.wait_until_finished() + + target_bert_pretrainer = _create_model( + jax.tree.map(jnp.shape, dummy_inputs) + ) + keras_utils.restore_keras_model( + target_bert_pretrainer, + checkpoint_dir, + restore_optimizer_vars=restore_optimizer_vars, + restore_steps=restore_steps, + restore_iterations=restore_iterations, + ) + + self.assertEqual( + target_bert_pretrainer.optimizer.iterations.value, expected_iterations + ) + self.assertEqual( + target_bert_pretrainer.optimizer.learning_rate, + expected_learning_rate, + ) + self.assertEqual( + target_bert_pretrainer._initial_epoch, expected_initial_epoch + ) + if __name__ == "__main__": absltest.main() diff --git a/recml/examples/dlrm_experiment.py b/recml/examples/dlrm_experiment.py index 53c5cde..ade10da 100644 --- a/recml/examples/dlrm_experiment.py +++ b/recml/examples/dlrm_experiment.py @@ -73,10 +73,14 @@ def __post_init__(self): ) def dense_features(self) -> FeatureSet[DenseFeature]: - return FeatureSet([f for f in self if isinstance(f, DenseFeature)]) + return FeatureSet[DenseFeature]( + [f for f in self if isinstance(f, DenseFeature)] + ) def sparse_features(self) -> FeatureSet[SparseFeature]: - return FeatureSet([f for f in self if isinstance(f, SparseFeature)]) + return FeatureSet[SparseFeature]( + [f for f in self if isinstance(f, SparseFeature)] + ) def __iter__(self) -> Iterator[FeatureT]: return iter(self.features)