Skip to content

Commit 86937a7

Browse files
Merge pull request #2537 from AI-Hypercomputer:aireen/grain_ckpt_scale
PiperOrigin-RevId: 827022073
2 parents a022ce3 + e7128bc commit 86937a7

File tree

12 files changed

+311
-45
lines changed

12 files changed

+311
-45
lines changed

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ eval_interval: 10000
7070
eval_steps: 50
7171
grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*'
7272
```
73+
8. Experimental: resuming training with a different chip count
74+
In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature:
75+
* **Scaling up**: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag `expansion_factor_real_data` (default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data is `total number of hosts // expansion_factor_real_data`. Each of these data-loading hosts will load `expansion_factor_real_data * per_host_batch_size_to_train`. For code integrity, the non-loading hosts use a `PlaceHolderDataIterator` to generate dummy data, which is later discarded. A user can optionally set `max_checkify=true` to enable additional checks that ensure dummy data is not used for training. In this example, you would set `expansion_factor_real_data=2` to scale from 64 to 128 hosts.
76+
* **Scaling down**: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag `expansion_factor_real_data` to have each host restore `1 / expansion_factor_real_data` data iterators. We then alternate between these iterators to produce batches. In this example, you would set `expansion_factor_real_data=0.5` to scale from 128 down to 64 hosts.
77+
* **Note**: In both scaling up and scaling down scenarios, the `per_device_batch_size` must remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data.

src/MaxText/checkpointing.py

Lines changed: 218 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,29 @@
1515
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
1616

1717
import time
18-
from typing import Any
18+
from typing import Any, Optional
1919

2020
from absl import flags
2121
from etils import epath
2222
from flax.training import train_state
23-
import grain.python as grain
2423
import jax
2524
from MaxText import exceptions
2625
from MaxText import max_logging
2726
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
2827
from MaxText.multihost_dataloading import MultiHostDataLoadIterator
28+
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
2929
import numpy as np
3030
import orbax.checkpoint as ocp
3131
from orbax.checkpoint import v1 as ocp_v1
3232
from orbax.checkpoint._src.arrays import sharding as sharding_utils
3333
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3434
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3535
# pylint: disable=too-many-positional-arguments
36+
import dataclasses
37+
import json
38+
39+
import grain
40+
from grain.python import PyGrainCheckpointHandler
3641

3742
CheckpointManager = ocp.CheckpointManager
3843
CheckpointManagerOptions = ocp.CheckpointManagerOptions
@@ -44,6 +49,83 @@
4449
EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager
4550

4651

