Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/guides/data_input_pipeline/data_input_grain.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ eval_interval: 10000
eval_steps: 50
grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*'
```
8. Experimental: resuming training with a different chip count
In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature:
* **Scaling up**: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag `expansion_factor_real_data` (default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data is `total number of hosts // expansion_factor_real_data`. Each of these data-loading hosts will load `expansion_factor_real_data * per_host_batch_size_to_train`. For code integrity, the non-loading hosts use a `PlaceHolderDataIterator` to generate dummy data, which is later discarded. A user can optionally set `max_checkify=true` to enable additional checks that ensure dummy data is not used for training. In this example, you would set `expansion_factor_real_data=2` to scale from 64 to 128 hosts.
* **Scaling down**: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag `expansion_factor_real_data` to have each host restore `1 / expansion_factor_real_data` data iterators. We then alternate between these iterators to produce batches. In this example, you would set `expansion_factor_real_data=0.5` to scale from 128 down to 64 hosts.
* **Note**: In both scaling up and scaling down scenarios, the `per_device_batch_size` must remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data.
240 changes: 216 additions & 24 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,29 @@
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

import time
from typing import Any
from typing import Any, Optional

from absl import flags
from etils import epath
from flax.training import train_state
import grain.python as grain
import jax
from MaxText import exceptions
from MaxText import max_logging
from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
from MaxText.multihost_dataloading import MultiHostDataLoadIterator
from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator
import numpy as np
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp_v1
from orbax.checkpoint._src.arrays import sharding as sharding_utils
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
# pylint: disable=too-many-positional-arguments
import dataclasses
import json

import grain
from grain.python import PyGrainCheckpointHandler

CheckpointManager = ocp.CheckpointManager
CheckpointManagerOptions = ocp.CheckpointManagerOptions
Expand All @@ -44,6 +49,83 @@
EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager


class GrainCheckpointHandler(PyGrainCheckpointHandler):
"""A CheckpointHandler that allows specifying process_index and process_count."""

def save(
self,
directory: epath.Path,
# `item` is for backwards compatibility with older Orbax API, see
# https://orbax.readthedocs.io/en/latest/api_refactor.html.
item: Optional[Any] = None,
args: Any = None,
):
"""Saves the given iterator to the checkpoint in `directory`."""
item = item or args.item # pytype:disable=attribute-error

def save_single_process(item, process_index, process_count):
filename = directory / f"process_{process_index}-of-{process_count}.json"
if isinstance(item, grain.DatasetIterator):
state = json.dumps(item.get_state(), indent=4)
else:
state = item.get_state().decode()
filename.write_text(state)

if isinstance(item, list):
for local_iterator, process_index, process_count in item:
save_single_process(local_iterator, process_index, process_count)
else:
process_index, process_count = jax.process_index(), jax.process_count()
save_single_process(item, process_index, process_count)

def restore(
self,
directory: epath.Path,
item: Optional[Any] = None,
args: Any = None,
) -> Any:
"""Restores the given iterator from the checkpoint in `directory`."""
item = item or args.item
process_index = getattr(args, "process_index", None)
process_count = getattr(args, "process_count", None)

def restore_single_process(item, process_index, process_count):
filename = directory / f"process_{process_index}-of-{process_count}.json"
if not filename.exists():
raise ValueError(f"File {filename} does not exist.")
state = filename.read_text()
if isinstance(item, grain.DatasetIterator):
state = json.loads(state)
else:
state = state.encode()
item.set_state(state)
return item

if isinstance(item, list):
restored_items = []
for data_iter, process_idx in zip(item, process_index):
restored_items.append(restore_single_process(data_iter, process_idx, process_count))
return restored_items
else:
if process_index is None or process_count is None:
process_index, process_count = jax.process_index(), jax.process_count()
return restore_single_process(item, process_index, process_count)


@ocp.args.register_with_handler(GrainCheckpointHandler, for_save=True)
@dataclasses.dataclass
class GrainCheckpointSave(ocp.args.CheckpointArgs):
item: Any


