Skip to content

Commit d5ecf00

Browse files
edwardzhou130recml authors
authored and
recml authors
committed
Add options to load ckpt with optimizer_variables/ training steps/ lr separately.
PiperOrigin-RevId: 745717557
1 parent 693ee85 commit d5ecf00

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

recml/core/training/keras_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
"""Keras task and trainer."""
1515

16+
from __future__ import annotations
17+
1618
import abc
1719
from collections.abc import Mapping
1820
import gc

recml/core/utils/keras_utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def restore_keras_model(
183183
checkpoint_dir: str,
184184
step: int | None = None,
185185
restore_optimizer_vars: bool = True,
186+
restore_steps: bool = True,
187+
restore_iterations: bool = True,
186188
):
187189
"""Restores a Keras 3 Jax backend model from an Orbax checkpoint.
188190
@@ -192,6 +194,14 @@ def restore_keras_model(
192194
step: The step to restore the model to. If `None` then the latest checkpoint
193195
will be restored.
194196
restore_optimizer_vars: Whether to restore the optimizer variables.
197+
restore_steps: Whether to restore the model's steps. If `True` then the
198+
model will continue training from the step the checkpoint was saved at. If
199+
`False` then the model will start training from the first step.
200+
restore_iterations: Whether to restore the model's iterations. If `True`
201+
then the model will continue training from the iteration the checkpoint
202+
was saved at. This is an optimizer variable used for controlling the
203+
learning rate schedule. This is not supported if restore_optimizer_vars
204+
is `False`.
195205
196206
Raises:
197207
FileNotFoundError: If no checkpoints are found in the checkpoint directory.
@@ -273,10 +283,13 @@ def restore_keras_model(
273283
"non_trainable_variables": non_trainable_variables,
274284
}
275285
if restore_optimizer_vars:
276-
model._initial_epoch = step + 1 # pylint: disable=protected-access
277286
optimizer_variables = restored_state[2]
278287
model._jax_state["optimizer_variables"] = optimizer_variables # pylint: disable=protected-access
279288
model.jax_state_sync()
289+
if restore_steps:
290+
model._initial_epoch = step + 1 # pylint: disable=protected-access
291+
if restore_optimizer_vars and not restore_iterations:
292+
model.optimizer.iterations.assign(0)
280293

281294

282295
# TODO(b/343544467): Support logging metrics more frequently.

recml/core/utils/keras_utils_test.py

+114-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
import numpy as np
2626
from recml.core.utils import keras_utils
2727

28+
_LEARNING_RATE_SCHEDULE = keras.optimizers.schedules.PolynomialDecay(
29+
initial_learning_rate=0.1,
30+
decay_steps=100,
31+
end_learning_rate=0.01,
32+
power=1.0,
33+
)
34+
2835

2936
def _create_model(input_shapes: Sequence[int]) -> keras.Model:
3037
model = keras_hub.models.BertMaskedLM(
@@ -39,7 +46,7 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model:
3946
dropout=0.1,
4047
)
4148
)
42-
optimizer = keras.optimizers.Adam(learning_rate=0.1)
49+
optimizer = keras.optimizers.Adam(learning_rate=_LEARNING_RATE_SCHEDULE)
4350
loss = keras.losses.SparseCategoricalCrossentropy()
4451
metrics = [keras.metrics.SparseCategoricalAccuracy()]
4552
model.compile(optimizer, loss, weighted_metrics=metrics)
@@ -242,6 +249,112 @@ def test_metrics_variables_checkpointing(
242249
)
243250
self.assertSequenceEqual(w1.dtype, w2.dtype)
244251

252+
@parameterized.named_parameters(
253+
{
254+
"testcase_name": "restore_all_variables",
255+
"restore_optimizer_vars": True,
256+
"restore_steps": True,
257+
"restore_iterations": True,
258+
"expected_learning_rate": 0.01,
259+
"expected_iterations": 100,
260+
"expected_initial_epoch": 2,
261+
},
262+
{
263+
"testcase_name": "restore_without_optimizer_vars",
264+
"restore_optimizer_vars": False,
265+
"restore_steps": True,
266+
"restore_iterations": True,
267+
"expected_learning_rate": 0.1,
268+
"expected_iterations": 0,
269+
"expected_initial_epoch": 2,
270+
},
271+
{
272+
"testcase_name": "restore_without_steps",
273+
"restore_optimizer_vars": True,
274+
"restore_steps": False,
275+
"restore_iterations": True,
276+
"expected_learning_rate": 0.01,
277+
"expected_iterations": 100,
278+
"expected_initial_epoch": None,
279+
},
280+
{
281+
"testcase_name": "restore_without_iterations",
282+
"restore_optimizer_vars": True,
283+
"restore_steps": True,
284+
"restore_iterations": False,
285+
"expected_learning_rate": 0.1,
286+
"expected_iterations": 0,
287+
"expected_initial_epoch": 2,
288+
},
289+
{
290+
"testcase_name": "restore_only_model_variables",
291+
"restore_optimizer_vars": False,
292+
"restore_steps": False,
293+
"restore_iterations": False,
294+
"expected_learning_rate": 0.1,
295+
"expected_iterations": 0,
296+
"expected_initial_epoch": None,
297+
},
298+
)
299+
def test_restore_keras_model_with_different_options(
300+
self,
301+
restore_optimizer_vars: bool,
302+
restore_steps: bool,
303+
restore_iterations: bool,
304+
expected_learning_rate: float,
305+
expected_iterations: int,
306+
expected_initial_epoch: int | None,
307+
):
308+
checkpoint_dir = self.create_tempdir().full_path
309+
checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir)
310+
epoch = 1
311+
dummy_inputs = {
312+
"token_ids": jax.random.randint(
313+
jax.random.key(0), (64, 128), minval=0, maxval=50_000
314+
),
315+
"segment_ids": jax.random.randint(
316+
jax.random.key(0), (64, 128), minval=0, maxval=7
317+
),
318+
"padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)),
319+
"mask_positions": jax.random.randint(
320+
jax.random.key(0), (64, 20), minval=0, maxval=128
321+
),
322+
}
323+
324+
source_bert_pretrainer = _create_model(
325+
jax.tree.map(jnp.shape, dummy_inputs)
326+
)
327+
source_bert_pretrainer.optimizer.iterations.assign(100)
328+
source_state = source_bert_pretrainer._get_jax_state( # pylint: disable=protected-access
329+
trainable_variables=True,
330+
non_trainable_variables=True,
331+
optimizer_variables=True,
332+
)
333+
checkpointer.save(step=epoch, items=source_state)
334+
checkpointer.wait_until_finished()
335+
336+
target_bert_pretrainer = _create_model(
337+
jax.tree.map(jnp.shape, dummy_inputs)
338+
)
339+
keras_utils.restore_keras_model(
340+
target_bert_pretrainer,
341+
checkpoint_dir,
342+
restore_optimizer_vars=restore_optimizer_vars,
343+
restore_steps=restore_steps,
344+
restore_iterations=restore_iterations,
345+
)
346+
347+
self.assertEqual(
348+
target_bert_pretrainer.optimizer.iterations.value, expected_iterations
349+
)
350+
self.assertEqual(
351+
target_bert_pretrainer.optimizer.learning_rate,
352+
expected_learning_rate,
353+
)
354+
self.assertEqual(
355+
target_bert_pretrainer._initial_epoch, expected_initial_epoch
356+
)
357+
245358

246359
if __name__ == "__main__":
247360
absltest.main()

0 commit comments

Comments
 (0)