1515"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1616
1717import time
18- from typing import Any
18+ from collections .abc import Iterator
19+ from typing import Any , Optional
1920
2021from absl import flags
2122from etils import epath
2223from flax .training import train_state
23- import grain .python as grain
2424import jax
2525from MaxText import exceptions
2626from MaxText import max_logging
3333import orbax .checkpoint .experimental .emergency .checkpoint_manager as emergency_checkpoint_manager
3434import orbax .checkpoint .experimental .emergency .replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3535# pylint: disable=too-many-positional-arguments
36+ import dataclasses
37+ import json
38+
39+ import grain
40+ from grain .python import PyGrainCheckpointHandler
3641
3742CheckpointManager = ocp .CheckpointManager
3843CheckpointManagerOptions = ocp .CheckpointManagerOptions
4449EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager .ReplicatorCheckpointManager
4550
4651
52+ class GrainCheckpointHandler (PyGrainCheckpointHandler ):
53+ """A CheckpointHandler that allows specifying process_index and process_count."""
54+
55+ def save (
56+ self ,
57+ directory : epath .Path ,
58+ # `item` is for backwards compatibility with older Orbax API, see
59+ # https://orbax.readthedocs.io/en/latest/api_refactor.html.
60+ item : Optional [Iterator ] = None ,
61+ args : Any = None ,
62+ ):
63+ """Saves the given iterator to the checkpoint in `directory`."""
64+ item = item or args .item # pytype:disable=attribute-error
65+
66+ def save_single_process (item , process_index , process_count ):
67+ filename = directory / f"process_{ process_index } -of-{ process_count } .json"
68+ if isinstance (item , grain .DatasetIterator ):
69+ state = json .dumps (item .get_state (), indent = 4 )
70+ else :
71+ state = item .get_state ().decode ()
72+ filename .write_text (state )
73+
74+ if isinstance (item , list ):
75+ for local_iterator , process_index , process_count in item :
76+ save_single_process (local_iterator , process_index , process_count )
77+ else :
78+ process_index , process_count = jax .process_index (), jax .process_count ()
79+ save_single_process (item , process_index , process_count )
80+
81+ def restore (
82+ self ,
83+ directory : epath .Path ,
84+ item : Optional [Iterator ] = None ,
85+ args : Any = None ,
86+ ) -> Iterator :
87+ """Restores the given iterator from the checkpoint in `directory`."""
88+ item = item or args .item
89+ process_index = getattr (args , "process_index" , None )
90+ process_count = getattr (args , "process_count" , None )
91+
92+ def restore_single_process (item , process_index , process_count ):
93+ filename = directory / f"process_{ process_index } -of-{ process_count } .json"
94+ if not filename .exists ():
95+ raise ValueError (f"File { filename } does not exist." )
96+ state = filename .read_text ()
97+ if isinstance (item , grain .DatasetIterator ):
98+ state = json .loads (state )
99+ else :
100+ state = state .encode ()
101+ item .set_state (state )
102+ return item
103+
104+ if isinstance (item , list ):
105+ restored_items = []
106+ for data_iter , process_idx in zip (item , process_index ):
107+ restored_items .append (restore_single_process (data_iter , process_idx , process_count ))
108+ return restored_items
109+ else :
110+ if process_index is None or process_count is None :
111+ process_index , process_count = jax .process_index (), jax .process_count ()
112+ return restore_single_process (item , process_index , process_count )
113+
114+
115+ @ocp .args .register_with_handler (GrainCheckpointHandler , for_save = True )
116+ @dataclasses .dataclass
117+ class GrainCheckpointSave (ocp .args .CheckpointArgs ):
118+ item : Any
119+
120+
121+ @ocp .args .register_with_handler (GrainCheckpointHandler , for_restore = True )
122+ @dataclasses .dataclass
123+ class GrainCheckpointRestore (ocp .args .CheckpointArgs ):
124+ item : Any
125+ process_index : Optional [int | list [int ]] = None
126+ process_count : Optional [int ] = None
127+
128+
47129def _load_full_state_from_path (
48130 path ,
49131 abstract_unboxed_pre_state ,
@@ -111,17 +193,18 @@ def create_orbax_checkpoint_manager(
111193
112194 max_logging .log (f"Creating checkpoint manager with ocdbt={ use_ocdbt } and zarr3={ use_zarr3 } " )
113195
196+ # Base configuration for all dataset types
197+ item_names = ("items" ,)
198+ # we need to use ocdbt and zarr3 to control max file size in the checkpoint
199+ item_handlers = {"items" : PyTreeCheckpointHandler (use_ocdbt = use_ocdbt , use_zarr3 = use_zarr3 )}
200+
114201 if dataset_type == "grain" :
115- item_names = ("items" , "iter" )
116- else :
117- item_names = ("items" ,)
202+ item_names += ("iter" ,)
203+ item_handlers ["iter" ] = GrainCheckpointHandler ()
118204
119205 # local storage checkpoint needs parent directory created
120206 p = epath .Path (checkpoint_dir )
121207 p .mkdir (exist_ok = True , parents = True )
122- # we need to use ocdbt and zarr3 to control max file size in the checkpoint
123- # omitting `iter` uses default handler for `iter`
124- item_handlers = {"items" : PyTreeCheckpointHandler (use_ocdbt = use_ocdbt , use_zarr3 = use_zarr3 )}
125208 manager = CheckpointManager (
126209 p ,
127210 item_names = item_names ,
@@ -273,6 +356,97 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
273356 return np .expand_dims (replica_result , axis = replica_axis_idx )
274357
275358
359+ def _prepare_scaled_down_grain_restore_args (
360+ data_iterator : list , process_count_jax : int , process_count_stored : int , directory : epath .Path
361+ ) -> GrainCheckpointRestore :
362+ """
363+ Prepares the restore arguments for a scaled-up (list) data iterator.
364+
365+ This is used when restoring a checkpoint saved with more processes than
366+ the current run (e.g., 64 files onto 32 JAX processes).
367+ """
368+ # 1. Validation Assertions
369+ assert isinstance (data_iterator , list ), (
370+ f"{ process_count_stored } processes found in Grain checkpoint directory { directory } , but only "
371+ f"{ process_count_jax } jax processes in this run, please set expansion_factor_real_data accordingly."
372+ )
373+
374+ scaling_factor = len (data_iterator )
375+ expected_process_count = process_count_stored / process_count_jax
376+ assert scaling_factor == expected_process_count , (
377+ f"Found { process_count_stored } processes in checkpoint and { process_count_jax } "
378+ f"JAX processes, implying a scaling factor of { expected_process_count } . "
379+ f"However, the data_iterator list has { scaling_factor } items."
380+ )
381+
382+ # 2. Prepare Arguments
383+ local_iterator_list = [x .local_iterator for x in data_iterator ]
384+ # Each JAX process calculates the global indices it's responsible for.
385+ # e.g., process 0 with scaling_factor=2 handles checkpoints from processes [0, 32]
386+ # e.g., process 1 with scaling_factor=2 handles checkpoints from processes [1, 33]
387+ process_index_list = [jax .process_index () + i * process_count_jax for i in range (scaling_factor )]
388+
389+ return GrainCheckpointRestore (local_iterator_list , process_index = process_index_list , process_count = process_count_stored )
390+
391+
392+ def _restore_grain_iterator (
393+ checkpoint_manager ,
394+ step : int ,
395+ data_iterator ,
396+ checkpoint_args ,
397+ expansion_factor_real_data : int , # This must be defined in the outer scope
398+ ) -> tuple [Any , None ]:
399+ """
400+ Handles the complex logic for restoring a Grain data iterator checkpoint.
401+ This function dispatches to the correct restore strategy based on
402+ the number of stored checkpoint files vs. current JAX processes.
403+ """
404+ directory = checkpoint_manager .directory / str (step ) / "iter"
405+ process_count_jax = jax .process_count ()
406+
407+ # Count the number of checkpoint files
408+ process_count_stored = len (list (directory .glob ("process_*-of-*.json" )))
409+
410+ grain_restore_args = None
411+
412+ if process_count_stored > process_count_jax :
413+ # Scaling down from a larger number of hosts. (e.g., 128 files -> 64 processes)
414+ # In this case, each host restores a list of data iterators.
415+ grain_restore_args = _prepare_scaled_down_grain_restore_args (
416+ data_iterator , process_count_jax , process_count_stored , directory
417+ )
418+
419+ elif process_count_stored == process_count_jax :
420+ # Normal case: number of hosts is the same. (e.g., 64 files -> 64 processes)
421+ assert not isinstance (data_iterator , list ), (
422+ f"{ process_count_stored } processes found in Grain checkpoint directory { directory } , matching the number of "
423+ "jax process, please do not set expansion_factor_real_data."
424+ )
425+ grain_restore_args = GrainCheckpointRestore (data_iterator .local_iterator )
426+
427+ elif expansion_factor_real_data > 0 and process_count_stored == process_count_jax // expansion_factor_real_data :
428+ # Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
429+ # In this case, a subset of hosts restore the data iterator.
430+ grain_restore_args = GrainCheckpointRestore (
431+ data_iterator .local_iterator , process_index = jax .process_index (), process_count = process_count_stored
432+ )
433+
434+ else :
435+ # Case 4: Mismatch
436+ raise ValueError (
437+ f"Error restoring Grain checkpoint in { directory } : "
438+ f"The number of stored checkpoint files ({ process_count_stored } ) "
439+ f"is incompatible with the number of JAX processes ({ process_count_jax } ). "
440+ "If you are resuming training with a different number of chips, see instructions in "
441+ "https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/"
442+ "data_input_grain.md#using-grain"
443+ )
444+
445+ # Call restore once with the composed arguments
446+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
447+ return (restored_state , None )
448+
449+
276450def load_state_if_possible (
277451 checkpoint_manager : CheckpointManager | None ,
278452 data_iterator : MultiHostDataLoadIterator | None ,
@@ -288,6 +462,7 @@ def load_state_if_possible(
288462 enable_orbax_v1 = False ,
289463 checkpoint_conversion_fn = None ,
290464 source_checkpoint_layout = "orbax" ,
465+ expansion_factor_real_data : int = - 1 ,
291466):
292467 """Loads TrainState as possible from the inputs.
293468
@@ -348,26 +523,34 @@ def map_to_pspec(data):
348523
349524 restore_args = jax .tree_util .tree_map (map_to_pspec , abstract_unboxed_pre_state )
350525 checkpoint_args = ocp .args .PyTreeRestore (item = abstract_unboxed_pre_state , restore_args = restore_args )
526+ from MaxText .input_pipeline .input_pipeline_interface import PlaceHolderDataIterator # pylint: disable=import-outside-toplevel
351527
352528 match (checkpoint_manager , dataset_type , data_iterator ):
353529 # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
354530 # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
355531 # 'data_iterator' can be any value and aren't used in this pattern.
356532 case (checkpoint_manager , _, _) if isinstance (
357- checkpoint_manager ,
358- (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager ),
533+ checkpoint_manager , (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager )
359534 ):
360535 return (
361536 checkpoint_manager .restore (step , args = Composite (state = checkpoint_args )).state ,
362537 None ,
363538 )
364- # Case 2: Matches if dataset type is "grain" and a specific checkpoint file exits for the iterator
365- # exists within the checkpoint manager's directory for the given step.
366- case (checkpoint_manager , dataset_type , data_iterator ) if dataset_type == "grain" and data_iterator and (
367- checkpoint_manager .directory / str (step ) / "iter"
368- ).exists ():
369- grain_iter = grain .PyGrainCheckpointRestore (data_iterator .local_iterator )
370- return (checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_iter )), None )
539+ # Case 2: Matches if dataset type is "grain" and the data iterator is not a
540+ # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
541+ case (
542+ checkpoint_manager ,
543+ dataset_type ,
544+ data_iterator ,
545+ ) if (
546+ dataset_type == "grain"
547+ and data_iterator
548+ and not isinstance (data_iterator , PlaceHolderDataIterator )
549+ and (checkpoint_manager .directory / str (step ) / "iter" ).exists ()
550+ ):
551+ return _restore_grain_iterator (
552+ checkpoint_manager , step , data_iterator , checkpoint_args , expansion_factor_real_data
553+ )
371554 # Case 3: Default/Fallback case.
372555 # This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
373556 case _:
@@ -518,15 +701,26 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
518701 save_args = jax .tree .map (lambda _ : ocp .SaveArgs (chunk_byte_size = chunk_byte_size ), state ),
519702 ocdbt_target_data_file_size = chunk_byte_size ,
520703 )
521-
522- match (checkpoint_manager , config ):
523- case (checkpoint_manager , _) if isinstance (
704+ save_args_composite = {"items" : checkpoint_args }
705+ from MaxText .input_pipeline .input_pipeline_interface import PlaceHolderDataIterator # pylint: disable=import-outside-toplevel
706+
707+ if config and config .dataset_type == "grain" and not isinstance (data_iterator , PlaceHolderDataIterator ):
708+ if not isinstance (data_iterator , list ):
709+ data_iterator = [data_iterator ]
710+ grain_iters_to_save = []
711+ process_count_total = jax .process_count () * len (data_iterator )
712+ if config .expansion_factor_real_data > 1 :
713+ process_count_total = process_count_total // config .expansion_factor_real_data
714+ for i , data_iter in enumerate (data_iterator ):
715+ process_index = jax .process_index () + i * jax .process_count ()
716+ grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
717+ save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
718+
719+ match (checkpoint_manager , config , data_iterator ):
720+ case (checkpoint_manager , _, _) if isinstance (
524721 checkpoint_manager , (EmergencyCheckpointManager , EmergencyReplicatorCheckpointManager )
525722 ):
526723 replicator_error_handler (config )
527724 return checkpoint_manager .save (step , args = Composite (state = checkpoint_args ), force = force )
528- case (_, config ) if config and config .dataset_type == "grain" :
529- grain_iter = grain .PyGrainCheckpointSave (data_iterator .local_iterator )
530- return checkpoint_manager .save (step , args = Composite (items = checkpoint_args , iter = grain_iter ), force = force )
531725 case _:
532- return checkpoint_manager .save (step , args = Composite (items = checkpoint_args ), force = force )
726+ return checkpoint_manager .save (step , args = Composite (** save_args_composite ), force = force )
0 commit comments