25
25
import numpy as np
26
26
from recml .core .utils import keras_utils
27
27
28
+ _LEARNING_RATE_SCHEDULE = keras .optimizers .schedules .PolynomialDecay (
29
+ initial_learning_rate = 0.1 ,
30
+ decay_steps = 100 ,
31
+ end_learning_rate = 0.01 ,
32
+ power = 1.0 ,
33
+ )
34
+
28
35
29
36
def _create_model (input_shapes : Sequence [int ]) -> keras .Model :
30
37
model = keras_hub .models .BertMaskedLM (
@@ -39,7 +46,7 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model:
39
46
dropout = 0.1 ,
40
47
)
41
48
)
42
- optimizer = keras .optimizers .Adam (learning_rate = 0.1 )
49
+ optimizer = keras .optimizers .Adam (learning_rate = _LEARNING_RATE_SCHEDULE )
43
50
loss = keras .losses .SparseCategoricalCrossentropy ()
44
51
metrics = [keras .metrics .SparseCategoricalAccuracy ()]
45
52
model .compile (optimizer , loss , weighted_metrics = metrics )
@@ -242,6 +249,112 @@ def test_metrics_variables_checkpointing(
242
249
)
243
250
self .assertSequenceEqual (w1 .dtype , w2 .dtype )
244
251
252
+ @parameterized .named_parameters (
253
+ {
254
+ "testcase_name" : "restore_all_variables" ,
255
+ "restore_optimizer_vars" : True ,
256
+ "restore_steps" : True ,
257
+ "restore_iterations" : True ,
258
+ "expected_learning_rate" : 0.01 ,
259
+ "expected_iterations" : 100 ,
260
+ "expected_initial_epoch" : 2 ,
261
+ },
262
+ {
263
+ "testcase_name" : "restore_without_optimizer_vars" ,
264
+ "restore_optimizer_vars" : False ,
265
+ "restore_steps" : True ,
266
+ "restore_iterations" : True ,
267
+ "expected_learning_rate" : 0.1 ,
268
+ "expected_iterations" : 0 ,
269
+ "expected_initial_epoch" : 2 ,
270
+ },
271
+ {
272
+ "testcase_name" : "restore_without_steps" ,
273
+ "restore_optimizer_vars" : True ,
274
+ "restore_steps" : False ,
275
+ "restore_iterations" : True ,
276
+ "expected_learning_rate" : 0.01 ,
277
+ "expected_iterations" : 100 ,
278
+ "expected_initial_epoch" : None ,
279
+ },
280
+ {
281
+ "testcase_name" : "restore_without_iterations" ,
282
+ "restore_optimizer_vars" : True ,
283
+ "restore_steps" : True ,
284
+ "restore_iterations" : False ,
285
+ "expected_learning_rate" : 0.1 ,
286
+ "expected_iterations" : 0 ,
287
+ "expected_initial_epoch" : 2 ,
288
+ },
289
+ {
290
+ "testcase_name" : "restore_only_model_variables" ,
291
+ "restore_optimizer_vars" : False ,
292
+ "restore_steps" : False ,
293
+ "restore_iterations" : False ,
294
+ "expected_learning_rate" : 0.1 ,
295
+ "expected_iterations" : 0 ,
296
+ "expected_initial_epoch" : None ,
297
+ },
298
+ )
299
+ def test_restore_keras_model_with_different_options (
300
+ self ,
301
+ restore_optimizer_vars : bool ,
302
+ restore_steps : bool ,
303
+ restore_iterations : bool ,
304
+ expected_learning_rate : float ,
305
+ expected_iterations : int ,
306
+ expected_initial_epoch : int | None ,
307
+ ):
308
+ checkpoint_dir = self .create_tempdir ().full_path
309
+ checkpointer = keras_utils .KerasOrbaxCheckpointManager (checkpoint_dir )
310
+ epoch = 1
311
+ dummy_inputs = {
312
+ "token_ids" : jax .random .randint (
313
+ jax .random .key (0 ), (64 , 128 ), minval = 0 , maxval = 50_000
314
+ ),
315
+ "segment_ids" : jax .random .randint (
316
+ jax .random .key (0 ), (64 , 128 ), minval = 0 , maxval = 7
317
+ ),
318
+ "padding_mask" : jax .random .uniform (jax .random .key (0 ), (64 , 128 )),
319
+ "mask_positions" : jax .random .randint (
320
+ jax .random .key (0 ), (64 , 20 ), minval = 0 , maxval = 128
321
+ ),
322
+ }
323
+
324
+ source_bert_pretrainer = _create_model (
325
+ jax .tree .map (jnp .shape , dummy_inputs )
326
+ )
327
+ source_bert_pretrainer .optimizer .iterations .assign (100 )
328
+ source_state = source_bert_pretrainer ._get_jax_state ( # pylint: disable=protected-access
329
+ trainable_variables = True ,
330
+ non_trainable_variables = True ,
331
+ optimizer_variables = True ,
332
+ )
333
+ checkpointer .save (step = epoch , items = source_state )
334
+ checkpointer .wait_until_finished ()
335
+
336
+ target_bert_pretrainer = _create_model (
337
+ jax .tree .map (jnp .shape , dummy_inputs )
338
+ )
339
+ keras_utils .restore_keras_model (
340
+ target_bert_pretrainer ,
341
+ checkpoint_dir ,
342
+ restore_optimizer_vars = restore_optimizer_vars ,
343
+ restore_steps = restore_steps ,
344
+ restore_iterations = restore_iterations ,
345
+ )
346
+
347
+ self .assertEqual (
348
+ target_bert_pretrainer .optimizer .iterations .value , expected_iterations
349
+ )
350
+ self .assertEqual (
351
+ target_bert_pretrainer .optimizer .learning_rate ,
352
+ expected_learning_rate ,
353
+ )
354
+ self .assertEqual (
355
+ target_bert_pretrainer ._initial_epoch , expected_initial_epoch
356
+ )
357
+
245
358
246
359
if __name__ == "__main__" :
247
360
absltest .main ()
0 commit comments