diff --git a/docs/guides/data_input_pipeline/data_input_grain.md b/docs/guides/data_input_pipeline/data_input_grain.md index d411174c8..578d9e61f 100644 --- a/docs/guides/data_input_pipeline/data_input_grain.md +++ b/docs/guides/data_input_pipeline/data_input_grain.md @@ -70,3 +70,8 @@ eval_interval: 10000 eval_steps: 50 grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*' ``` +8. Experimental: resuming training with a different chip count +In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature: +* **Scaling up**: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag `expansion_factor_real_data` (default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data is `total number of hosts // expansion_factor_real_data`. Each of these data-loading hosts will load `expansion_factor_real_data * per_host_batch_size_to_train`. For code integrity, the non-loading hosts use a `PlaceHolderDataIterator` to generate dummy data, which is later discarded. A user can optionally set `max_checkify=true` to enable additional checks that ensure dummy data is not used for training. In this example, you would set `expansion_factor_real_data=2` to scale from 64 to 128 hosts. +* **Scaling down**: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag `expansion_factor_real_data` to have each host restore `1 / expansion_factor_real_data` data iterators. We then alternate between these iterators to produce batches. In this example, you would set `expansion_factor_real_data=0.5` to scale from 128 down to 64 hosts. +* **Note**: In both scaling up and scaling down scenarios, the `per_device_batch_size` must remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data. diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 4de5d309d..968db7d6a 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -15,17 +15,17 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" import time -from typing import Any +from typing import Any, Optional from absl import flags from etils import epath from flax.training import train_state -import grain.python as grain import jax from MaxText import exceptions from MaxText import max_logging from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE from MaxText.multihost_dataloading import MultiHostDataLoadIterator +from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 @@ -33,6 +33,11 @@ import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager # pylint: disable=too-many-positional-arguments +import dataclasses +import json + +import grain +from grain.python import PyGrainCheckpointHandler CheckpointManager = ocp.CheckpointManager CheckpointManagerOptions = ocp.CheckpointManagerOptions @@ -44,6 +49,83 @@ EmergencyReplicatorCheckpointManager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager +class GrainCheckpointHandler(PyGrainCheckpointHandler): + """A CheckpointHandler that allows specifying process_index and process_count.""" + + def save( + self, + directory: epath.Path, + # `item` is for backwards compatibility with older Orbax API, see + # https://orbax.readthedocs.io/en/latest/api_refactor.html. + item: Optional[Any] = None, + args: Any = None, + ): + """Saves the given iterator to the checkpoint in `directory`.""" + item = item or args.item # pytype:disable=attribute-error + + def save_single_process(item, process_index, process_count): + filename = directory / f"process_{process_index}-of-{process_count}.json" + if isinstance(item, grain.DatasetIterator): + state = json.dumps(item.get_state(), indent=4) + else: + state = item.get_state().decode() + filename.write_text(state) + + if isinstance(item, list): + for local_iterator, process_index, process_count in item: + save_single_process(local_iterator, process_index, process_count) + else: + process_index, process_count = jax.process_index(), jax.process_count() + save_single_process(item, process_index, process_count) + + def restore( + self, + directory: epath.Path, + item: Optional[Any] = None, + args: Any = None, + ) -> Any: + """Restores the given iterator from the checkpoint in `directory`.""" + item = item or args.item + process_index = getattr(args, "process_index", None) + process_count = getattr(args, "process_count", None) + + def restore_single_process(item, process_index, process_count): + filename = directory / f"process_{process_index}-of-{process_count}.json" + if not filename.exists(): + raise ValueError(f"File {filename} does not exist.") + state = filename.read_text() + if isinstance(item, grain.DatasetIterator): + state = json.loads(state) + else: + state = state.encode() + item.set_state(state) + return item + + if isinstance(item, list): + restored_items = [] + for data_iter, process_idx in zip(item, process_index): + restored_items.append(restore_single_process(data_iter, process_idx, process_count)) + return restored_items + else: + if process_index is None or process_count is None: + process_index, process_count = jax.process_index(), jax.process_count() + return restore_single_process(item, process_index, process_count) + + +@ocp.args.register_with_handler(GrainCheckpointHandler, for_save=True) +@dataclasses.dataclass +class GrainCheckpointSave(ocp.args.CheckpointArgs): + item: Any + + +@ocp.args.register_with_handler(GrainCheckpointHandler, for_restore=True) +@dataclasses.dataclass +class GrainCheckpointRestore(ocp.args.CheckpointArgs): + item: Any + process_index: Optional[int | list[int]] = None + process_count: Optional[int] = None + + def _load_full_state_from_path( path, abstract_unboxed_pre_state, @@ -111,17 +193,18 @@ def create_orbax_checkpoint_manager( max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}") + # Base configuration for all dataset types + item_names = ("items",) + # we need to use ocdbt and zarr3 to control max file size in the checkpoint + item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} + if dataset_type == "grain": - item_names = ("items", "iter") - else: - item_names = ("items",) + item_names += ("iter",) + item_handlers["iter"] = GrainCheckpointHandler() # local storage checkpoint needs parent directory created p = epath.Path(checkpoint_dir) p.mkdir(exist_ok=True, parents=True) - # we need to use ocdbt and zarr3 to control max file size in the checkpoint - # omitting `iter` uses default handler for `iter` - item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} manager = CheckpointManager( p, item_names=item_names, @@ -273,6 +356,97 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): return np.expand_dims(replica_result, axis=replica_axis_idx) +def _prepare_scaled_down_grain_restore_args( + data_iterator: list, process_count_jax: int, process_count_stored: int, directory: epath.Path +) -> GrainCheckpointRestore: + """ + Prepares the restore arguments for a scaled-up (list) data iterator. + + This is used when restoring a checkpoint saved with more processes than + the current run (e.g., 64 files onto 32 JAX processes). + """ + # 1. Validation Assertions + assert isinstance(data_iterator, list), ( + f"{process_count_stored} processes found in Grain checkpoint directory {directory}, but only " + f"{process_count_jax} jax processes in this run, please set expansion_factor_real_data accordingly." + ) + + scaling_factor = len(data_iterator) + expected_process_count = process_count_stored / process_count_jax + assert scaling_factor == expected_process_count, ( + f"Found {process_count_stored} processes in checkpoint and {process_count_jax} " + f"JAX processes, implying a scaling factor of {expected_process_count}. " + f"However, the data_iterator list has {scaling_factor} items." + ) + + # 2. Prepare Arguments + local_iterator_list = [x.local_iterator for x in data_iterator] + # Each JAX process calculates the global indices it's responsible for. + # e.g., process 0 with scaling_factor=2 handles checkpoints from processes [0, 32] + # e.g., process 1 with scaling_factor=2 handles checkpoints from processes [1, 33] + process_index_list = [jax.process_index() + i * process_count_jax for i in range(scaling_factor)] + + return GrainCheckpointRestore(local_iterator_list, process_index=process_index_list, process_count=process_count_stored) + + +def _restore_grain_iterator( + checkpoint_manager, + step: int, + data_iterator, + checkpoint_args, + expansion_factor_real_data: int, # This must be defined in the outer scope +) -> tuple[Any, None]: + """ + Handles the complex logic for restoring a Grain data iterator checkpoint. + This function dispatches to the correct restore strategy based on + the number of stored checkpoint files vs. current JAX processes. + """ + directory = checkpoint_manager.directory / str(step) / "iter" + process_count_jax = jax.process_count() + + # Count the number of checkpoint files + process_count_stored = len(list(directory.glob("process_*-of-*.json"))) + + grain_restore_args = None + + if process_count_stored > process_count_jax: + # Scaling down from a larger number of hosts. (e.g., 128 files -> 64 processes) + # In this case, each host restores a list of data iterators. + grain_restore_args = _prepare_scaled_down_grain_restore_args( + data_iterator, process_count_jax, process_count_stored, directory + ) + + elif process_count_stored == process_count_jax: + # Normal case: number of hosts is the same. (e.g., 64 files -> 64 processes) + assert not isinstance(data_iterator, list), ( + f"{process_count_stored} processes found in Grain checkpoint directory {directory}, matching the number of " + "jax process, please do not set expansion_factor_real_data." + ) + grain_restore_args = GrainCheckpointRestore(data_iterator.local_iterator) + + elif expansion_factor_real_data > 0 and process_count_stored == process_count_jax // expansion_factor_real_data: + # Scaling up to a larger number of hosts.(e.g., 32 files -> 64 processes) + # In this case, a subset of hosts restore the data iterator. + grain_restore_args = GrainCheckpointRestore( + data_iterator.local_iterator, process_index=jax.process_index(), process_count=process_count_stored + ) + + else: + # Case 4: Mismatch + raise ValueError( + f"Error restoring Grain checkpoint in {directory}: " + f"The number of stored checkpoint files ({process_count_stored}) " + f"is incompatible with the number of JAX processes ({process_count_jax}). " + "If you are resuming training with a different number of chips, see instructions in " + "https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/" + "data_input_grain.md#using-grain" + ) + + # Call restore once with the composed arguments + restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args)) + return (restored_state, None) + + def load_state_if_possible( checkpoint_manager: CheckpointManager | None, data_iterator: MultiHostDataLoadIterator | None, @@ -288,6 +462,7 @@ def load_state_if_possible( enable_orbax_v1=False, checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", + expansion_factor_real_data: int = -1, ): """Loads TrainState as possible from the inputs. @@ -354,20 +529,27 @@ def map_to_pspec(data): # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and # 'data_iterator' can be any value and aren't used in this pattern. case (checkpoint_manager, _, _) if isinstance( - checkpoint_manager, - (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), + checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) ): return ( checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, None, ) - # Case 2: Matches if dataset type is "grain" and a specific checkpoint file exits for the iterator - # exists within the checkpoint manager's directory for the given step. - case (checkpoint_manager, dataset_type, data_iterator) if dataset_type == "grain" and data_iterator and ( - checkpoint_manager.directory / str(step) / "iter" - ).exists(): - grain_iter = grain.PyGrainCheckpointRestore(data_iterator.local_iterator) - return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_iter)), None) + # Case 2: Matches if dataset type is "grain" and the data iterator is not a + # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator + case ( + checkpoint_manager, + dataset_type, + data_iterator, + ) if ( + dataset_type == "grain" + and data_iterator + and not isinstance(data_iterator, PlaceHolderDataIterator) + and (checkpoint_manager.directory / str(step) / "iter").exists() + ): + return _restore_grain_iterator( + checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data + ) # Case 3: Default/Fallback case. # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. case _: @@ -518,15 +700,25 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= save_args=jax.tree.map(lambda _: ocp.SaveArgs(chunk_byte_size=chunk_byte_size), state), ocdbt_target_data_file_size=chunk_byte_size, ) - - match (checkpoint_manager, config): - case (checkpoint_manager, _) if isinstance( + save_args_composite = {"items": checkpoint_args} + + if config and config.dataset_type == "grain" and not isinstance(data_iterator, PlaceHolderDataIterator): + if not isinstance(data_iterator, list): + data_iterator = [data_iterator] + grain_iters_to_save = [] + process_count_total = jax.process_count() * len(data_iterator) + if config.expansion_factor_real_data > 1: + process_count_total = process_count_total // config.expansion_factor_real_data + for i, data_iter in enumerate(data_iterator): + process_index = jax.process_index() + i * jax.process_count() + grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total)) + save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) + + match (checkpoint_manager, config, data_iterator): + case (checkpoint_manager, _, _) if isinstance( checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) ): replicator_error_handler(config) return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) - case (_, config) if config and config.dataset_type == "grain": - grain_iter = grain.PyGrainCheckpointSave(data_iterator.local_iterator) - return checkpoint_manager.save(step, args=Composite(items=checkpoint_args, iter=grain_iter), force=force) case _: - return checkpoint_manager.save(step, args=Composite(items=checkpoint_args), force=force) + return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index c6367bd51..c2b1760c0 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -496,7 +496,11 @@ add_eos: True # Dataset per_device_batch_size: 12.0 -expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS. +# When expansion_factor_real_data is set to > 1, total_hosts//expansion_factor_real_data will load data. +# Each data-loading host will load per_device_batch_size * expansion_factor_real_data. +# When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. +# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain +expansion_factor_real_data: -1.0 eval_per_device_batch_size: 0.0 max_corpus_chars: 10_000_000 train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" diff --git a/src/MaxText/data_loader.py b/src/MaxText/data_loader.py index 360382548..3c76144f9 100644 --- a/src/MaxText/data_loader.py +++ b/src/MaxText/data_loader.py @@ -35,10 +35,21 @@ class DataLoader: def __init__(self, config, mesh, data_iterator, goodput_recorder): self.config = config self.goodput_recorder = goodput_recorder - self.data_iterator = data_iterator + if isinstance(data_iterator, list): + self.data_iterator_list = data_iterator + self.data_iterator_index = 0 + self.data_iterator = self.data_iterator_list[self.data_iterator_index] + else: + self.data_iterator = data_iterator self.last_batch = None self.input_data_shardings = sharding.get_input_data_sharding(config, mesh) + def update_data_iterator(self): + """Update to the next data iterator in the list, if applicable.""" + if hasattr(self, "data_iterator_list"): + self.data_iterator_index = (self.data_iterator_index + 1) % len(self.data_iterator_list) + self.data_iterator = self.data_iterator_list[self.data_iterator_index] + def load_next_batch(self): """Loads the next batch. Can keep reusing the same batch for performance reasons.""" with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING): @@ -47,6 +58,7 @@ def load_next_batch(self): example_batch = self.last_batch else: example_batch = next(self.data_iterator) + self.update_data_iterator() # Reshard data from loaded sharding to performant activation sharding self.last_batch = sharding.maybe_shard_with_name( example_batch, diff --git a/src/MaxText/experimental/rl/grpo_trainer.py b/src/MaxText/experimental/rl/grpo_trainer.py index 19393d0e3..16abc9b6d 100644 --- a/src/MaxText/experimental/rl/grpo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_trainer.py @@ -878,8 +878,13 @@ def generation_worker_fn( max_utils.print_mem_stats("After params initialized") metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) - state_to_save = _split_grpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + + if config.save_checkpoint_on_completion: + state_to_save = _split_grpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + elif checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() except exceptions.StopTraining as e: max_logging.log(f"Training stopped: {str(e)}") finally: diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 1749fd8a5..b92a6f194 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -122,6 +122,11 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra ) # Pack and Batch examples. batch_size = config.global_batch_size_to_load // jax.process_count() + if config.expansion_factor_real_data > 1: + # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. + # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. + # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. + batch_size = batch_size // config.expansion_factor_real_data if config.packing: length_struct = {col: config.max_target_length for col in data_columns} dataset = grain.experimental.FirstFitPackIterDataset( @@ -194,7 +199,7 @@ def make_grain_train_iterator( assert ( config.global_batch_size_to_load % global_mesh.size == 0 ), "Batch size should be divisible by number of global devices." - if not config.colocated_python_data_input: + if not config.colocated_python_data_input and not 0 < config.expansion_factor_real_data < 1: train_ds = get_datasets( config.grain_train_files, config.grain_file_type, @@ -222,7 +227,10 @@ def make_grain_train_iterator( grain_worker_count=config.grain_worker_count, ) return multihost_dataloading.MultiHostDataLoadIterator( - train_dataloader, global_mesh, config.generate_padding_batch_train + train_dataloader, + global_mesh, + config.generate_padding_batch_train, + expansion_loading_factor_for_grain=config.expansion_factor_real_data, ) else: get_ds_fn = functools.partial( @@ -250,8 +258,23 @@ def make_grain_train_iterator( tokenize=config.tokenize_train_data, grain_worker_count=config.grain_worker_count, ) - global_shape = (config.global_batch_size_to_load, config.max_target_length) - return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape) + if config.colocated_python_data_input: + global_shape = (config.global_batch_size_to_load, config.max_target_length) + return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape) + else: + # config.expansion_factor_real_data is between 0 and 1 + num_dataloader_to_restore = int(1 / config.expansion_factor_real_data) + train_dataloader_list = [] + dataloading_host_count = len(process_indices) * num_dataloader_to_restore + for i in range(num_dataloader_to_restore): + dataloading_host_index = len(process_indices) * i + process_indices.index(jax.process_index()) + train_ds = get_ds_fn(dataloading_host_index=dataloading_host_index, dataloading_host_count=dataloading_host_count) + train_dataloader = preprocessing_fn(train_ds) + train_dataloader_list.append(train_dataloader) + return [ + multihost_dataloading.MultiHostDataLoadIterator(x, global_mesh, config.generate_padding_batch_train) + for x in train_dataloader_list + ] def make_grain_eval_iterator( diff --git a/src/MaxText/input_pipeline/input_pipeline_interface.py b/src/MaxText/input_pipeline/input_pipeline_interface.py index 34f92946f..27b105bb2 100644 --- a/src/MaxText/input_pipeline/input_pipeline_interface.py +++ b/src/MaxText/input_pipeline/input_pipeline_interface.py @@ -89,7 +89,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): mesh, ) output_train_iterator = create_process_specific_iterator(config, mesh, process_indices_train, train_iterator) - if config.expansion_factor_real_data != -1: # assert number of hosts loading real data + if config.expansion_factor_real_data > 1: # assert number of hosts loading real data assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data # Generate output eval iterator @@ -103,7 +103,7 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh): mesh, ) - if config.expansion_factor_real_data != -1: + if config.expansion_factor_real_data > 1: assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data output_eval_iterator = create_process_specific_iterator(config, mesh, process_indices_eval, eval_iterator) return output_train_iterator, output_eval_iterator diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index ac93b4104..6cb8b1e10 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -848,6 +848,7 @@ def setup_initial_state( enable_orbax_v1=config.enable_orbax_v1, checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, + expansion_factor_real_data=config.expansion_factor_real_data, ) if restored: @@ -860,8 +861,7 @@ def setup_initial_state( ): state = restored else: - if "iter" in restored and restored["iter"] is not None: - data_iterator.local_iterator = restored["iter"] + # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] else: init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) diff --git a/src/MaxText/multihost_dataloading.py b/src/MaxText/multihost_dataloading.py index 29a64cfe6..dd9ca2a81 100644 --- a/src/MaxText/multihost_dataloading.py +++ b/src/MaxText/multihost_dataloading.py @@ -66,9 +66,17 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: class MultiHostDataLoadIterator: - """fold get_next_batch_sharded into a iterator class""" - - def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, generate_padding_batch: bool = False): + """fold get_next_batch_sharded into a iterator class. + expansion_factor_for_grain is only used for grain pipeline when having a subset of hosts loading real data. + """ + + def __init__( + self, + dataloader: tf.data.Dataset | Iterable, + global_mesh: Mesh, + generate_padding_batch: bool = False, + expansion_loading_factor_for_grain: int = -1, + ): self.global_mesh = global_mesh self.dataloader = dataloader if isinstance(self.dataloader, tf.data.Dataset): @@ -80,6 +88,7 @@ def __init__(self, dataloader: tf.data.Dataset | Iterable, global_mesh: Mesh, ge self.out_of_data = False self.last_local_data = None self.generate_padding_batch = generate_padding_batch + self.expansion_loading_factor_for_grain = expansion_loading_factor_for_grain def reset(self): if isinstance(self.dataloader, tf.data.Dataset): @@ -111,6 +120,15 @@ def _get_next_batch_sharded(self) -> jax.Array: for _ in range(MAX_DATA_LOAD_ATTEMPTS): try: local_data = next(self.local_iterator) + if self.expansion_loading_factor_for_grain > 1: + # Since grain checkpoint requires fixed batch_size, we run the dataIterator for + # expansion_loading_factor_for_grain times to get the + # right batch_size for the host that is loading real data. + local_data_list = [local_data] + for _ in range(1, self.expansion_loading_factor_for_grain): + next_batch = next(self.local_iterator) + local_data_list.append(next_batch) + local_data = jtu.tree_map(lambda *xs: np.concatenate(xs, axis=0), *local_data_list) break # exit the loop on success except tf.errors.FailedPreconditionError as e: max_logging.log(f"Failed to get next data batch due to {e}, retrying") diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index c87a7a5a1..14405ad53 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -31,9 +31,8 @@ from MaxText import accelerator_to_spec_map from MaxText import max_logging from MaxText import max_utils -from MaxText.common_types import DecoderBlockType, ShardMode from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR -from MaxText.layers.attentions import AttentionType +from MaxText.common_types import AttentionType, DecoderBlockType, ShardMode from MaxText.utils import gcs_utils @@ -1237,12 +1236,12 @@ def calculate_global_batch_sizes( """Calculates target global batch size from target devices and per_device_batch""" if per_device_batch_size < 1.0: # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 - if expansion_factor_real_data != -1: + if expansion_factor_real_data > 1: micro_batch_size_to_load = num_devices * expansion_factor_real_data else: micro_batch_size_to_load = num_devices else: - if expansion_factor_real_data != -1: + if expansion_factor_real_data > 1: micro_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data) else: micro_batch_size_to_load = int(num_devices * per_device_batch_size) diff --git a/src/MaxText/sft_trainer.py b/src/MaxText/sft_trainer.py index 8075478fe..d48caa37e 100644 --- a/src/MaxText/sft_trainer.py +++ b/src/MaxText/sft_trainer.py @@ -139,7 +139,11 @@ def train_loop(config, recorder, state=None): metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) - checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator) + if config.save_checkpoint_on_completion: + checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator) + elif checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() except exceptions.StopTraining as e: max_logging.log(f"Training stopped: {str(e)}") finally: diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 47f330001..2761d0d46 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -459,6 +459,9 @@ def train_loop(config, recorder, state=None): if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + if checkpoint_manager is not None: + # in case the last checkpoint_period checkpoint is still in progress + checkpoint_manager.wait_until_finished() except exceptions.StopTraining as e: max_logging.log(f"Training stopped: {str(e)}") finally: