1515"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1616
1717import time
18- from typing import Any
18+ from typing import Any , Optional
1919
2020from absl import flags
2121from etils import epath
2222from flax .training import train_state
23- import grain .python as grain
2423import jax
2524from MaxText import exceptions
2625from MaxText import max_logging
2726from MaxText .globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
2827from MaxText .multihost_dataloading import MultiHostDataLoadIterator
28+ from MaxText .input_pipeline .input_pipeline_interface import PlaceHolderDataIterator
2929import numpy as np
3030import orbax .checkpoint as ocp
3131from orbax .checkpoint import v1 as ocp_v1
3232from orbax .checkpoint ._src .arrays import sharding as sharding_utils
3333import orbax .checkpoint .experimental .emergency .checkpoint_manager as emergency_checkpoint_manager
3434import orbax .checkpoint .experimental .emergency .replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3535# pylint: disable=too-many-positional-arguments
36+ import dataclasses
37+ import json
38+
39+ import grain
40+ from grain .python import PyGrainCheckpointHandler
3641
3742CheckpointManager = ocp .CheckpointManager
3843CheckpointManagerOptions = ocp .CheckpointManagerOptions
4449EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager .ReplicatorCheckpointManager
4550
4651
52+ class GrainCheckpointHandler (PyGrainCheckpointHandler , ocp .CheckpointHandler ):
53+ """A CheckpointHandler that allows specifying process_index and process_count."""
54+
55+ def save (
56+ self ,
57+ directory : epath .Path ,
58+ # `item` is for backwards compatibility with older Orbax API, see
59+ # https://orbax.readthedocs.io/en/latest/api_refactor.html.
60+ item : Optional [Any ] = None ,
61+ args : Any = None ,
62+ ):
63+ """Saves the given iterator to the checkpoint in `directory`."""
64+ item = item or args .item # pytype:disable=attribute-error
65+
66+ def save_single_process (item , process_index , process_count ):
67+ filename = directory / f"process_{ process_index } -of-{ process_count } .json"
68+ if isinstance (item , grain .DatasetIterator ):
69+ state = json .dumps (item .get_state (), indent = 4 )
70+ else :
71+ state = item .get_state ().decode ()
72+ filename .write_text (state )
73+
74+ if isinstance (item , list ):
75+ for local_iterator , process_index , process_count in item :
76+ save_single_process (local_iterator , process_index , process_count )
77+ else :
78+ process_index , process_count = jax .process_index (), jax .process_count ()
79+ save_single_process (item , process_index , process_count )
80+
81+ def restore (
82+ self ,
83+ directory : epath .Path ,
84+ item : Optional [Any ] = None ,
85+ args : Any = None ,
86+ ) -> Any :
87+ """Restores the given iterator from the checkpoint in `directory`."""
88+ item = item or args .item
89+ process_index = getattr (args , "process_index" , None )
90+ process_count = getattr (args , "process_count" , None )
91+
92+ def restore_single_process (item , process_index , process_count ):
93+ filename = directory / f"process_{ process_index } -of-{ process_count } .json"
94+ if not filename .exists ():
95+ raise ValueError (f"File { filename } does not exist." )
96+ state = filename .read_text ()
97+ if isinstance (item , grain .DatasetIterator ):
98+ state = json .loads (state )
99+ else :
100+ state = state .encode ()
101+ item .set_state (state )
102+ return item
103+
104+ if isinstance (item , list ):
105+ restored_items = []
106+ for data_iter , process_idx in zip (item , process_index ):
107+ restored_items .append (restore_single_process (data_iter , process_idx , process_count ))
108+ return restored_items
109+ else :
110+ if process_index is None or process_count is None :
111+ process_index , process_count = jax .process_index (), jax .process_count ()
112+ return restore_single_process (item , process_index , process_count )
113+
114+
115+ @ocp .args .register_with_handler (GrainCheckpointHandler , for_save = True )
116+ @dataclasses .dataclass
117+ class GrainCheckpointSave (ocp .args .CheckpointArgs ):
118+ item : Any
119+
120+
121+ @ocp .args .register_with_handler (GrainCheckpointHandler , for_restore = True )
122+ @dataclasses .dataclass
123+ class GrainCheckpointRestore (ocp .args .CheckpointArgs ):
124+ item : Any
125+ process_index : Optional [int | list [int ]] = None
126+ process_count : Optional [int ] = None
127+
128+
47129def _load_full_state_from_path (
48130 path ,
49131 abstract_unboxed_pre_state ,
@@ -111,17 +193,18 @@ def create_orbax_checkpoint_manager(
111193
112194 max_logging .log (f"Creating checkpoint manager with ocdbt={ use_ocdbt } and zarr3={ use_zarr3 } " )
113195
196+ # Base configuration for all dataset types
197+ item_names = ("items" ,)
198+ # we need to use ocdbt and zarr3 to control max file size in the checkpoint
199+ item_handlers = {"items" : PyTreeCheckpointHandler (use_ocdbt = use_ocdbt , use_zarr3 = use_zarr3 )}
200+
114201 if dataset_type == "grain" :
115- item_names = ("items" , "iter" )
116- else :
117- item_names = ("items" ,)
202+ item_names += ("iter" ,)
203+ item_handlers ["iter" ] = GrainCheckpointHandler ()
118204
119205 # local storage checkpoint needs parent directory created
120206 p = epath .Path (checkpoint_dir )
121207 p .mkdir (exist_ok = True , parents = True )
122- # we need to use ocdbt and zarr3 to control max file size in the checkpoint
123- # omitting `iter` uses default handler for `iter`
124- item_handlers = {"items" : PyTreeCheckpointHandler (use_ocdbt = use_ocdbt , use_zarr3 = use_zarr3 )}
125208 manager = CheckpointManager (
126209 p ,
127210 item_names = item_names ,
@@ -273,9 +356,101 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
273356 return np .expand_dims (replica_result , axis = replica_axis_idx )
274357
275358
359+ def _prepare_scaled_down_grain_restore_args (
360+ data_iterator : list , process_count_jax : int , process_count_stored : int , directory : epath .Path
361+ ) -> GrainCheckpointRestore :
362+ """
363+ Prepares the restore arguments for a scaled-up (list) data iterator.
364+
365+ This is used when restoring a checkpoint saved with more processes than
366+ the current run (e.g., 64 files onto 32 JAX processes).
367+ """
368+ # 1. Validation Assertions
369+ assert isinstance (data_iterator , list ), (
370+ f"{ process_count_stored } processes found in Grain checkpoint directory { directory } , but only "
371+ f"{ process_count_jax } jax processes in this run, please set expansion_factor_real_data accordingly."
372+ )
373+
374+ scaling_factor = len (data_iterator )
375+ expected_process_count = process_count_stored / process_count_jax
376+ assert scaling_factor == expected_process_count , (
377+ f"Found { process_count_stored } processes in checkpoint and { process_count_jax } "
378+ f"JAX processes, implying a scaling factor of { expected_process_count } . "
379+ f"However, the data_iterator list has { scaling_factor } items."
380+ )
381+
382+ # 2. Prepare Arguments
383+ local_iterator_list = [x .local_iterator for x in data_iterator ]
384+ # Each JAX process calculates the global indices it's responsible for.
385+ # e.g., process 0 with scaling_factor=2 handles checkpoints from processes [0, 32]
386+ # e.g., process 1 with scaling_factor=2 handles checkpoints from processes [1, 33]
387+ process_index_list = [jax .process_index () + i * process_count_jax for i in range (scaling_factor )]
388+
389+ return GrainCheckpointRestore (local_iterator_list , process_index = process_index_list , process_count = process_count_stored )
390+
391+
392+ def _restore_grain_iterator (
393+ checkpoint_manager ,
394+ step : int ,
395+ data_iterator ,
396+ checkpoint_args ,
397+ expansion_factor_real_data : int , # This must be defined in the outer scope
398+ ) -> tuple [Any , None ]:
399+ """
400+ Handles the complex logic for restoring a Grain data iterator checkpoint.
401+ This function dispatches to the correct restore strategy based on
402+ the number of stored checkpoint files vs. current JAX processes.
403+ """
404+ directory = checkpoint_manager .directory / str (step ) / "iter"
405+ process_count_jax = jax .process_count ()
406+
407+ # Count the number of checkpoint files
408+ process_count_stored = len (list (directory .glob ("process_*-of-*.json" )))
409+
410+ grain_restore_args = None
411+
412+ if process_count_stored > process_count_jax :
413+ # Scaling down from a larger number of hosts. (e.g., 128 files -> 64 processes)
414+ # In this case, each host restores a list of data iterators.
415+ grain_restore_args = _prepare_scaled_down_grain_restore_args (
416+ data_iterator , process_count_jax , process_count_stored , directory
417+ )
418+
419+ elif process_count_stored == process_count_jax :
420+ # Normal case: number of hosts is the same. (e.g., 64 files -> 64 processes)
421+ assert not isinstance (data_iterator , list ), (
422+ f"{ process_count_stored } processes found in Grain checkpoint directory { directory } , matching the number of "
423+ "jax process, please do not set expansion_factor_real_data."
424+ )
425+ grain_restore_args = GrainCheckpointRestore (data_iterator .local_iterator )
426+
427+ elif expansion_factor_real_data > 1 and process_count_stored == process_count_jax // expansion_factor_real_data :
428+ # Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
429+ # In this case, a subset of hosts restore the data iterator.
430+ assert not isinstance (data_iterator , list ), "when expansion_factor_real_data > 1, the data iterator should not be a list."
431+ grain_restore_args = GrainCheckpointRestore (
432+ data_iterator .local_iterator , process_index = jax .process_index (), process_count = process_count_stored
433+ )
434+
435+ else :
436+ # Case 4: Mismatch
437+ raise ValueError (
438+ f"Error restoring Grain checkpoint in { directory } : "
439+ f"The number of stored checkpoint files ({ process_count_stored } ) "
440+ f"is incompatible with the number of JAX processes ({ process_count_jax } ). "
441+ "If you are resuming training with a different number of chips, see instructions in "
442+ "https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/"
443+ "data_input_grain.md#using-grain"
444+ )
445+
446+ # Call restore once with the composed arguments
447+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
448+ return (restored_state , None )
449+
450+
276451def load_state_if_possible (
277452 checkpoint_manager : CheckpointManager | None ,
278- data_iterator : MultiHostDataLoadIterator | None ,
453+ data_iterator : MultiHostDataLoadIterator | list [ MultiHostDataLoadIterator ] | None ,
279454 load_parameters_from_path : str ,
280455 load_full_state_from_path : str ,
281456 checkpoint_storage_concurrent_gb : int ,
@@ -288,6 +463,7 @@ def load_state_if_possible(
288463 enable_orbax_v1 = False ,
289464 checkpoint_conversion_fn = None ,
290465 source_checkpoint_layout = "orbax" ,
466+ expansion_factor_real_data : int = - 1 ,
291467):
292468 """Loads TrainState as possible from the inputs.
293469
@@ -354,20 +530,27 @@ def map_to_pspec(data):
354530 # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
355531 # 'data_iterator' can be any value and aren't used in this pattern.
356532 case (checkpoint_manager , _, _) if isinstance (
357- checkpoint_manager ,
358- (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
533+ checkpoint_manager , (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager )
359534 ):
360535 return (
361536 checkpoint_manager .restore (step , args = Composite (state = checkpoint_args )).state ,
362537 None ,
363538 )
364- # Case 2: Matches if dataset type is "grain" and a specific checkpoint file exits for the iterator
365- # exists within the checkpoint manager's directory for the given step.
366- case (checkpoint_manager , dataset_type , data_iterator ) if dataset_type == "grain" and data_iterator and (
367- checkpoint_manager .directory / str (step ) / "iter"
368- ).exists ():
369- grain_iter = grain .PyGrainCheckpointRestore (data_iterator .local_iterator )
370- return (checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_iter )), None )
539+ # Case 2: Matches if dataset type is "grain" and the data iterator is not a
540+ # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
541+ case (
542+ checkpoint_manager ,
543+ dataset_type ,
544+ data_iterator ,
545+ ) if (
546+ dataset_type == "grain"
547+ and data_iterator
548+ and not isinstance (data_iterator , PlaceHolderDataIterator )
549+ and (checkpoint_manager .directory / str (step ) / "iter" ).exists ()
550+ ):
551+ return _restore_grain_iterator (
552+ checkpoint_manager , step , data_iterator , checkpoint_args , expansion_factor_real_data
553+ )
371554 # Case 3: Default/Fallback case.
372555 # This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
373556 case _:
@@ -518,15 +701,25 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
518701 save_args = jax .tree .map (lambda _ : ocp .SaveArgs (chunk_byte_size = chunk_byte_size ), state ),
519702 ocdbt_target_data_file_size = chunk_byte_size ,
520703 )
521-
522- match (checkpoint_manager , config ):
523- case (checkpoint_manager , _) if isinstance (
704+ save_args_composite = {"items" : checkpoint_args }
705+
706+ if config and config .dataset_type == "grain" and not isinstance (data_iterator , PlaceHolderDataIterator ):
707+ if not isinstance (data_iterator , list ):
708+ data_iterator = [data_iterator ]
709+ grain_iters_to_save = []
710+ process_count_total = jax .process_count () * len (data_iterator )
711+ if config .expansion_factor_real_data > 1 :
712+ process_count_total = process_count_total // config .expansion_factor_real_data
713+ for i , data_iter in enumerate (data_iterator ):
714+ process_index = jax .process_index () + i * jax .process_count ()
715+ grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
716+ save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
717+
718+ match (checkpoint_manager , config , data_iterator ):
719+ case (checkpoint_manager , _, _) if isinstance (
524720 checkpoint_manager , (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager )
525721 ):
526722 replicator_error_handler (config )
527723 return checkpoint_manager .save (step , args = Composite (state = checkpoint_args ), force = force )
528- case (_, config ) if config and config .dataset_type == "grain" :
529- grain_iter = grain .PyGrainCheckpointSave (data_iterator .local_iterator )
530- return checkpoint_manager .save (step , args = Composite (items = checkpoint_args , iter = grain_iter ), force = force )
531724 case _:
532- return checkpoint_manager .save (step , args = Composite (items = checkpoint_args ), force = force )
725+ return checkpoint_manager .save (step , args = Composite (** save_args_composite ), force = force )
0 commit comments