@ocp.args.register_with_handler(GrainCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class GrainCheckpointRestore(ocp.args.CheckpointArgs):
item: Any
process_index: Optional[int | list[int]] = None
process_count: Optional[int] = None


def _load_full_state_from_path(
path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -111,17 +193,18 @@ def create_orbax_checkpoint_manager(

max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")

# Base configuration for all dataset types
item_names = ("items",)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}

if dataset_type == "grain":
item_names = ("items", "iter")
else:
item_names = ("items",)
item_names += ("iter",)
item_handlers["iter"] = GrainCheckpointHandler()

# local storage checkpoint needs parent directory created
p = epath.Path(checkpoint_dir)
p.mkdir(exist_ok=True, parents=True)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
# omitting `iter` uses default handler for `iter`
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
manager = CheckpointManager(
p,
item_names=item_names,
Expand Down Expand Up @@ -273,6 +356,97 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
return np.expand_dims(replica_result, axis=replica_axis_idx)


def _prepare_scaled_down_grain_restore_args(
data_iterator: list, process_count_jax: int, process_count_stored: int, directory: epath.Path
) -> GrainCheckpointRestore:
"""
Prepares the restore arguments for a scaled-up (list) data iterator.

This is used when restoring a checkpoint saved with more processes than
the current run (e.g., 64 files onto 32 JAX processes).
"""
# 1. Validation Assertions
assert isinstance(data_iterator, list), (
f"{process_count_stored} processes found in Grain checkpoint directory {directory}, but only "
f"{process_count_jax} jax processes in this run, please set expansion_factor_real_data accordingly."
)

scaling_factor = len(data_iterator)
expected_process_count = process_count_stored / process_count_jax
assert scaling_factor == expected_process_count, (
f"Found {process_count_stored} processes in checkpoint and {process_count_jax} "
f"JAX processes, implying a scaling factor of {expected_process_count}. "
f"However, the data_iterator list has {scaling_factor} items."
)

# 2. Prepare Arguments
local_iterator_list = [x.local_iterator for x in data_iterator]
# Each JAX process calculates the global indices it's responsible for.
# e.g., process 0 with scaling_factor=2 handles checkpoints from processes [0, 32]
# e.g., process 1 with scaling_factor=2 handles checkpoints from processes [1, 33]
process_index_list = [jax.process_index() + i * process_count_jax for i in range(scaling_factor)]

return GrainCheckpointRestore(local_iterator_list, process_index=process_index_list, process_count=process_count_stored)


def _restore_grain_iterator(
checkpoint_manager,
step: int,
data_iterator,
checkpoint_args,
expansion_factor_real_data: int, # This must be defined in the outer scope
) -> tuple[Any, None]:
"""
Handles the complex logic for restoring a Grain data iterator checkpoint.
This function dispatches to the correct restore strategy based on
the number of stored checkpoint files vs. current JAX processes.
"""
directory = checkpoint_manager.directory / str(step) / "iter"
process_count_jax = jax.process_count()

# Count the number of checkpoint files
process_count_stored = len(list(directory.glob("process_*-of-*.json")))

grain_restore_args = None

if process_count_stored > process_count_jax:
# Scaling down from a larger number of hosts. (e.g., 128 files -> 64 processes)
# In this case, each host restores a list of data iterators.
grain_restore_args = _prepare_scaled_down_grain_restore_args(
data_iterator, process_count_jax, process_count_stored, directory
)

elif process_count_stored == process_count_jax:
# Normal case: number of hosts is the same. (e.g., 64 files -> 64 processes)
assert not isinstance(data_iterator, list), (
f"{process_count_stored} processes found in Grain checkpoint directory {directory}, matching the number of "
"jax process, please do not set expansion_factor_real_data."
)
grain_restore_args = GrainCheckpointRestore(data_iterator.local_iterator)

elif expansion_factor_real_data > 0 and process_count_stored == process_count_jax // expansion_factor_real_data:
# Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes)
# In this case, a subset of hosts restore the data iterator.
grain_restore_args = GrainCheckpointRestore(
data_iterator.local_iterator, process_index=jax.process_index(), process_count=process_count_stored
)

else:
# Case 4: Mismatch
raise ValueError(
f"Error restoring Grain checkpoint in {directory}: "
f"The number of stored checkpoint files ({process_count_stored}) "
f"is incompatible with the number of JAX processes ({process_count_jax}). "
"If you are resuming training with a different number of chips, see instructions in "
"https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/"
"data_input_grain.md#using-grain"
)

# Call restore once with the composed arguments
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
return (restored_state, None)


def load_state_if_possible(
checkpoint_manager: CheckpointManager | None,
data_iterator: MultiHostDataLoadIterator | None,
Expand All @@ -288,6 +462,7 @@ def load_state_if_possible(
enable_orbax_v1=False,
checkpoint_conversion_fn=None,
source_checkpoint_layout="orbax",
expansion_factor_real_data: int = -1,
):
"""Loads TrainState as possible from the inputs.

Expand Down Expand Up @@ -354,20 +529,27 @@ def map_to_pspec(data):
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
# 'data_iterator' can be any value and aren't used in this pattern.
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager,
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
return (
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
None,
)
# Case 2: Matches if dataset type is "grain" and a specific checkpoint file exits for the iterator
# exists within the checkpoint manager's directory for the given step.
case (checkpoint_manager, dataset_type, data_iterator) if dataset_type == "grain" and data_iterator and (
checkpoint_manager.directory / str(step) / "iter"
).exists():
grain_iter = grain.PyGrainCheckpointRestore(data_iterator.local_iterator)
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter)), None)
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
case (
checkpoint_manager,
dataset_type,
data_iterator,
) if (
dataset_type == "grain"
and data_iterator
and not isinstance(data_iterator, PlaceHolderDataIterator)
and (checkpoint_manager.directory / str(step) / "iter").exists()
):
return _restore_grain_iterator(
checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data
)
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
Expand Down Expand Up @@ -518,15 +700,25 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
save_args=jax.tree.map(lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), state),
ocdbt_target_data_file_size=chunk_byte_size,
)

match (checkpoint_manager, config):
case (checkpoint_manager, _) if isinstance(
save_args_composite = {"items": checkpoint_args}

if config and config.dataset_type == "grain" and not isinstance(data_iterator, PlaceHolderDataIterator):
if not isinstance(data_iterator, list):
data_iterator = [data_iterator]
grain_iters_to_save = []
process_count_total = jax.process_count() * len(data_iterator)
if config.expansion_factor_real_data > 1:
process_count_total = process_count_total // config.expansion_factor_real_data
for i, data_iter in enumerate(data_iterator):
process_index = jax.process_index() + i * jax.process_count()
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

match (checkpoint_manager, config, data_iterator):
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
case (_, config) if config and config.dataset_type == "grain":
grain_iter = grain.PyGrainCheckpointSave(data_iterator.local_iterator)
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args, iter=grain_iter), force=force)
case _:
return checkpoint_manager.save(step, args=Composite(items=checkpoint_args), force=force)
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)
6 changes: 5 additions & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,11 @@ add_eos: True

# Dataset
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
# When expansion_factor_real_data is set to > 1, total_hosts//expansion_factor_real_data will load data.
# Each data-loading host will load per_device_batch_size * expansion_factor_real_data.
# When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job.
# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain
expansion_factor_real_data: -1.0
eval_per_device_batch_size: 0.0
max_corpus_chars: 10_000_000
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
Expand Down
14 changes: 13 additions & 1 deletion src/MaxText/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ class DataLoader:
def __init__(self, config, mesh, data_iterator, goodput_recorder):
self.config = config
self.goodput_recorder = goodput_recorder
self.data_iterator = data_iterator
if isinstance(data_iterator, list):
self.data_iterator_list = data_iterator
self.data_iterator_index = 0
self.data_iterator = self.data_iterator_list[self.data_iterator_index]
else:
self.data_iterator = data_iterator
self.last_batch = None
self.input_data_shardings = sharding.get_input_data_sharding(config, mesh)

def update_data_iterator(self):
"""Update to the next data iterator in the list, if applicable."""
if hasattr(self, "data_iterator_list"):
self.data_iterator_index = (self.data_iterator_index + 1) % len(self.data_iterator_list)
self.data_iterator = self.data_iterator_list[self.data_iterator_index]

def load_next_batch(self):
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
Expand All @@ -47,6 +58,7 @@ def load_next_batch(self):
example_batch = self.last_batch
else:
example_batch = next(self.data_iterator)
self.update_data_iterator()
# Reshard data from loaded sharding to performant activation sharding
self.last_batch = sharding.maybe_shard_with_name(
example_batch,
Expand Down
9 changes: 7 additions & 2 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,13 @@ def generation_worker_fn(
max_utils.print_mem_stats("After params initialized")

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)

if config.save_checkpoint_on_completion:
state_to_save = _split_grpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
elif checkpoint_manager is not None:
# in case the last checkpoint_period checkpoint is still in progress
checkpoint_manager.wait_until_finished()
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
Expand Down
Loading
Loading