52+
class GrainCheckpointHandler(PyGrainCheckpointHandler, ocp.CheckpointHandler):
53+
"""A CheckpointHandler that allows specifying process_index and process_count."""
54+
55+
def save(
56+
self,
57+
directory: epath.Path,
58+
# `item` is for backwards compatibility with older Orbax API, see
59+
# https://orbax.readthedocs.io/en/latest/api_refactor.html.
60+
item: Optional[Any] = None,
61+
args: Any = None,
62+
):
63+
"""Saves the given iterator to the checkpoint in `directory`."""
64+
item = item or args.item # pytype:disable=attribute-error
65+
66+
def save_single_process(item, process_index, process_count):
67+
filename = directory / f"process_{process_index}-of-{process_count}.json"
68+
if isinstance(item, grain.DatasetIterator):
69+
state = json.dumps(item.get_state(), indent=4)
70+
else:
71+
state = item.get_state().decode()
72+
filename.write_text(state)
73+
74+
if isinstance(item, list):
75+
for local_iterator, process_index, process_count in item:
76+
save_single_process(local_iterator, process_index, process_count)
77+
else:
78+
process_index, process_count = jax.process_index(), jax.process_count()
79+
save_single_process(item, process_index, process_count)
80+
81+
def restore(
82+
self,
83+
directory: epath.Path,
84+
item: Optional[Any] = None,
85+
args: Any = None,
86+
) -> Any:
87+
"""Restores the given iterator from the checkpoint in `directory`."""
88+
item = item or args.item
89+
process_index = getattr(args, "process_index", None)
90+
process_count = getattr(args, "process_count", None)
91+
92+
def restore_single_process(item, process_index, process_count):
93+
filename = directory / f"process_{process_index}-of-{process_count}.json"
94+
if not filename.exists():
95+
raise ValueError(f"File {filename} does not exist.")
96+
state = filename.read_text()
97+
if isinstance(item, grain.DatasetIterator):
98+
state = json.loads(state)
99+
else:
100+
state = state.encode()
101+
item.set_state(state)
102+
return item
103+
104+
if isinstance(item, list):
105+
restored_items = []
106+
for data_iter, process_idx in zip(item, process_index):
107+
restored_items.append(restore_single_process(data_iter, process_idx, process_count))
108+
return restored_items
109+
else:
110+
if process_index is None or process_count is None:
111+
process_index, process_count = jax.process_index(), jax.process_count()
112+
return restore_single_process(item, process_index, process_count)
113+
114+
115+
@ocp.args.register_with_handler(GrainCheckpointHandler, for_save=True)
116+
@dataclasses.dataclass
117+
class GrainCheckpointSave(ocp.args.CheckpointArgs):
118+
item: Any
119+
120+
121+
@ocp.args.register_with_handler(GrainCheckpointHandler, for_restore=True)
122+
@dataclasses.dataclass
123+
class GrainCheckpointRestore(ocp.args.CheckpointArgs):
124+
item: Any
125+
process_index: Optional[int | list[int]] = None
126+
process_count: Optional[int] = None
127+
128+
47129
def _load_full_state_from_path(
48130
path,
49131
abstract_unboxed_pre_state,
@@ -111,17 +193,18 @@ def create_orbax_checkpoint_manager(
111193

112194
max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")
113195

196+
# Base configuration for all dataset types
197+
item_names = ("items",)
198+
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
199+
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
200+
114201
if dataset_type == "grain":
115-
item_names = ("items", "iter")
116-
else:
117-
item_names = ("items",)
202+
item_names += ("iter",)
203+
item_handlers["iter"] = GrainCheckpointHandler()
118204

119205
# local storage checkpoint needs parent directory created
120206
p = epath.Path(checkpoint_dir)
121207
p.mkdir(exist_ok=True, parents=True)
122-
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
123-
# omitting `iter` uses default handler for `iter`
124-
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
125208
manager = CheckpointManager(
126209
p,
127210
item_names=item_names,
@@ -273,9 +356,101 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
273356
return np.expand_dims(replica_result, axis=replica_axis_idx)
274357

275358

359+
def _prepare_scaled_down_grain_restore_args(
360+
data_iterator: list, process_count_jax: int, process_count_stored: int, directory: epath.Path
361+
) -> GrainCheckpointRestore:
362+
"""
363+
Prepares the restore arguments for a scaled-up (list) data iterator.
364+
365+
This is used when restoring a checkpoint saved with more processes than
366+
the current run (e.g., 64 files onto 32 JAX processes).
367+
"""
368+
# 1. Validation Assertions
369+
assert isinstance(data_iterator, list), (
370+
f"{process_count_stored} processes found in Grain checkpoint directory {directory}, but only "
371+
f"{process_count_jax} jax processes in this run, please set expansion_factor_real_data accordingly."
372+
)
373+
374+
scaling_factor = len(data_iterator)
375+
expected_process_count = process_count_stored / process_count_jax
376+
assert scaling_factor == expected_process_count, (
377+
f"Found {process_count_stored} processes in checkpoint and {process_count_jax} "
378+
f"JAX processes, implying a scaling factor of {expected_process_count}. "
379+
f"However, the data_iterator list has {scaling_factor} items."
380+
)
381+
382+
# 2. Prepare Arguments
383+
local_iterator_list = [x.local_iterator for x in data_iterator]
384+
# Each JAX process calculates the global indices it's responsible for.
385+
# e.g., process 0 with scaling_factor=2 handles checkpoints from processes [0, 32]
386+
# e.g., process 1 with scaling_factor=2 handles checkpoints from processes [1, 33]
387+
process_index_list = [jax.process_index() + i * process_count_jax for i in range(scaling_factor)]
388+
389+
return GrainCheckpointRestore(local_iterator_list, process_index=process_index_list, process_count=process_count_stored)
390+
391+
392+
def _restore_grain_iterator(
393+
checkpoint_manager,
394+
step: int,
395+
data_iterator,
396+
checkpoint_args,
397+
expansion_factor_real_data: int, # This must be defined in the outer scope
398+
) -> tuple[Any, None]:
399+
"""
400+
Handles the complex logic for restoring a Grain data iterator checkpoint.
401+
This function dispatches to the correct restore strategy based on
402+
the number of stored checkpoint files vs. current JAX processes.
403+
"""
404+
directory = checkpoint_manager.directory / str(step) / "iter"
405+
process_count_jax = jax.process_count()
406+
407+
# Count the number of checkpoint files
408+
process_count_stored = len(list(directory.glob("process_*-of-*.json")))
409+
410+
grain_restore_args = None
411+
412+
if process_count_stored > process_count_jax:
413+
# Scaling down from a larger number of hosts. (e.g., 128 files -> 64 processes)
414+
# In this case, each host restores a list of data iterators.
415+
grain_restore_args = _prepare_scaled_down_grain_restore_args(
416+
data_iterator, process_count_jax, process_count_stored, directory
417+
)
418+
419+
elif process_count_stored == process_count_jax:
420+
# Normal case: number of hosts is the same. (e.g., 64 files -> 64 processes)
421+
assert not isinstance(data_iterator, list), (
422+
f"{process_count_stored} processes found in Grain checkpoint directory {directory}, matching the number of "
423+
"jax process, please do not set expansion_factor_real_data."
424+
)
425+
grain_restore_args = GrainCheckpointRestore(data_iterator.local_iterator)
426+
427+
elif expansion_factor_real_data > 1 and process_count_stored == process_count_jax // expansion_factor_real_data:
428+
# Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
429+
# In this case, a subset of hosts restore the data iterator.
430+
assert not isinstance(data_iterator, list), "when expansion_factor_real_data > 1, the data iterator should not be a list."
431+
grain_restore_args = GrainCheckpointRestore(
432+
data_iterator.local_iterator, process_index=jax.process_index(), process_count=process_count_stored
433+
)
434+
435+
else:
436+
# Case 4: Mismatch
437+
raise ValueError(
438+
f"Error restoring Grain checkpoint in {directory}: "
439+
f"The number of stored checkpoint files ({process_count_stored}) "
440+
f"is incompatible with the number of JAX processes ({process_count_jax}). "
441+
"If you are resuming training with a different number of chips, see instructions in "
442+
"https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/"
443+
"data_input_grain.md#using-grain"
444+
)
445+
446+
# Call restore once with the composed arguments
447+
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
448+
return (restored_state, None)
449+
450+
276451
def load_state_if_possible(
277452
checkpoint_manager: CheckpointManager | None,
278-
data_iterator: MultiHostDataLoadIterator | None,
453+
data_iterator: MultiHostDataLoadIterator | list[MultiHostDataLoadIterator] | None,
279454
load_parameters_from_path: str,
280455
load_full_state_from_path: str,
281456
checkpoint_storage_concurrent_gb: int,
@@ -288,6 +463,7 @@ def load_state_if_possible(
288463
enable_orbax_v1=False,
289464
checkpoint_conversion_fn=None,
290465
source_checkpoint_layout="orbax",
466+
expansion_factor_real_data: int = -1,
291467
):
292468
"""Loads TrainState as possible from the inputs.
293469
@@ -354,20 +530,27 @@ def map_to_pspec(data):
354530
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
355531
# 'data_iterator' can be any value and aren't used in this pattern.
356532
case (checkpoint_manager, _, _) if isinstance(
357-
checkpoint_manager,
358-
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
533+
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
359534
):
360535
return (
361536
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
362537
None,
363538
)
364-
# Case 2: Matches if dataset type is "grain" and a specific checkpoint file exits for the iterator
365-
# exists within the checkpoint manager's directory for the given step.
366-
case (checkpoint_manager, dataset_type, data_iterator) if dataset_type == "grain" and data_iterator and (
367-
checkpoint_manager.directory / str(step) / "iter"
368-
).exists():
369-
grain_iter = grain.PyGrainCheckpointRestore(data_iterator.local_iterator)
370-
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter)), None)
539+
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
540+
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
541+
case (
542+
checkpoint_manager,
543+
dataset_type,
544+
data_iterator,
545+
) if (
546+
dataset_type == "grain"
547+
and data_iterator
548+
and not isinstance(data_iterator, PlaceHolderDataIterator)
549+
and (checkpoint_manager.directory / str(step) / "iter").exists()
550+
):
551+
return _restore_grain_iterator(
552+
checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data
553+
)
371554
# Case 3: Default/Fallback case.
372555
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
373556
case _:
@@ -518,15 +701,25 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
518701
save_args=jax.tree.map(lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), state),
519702
ocdbt_target_data_file_size=chunk_byte_size,
520703
)
521-
522-
match (checkpoint_manager, config):
523-
case (checkpoint_manager, _) if isinstance(
704+
save_args_composite = {"items": checkpoint_args}
705+
706+
if config and config.dataset_type == "grain" and not isinstance(data_iterator, PlaceHolderDataIterator):
707+
if not isinstance(data_iterator, list):
708+
data_iterator = [data_iterator]
709+
grain_iters_to_save = []
710+
process_count_total = jax.process_count() * len(data_iterator)
711+
if config.expansion_factor_real_data > 1:
712+
process_count_total = process_count_total // config.expansion_factor_real_data
713+
for i, data_iter in enumerate(data_iterator):
714+
process_index = jax.process_index() + i * jax.process_count()
715+
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
716+
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
717+
718+
match (checkpoint_manager, config, data_iterator):
719+
case (checkpoint_manager, _, _) if isinstance(
524720
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
525721
):
526722
replicator_error_handler(config)
527723
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
528-
case (_, config) if config and config.dataset_type == "grain":
529-
grain_iter = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
530-
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args, iter=grain_iter), force=force)
531724
case _:
532-
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args), force=force)
725+
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)

src/MaxText/configs/base.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,11 @@ add_eos: True
497497

498498
# Dataset
499499
per_device_batch_size: 12.0
500-
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
500+
# When expansion_factor_real_data is set to > 1, total_hosts//expansion_factor_real_data will load data.
501+
# Each data-loading host will load per_device_batch_size * expansion_factor_real_data.
502+
# When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job.
503+
# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain
504+
expansion_factor_real_data: -1.0
501505
eval_per_device_batch_size: 0.0
502506
max_corpus_chars: 10_000_000
503507
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"

src/MaxText/data_loader.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,21 @@ class DataLoader:
3535
def __init__(self, config, mesh, data_iterator, goodput_recorder):
3636
self.config = config
3737
self.goodput_recorder = goodput_recorder
38-
self.data_iterator = data_iterator
38+
if isinstance(data_iterator, list):
39+
self.data_iterator_list = data_iterator
40+
self.data_iterator_index = 0
41+
self.data_iterator = self.data_iterator_list[self.data_iterator_index]
42+
else:
43+
self.data_iterator = data_iterator
3944
self.last_batch = None
4045
self.input_data_shardings = sharding.get_input_data_sharding(config, mesh)
4146

47+
def update_data_iterator(self):
48+
"""Update to the next data iterator in the list, if applicable."""
49+
if hasattr(self, "data_iterator_list"):
50+
self.data_iterator_index = (self.data_iterator_index + 1) % len(self.data_iterator_list)
51+
self.data_iterator = self.data_iterator_list[self.data_iterator_index]
52+
4253
def load_next_batch(self):
4354
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
4455
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
@@ -47,6 +58,7 @@ def load_next_batch(self):
4758
example_batch = self.last_batch
4859
else:
4960
example_batch = next(self.data_iterator)
61+
self.update_data_iterator()
5062
# Reshard data from loaded sharding to performant activation sharding
5163
self.last_batch = sharding.maybe_shard_with_name(
5264
example_batch,

src/MaxText/experimental/rl/grpo_trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,13 @@ def generation_worker_fn(
878878
max_utils.print_mem_stats("After params initialized")
879879

880880
metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
881-
state_to_save = _split_grpo_state(state)[0]
882-
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
881+
882+
if config.save_checkpoint_on_completion:
883+
state_to_save = _split_grpo_state(state)[0]
884+
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
885+
elif checkpoint_manager is not None:
886+
# in case the last checkpoint_period checkpoint is still in progress
887+
checkpoint_manager.wait_until_finished()
883888
except exceptions.StopTraining as e:
884889
max_logging.log(f"Training stopped: {str(e)}")
885890
finally:

0 commit comments

Comments
 (0)