diff --git a/docs/source/advanced/epochized_blending.md b/docs/source/advanced/epochized_blending.md index dcd68e11..1d8f0a77 100644 --- a/docs/source/advanced/epochized_blending.md +++ b/docs/source/advanced/epochized_blending.md @@ -25,20 +25,24 @@ splits: split_part: val ``` -Now, the call to `get_train_dataset` requires the additional parameter `repeat=False` to interrupt iterating after one epoch: +Now, the call to `get_train_dataset` requires the additional parameter `repeat=False` to interrupt iterating after one epoch. +In addition, the `RedistributeLoader` is used to synchronize the last samples when some ranks have already exhausted their data: ```py from megatron.energon import get_train_dataset, get_loader, WorkerConfig +from megatron.energon.sync_end import RedistributeLoader -loader = get_loader(get_train_dataset( +# Set repeat=False to avoid repeating the dataset. +# Use RedistributeLoader to synchronize the end of rank exhaustion. Only works with initialized torch distributed. +loader = RedistributeLoader(get_loader(get_train_dataset( 'metadataset.yaml', batch_size=2, shuffle_buffer_size=100, max_samples_per_sequence=100, worker_config=WorkerConfig.default_worker_config(), repeat=False, -)) +))) # This will now stop iterating after the datasets have been iterated (coco 5 times, coyo-train 2 # times and coyo-val 1 times). Of course, the data is still being shuffled between all those @@ -54,3 +58,10 @@ for batch in loader: If used as dataset for `get_val_dataset`, the `repetitions` are ignored. The metadataset would also work without setting `repeat=False`, but then the shuffle buffer will shuffle samples across bounderies of epochs. + +## Sample Redistribution + +There are two options for handling the end of the data iteration: + +* `RedistributeLoader`: Redistributes samples when a rank is exhausted before other ranks. On next epoch, includes the incomplete batches' samples. +* `StopFirstLoader`: Stop as soon as the first rank is exhausted. The next epoch will iterate until the next loader stops, restarting all ranks once. \ No newline at end of file diff --git a/src/megatron/energon/sync_end/__init__.py b/src/megatron/energon/sync_end/__init__.py new file mode 100644 index 00000000..f47ad656 --- /dev/null +++ b/src/megatron/energon/sync_end/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from megatron.energon.sync_end.redistribute import RedistributeDataLoaderState, RedistributeLoader +from megatron.energon.sync_end.stop_first_end import StopFirstDataLoaderState, StopFirstLoader + +""" +Provides wrappers for the dataset loaders that allow for synchronization at the end of the dataset. +I.e. if running a training with repeat=False, the loaders will typically exhaust at different times, which may require +synchronization across ranks. + +The wrappers are: +- RedistributeLoader: Redistributes the last samples to the ranks that are not exhausted. +- StopFirstLoader: Stops iterating as soon as the first rank is exhausted. +""" + +__all__ = [ + "RedistributeLoader", + "RedistributeDataLoaderState", + "StopFirstLoader", + "StopFirstDataLoaderState", +] diff --git a/src/megatron/energon/sync_end/redistribute.py b/src/megatron/energon/sync_end/redistribute.py new file mode 100644 index 00000000..7e432189 --- /dev/null +++ b/src/megatron/energon/sync_end/redistribute.py @@ -0,0 +1,452 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Any, Dict, Generic, Iterator, Optional, Sequence, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist + +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import State +from megatron.energon.savable_loader import ( + BasicDataLoader, + SavableDataLoader, + SavableDataLoaderState, +) +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import get_sample_restore_key + +__all__ = ["RedistributeLoader", "RedistributeDataLoaderState"] + + +T = TypeVar("T") + + +@edataclass +class RedistributeDataLoaderState(State): + inner_state: SavableDataLoaderState | None + + exhausted_state: bool + overuse_count: int + + next_samples_restore_keys: list[Tuple[Union[str, int, tuple], ...]] | None + + def __repr__(self): + return f"RedistributeLoaderState(inner_state={self.inner_state!r}, exhausted_state={self.exhausted_state!r}, overuse_count={self.overuse_count!r})" + + +class RedistributeLoader(Generic[T]): + """ + A loader that wraps the actual loader and redistributes the last samples to the ranks that are not exhausted. + The last incomplete batch (i.e. where not all ranks have data available) is not iterated. + + It is useful for trainings where the dataset is not repeated. + + Stages: + First stage: Iterate until one rank is exhausted. + Second stage: Iterate until all ranks are exhausted and the global batch is incomplete. + Collect how many samples are required to satisfy the need for a global batch. + Fetch those additional needed samples from the ranks that have the least overuse count. + Directly communicate the samples from the overfetched ranks to the exhausted ranks in round robin fashion. + Distribute the samples to the ranks that are not exhausted. + If starting a new iterator after global exhaustion, perform another epoch (also emitting the samples from the last + incomplete batch). + """ + + # large int64 number we'll never reach for overuse counts. + OVERUSE_COUNT_MAX = 0x1000000000000000 + + loader: SavableDataLoader | BasicDataLoader + worker_config: WorkerConfig + distributed_device: str + + overuse_counts: torch.Tensor + exhausted_states: torch.Tensor + _exhausted_states_list: list[torch.Tensor] + + _iterator: Iterator[T] | None = None + + _next_samples: list[T] | None = None + _next_sample_restore_keys: list[Tuple[Union[str, int, tuple], ...]] | None = None + + def __init__(self, loader: SavableDataLoader | BasicDataLoader): + self.loader = loader + self.worker_config = loader.worker_config + self.distributed_device = ( + "cuda" + if dist.is_available() + and dist.is_initialized() + and dist.get_backend() == dist.Backend.NCCL + else "cpu" + ) + self.overuse_counts = torch.zeros( + self.worker_config.world_size, dtype=torch.int64, requires_grad=False + ) + self.exhausted_states = torch.zeros( + self.worker_config.world_size, + dtype=torch.uint8, + device=self.distributed_device, + requires_grad=False, + ) + self._exhausted_states_list = [ + self.exhausted_states[i] for i in range(self.worker_config.world_size) + ] + + def _find_ranks_to_oversample(self, needed_samples: int) -> int: + oversample_self = 0 + while needed_samples > 0: + min_overuse_idx = torch.where(self.overuse_counts == torch.min(self.overuse_counts))[ + 0 + ].cpu() + # print(f"[r={self.worker_config.rank}]: Min overuse idx: {min_overuse_idx}\n", end="") + for rank in min_overuse_idx: + self.overuse_counts[rank] += 1 + if rank == self.worker_config.rank: + oversample_self += 1 + needed_samples -= 1 + if needed_samples == 0: + break + + return oversample_self + + def _as_global_rank(self, rank: int) -> int: + if self.worker_config.data_parallel_group is not None: + return dist.get_global_rank(self.worker_config.data_parallel_group, rank) + else: + return rank + + def __iter__(self): + if self._iterator is None: + self._iterator = iter(self.loader) + + samples: list[T] = [] + + if self._next_samples is not None: + samples.extend(self._next_samples) + self._next_samples = None + self._next_sample_restore_keys = None + elif self._next_sample_restore_keys is not None: + samples.extend( + self.restore_sample(restore_key) for restore_key in self._next_sample_restore_keys + ) + self._next_sample_restore_keys = None + + rank = self.worker_config.rank + + # Ensure the initial state is synchronized (e.g. if restored from a checkpoint) + dist.all_gather( + self._exhausted_states_list, + self.exhausted_states[rank], + group=self.worker_config.data_parallel_group, + ) + overuse_count_sync = torch.zeros( + self.worker_config.world_size, + dtype=torch.int64, + device=self.distributed_device, + requires_grad=False, + ) + dist.all_gather( + [overuse_count_sync[i] for i in range(self.worker_config.world_size)], + self.overuse_counts[rank].to(device=self.distributed_device), + group=self.worker_config.data_parallel_group, + ) + self.overuse_counts[:] = overuse_count_sync.cpu() + + # Iterate until any rank is exhausted + self_exhausted = 0 + while not self.exhausted_states.any(): + if len(samples) > 0: + # First use pending samples from previous iteration + sample = samples.pop(0) + else: + try: + sample = next(self._iterator) + except StopIteration: + # print(f"[r={rank}]: StopIteration\n", end="") + self.exhausted_states[rank] = self_exhausted = 1 + dist.all_reduce( + self.exhausted_states[rank], + op=dist.ReduceOp.MAX, + group=self.worker_config.data_parallel_group, + ) + global_any_exhausted = bool(self.exhausted_states[rank].item()) + + if global_any_exhausted: + # print(f"[r={rank}]: One rank exhausted\n", end="") + self.exhausted_states[rank] = self_exhausted + if not self_exhausted: + # print(f"[r={rank}]: Not exhausted, storing sample\n", end="") + samples.append(sample) + break + + yield sample + + sync_ranks = True + sample_count = torch.zeros( + self.worker_config.world_size, + dtype=torch.int64, + device=self.distributed_device, + requires_grad=False, + ) + sample_count_list = [sample_count[i] for i in range(self.worker_config.world_size)] + + # Redistribute the samples until all ranks are exhausted + # * The ranks with the least overuse count shall fetch more sample(s) as needed + # * The ranks which are already exhausted shall receive one sample of the additionally fetched samples + while not self.exhausted_states.all() or sync_ranks: + if sync_ranks: + # Share all exhausted states + dist.all_gather( + self._exhausted_states_list, + self.exhausted_states[rank], + group=self.worker_config.data_parallel_group, + ) + exhausted_cpu = self.exhausted_states.cpu() + if exhausted_cpu.all(): + break + for i in range(self.worker_config.world_size): + if exhausted_cpu[i]: + self.overuse_counts[i] = self.OVERUSE_COUNT_MAX + + # Check if there are enough samples to satisfy the need + dist.all_gather( + sample_count_list, + torch.tensor(len(samples), dtype=torch.int64, device=self.distributed_device), + group=self.worker_config.data_parallel_group, + ) + needed_samples = self.worker_config.world_size - sample_count.sum().item() + + # print(f"[r={rank}]: Exhausted: {self.exhausted_states.cpu()}, Sample count: {sample_count.cpu()}, overuse counts: {self.overuse_counts}\n", end="") + + # The ranks are now in sync with all dataloader states and sample counts + sync_ranks = False + if needed_samples > 0: + # print(f"[r={rank}]: Need {needed_samples} samples\n", end="") + # Not enough samples to satisfy the need -> fetch more on non-exhausted ranks + oversample_self = self._find_ranks_to_oversample(needed_samples) + # print(f"[r={rank}]: Oversample self {oversample_self} samples\n", end="") + while oversample_self > 0: + try: + samples.append(next(self._iterator)) + # print(f"[r={rank}]: Got {len(samples)} samples\n", end="") + except StopIteration: + # print(f"[r={rank}]: Exhausted\n", end="") + self.exhausted_states[rank] = 1 + break + else: + oversample_self -= 1 + # print(f"[r={rank}]: Got {len(samples)} samples\n", end="") + # Loop again, in case another rank exhausted now and did not get a sample, sync ranks again to be sure + sync_ranks = True + continue + else: + # All ranks in sum have enough samples -> distribute the samples. + assert needed_samples == 0, ( + f"Needed {needed_samples} samples, but have {sample_count.sum().item()}" + ) + + # For each sample, compute the samples that can be distributed to other ranks + sending_ranks = [ + (rank, idx) + for rank, count in enumerate(sample_count) + for idx in range(1, count) + ] + # Compute the rank that is going to send a sample to this rank + # xor the ranks that are going to receive a sample from this rank + self_source_rank: int | None = None + # List of (target_rank, source_sample_idx) that are going to receive a sample from this rank + self_target_ranks: list[tuple[int, int]] = [] + for chk_rank in range(self.worker_config.world_size): + if sample_count[chk_rank] == 0: + # This rank is not going to receive a sample, because it has no samples + # Take the first sample from the sending ranks and send it to this rank + src_rank, src_idx = sending_ranks.pop() + # print(f"[r={rank}]: Sending sample {src_idx} from rank {src_rank} to rank {chk_rank}\n", end="") + if chk_rank == rank: + # If the self rank is the receiving rank, store the rank we're receiving from + self_source_rank = src_rank + elif src_rank == rank: + # If the self rank is the sending rank, store the rank we're sending to and which sample + self_target_ranks.append((chk_rank, src_idx)) + + if self_source_rank is not None: + # print(f"[r={rank}]: Receiving sample from rank {self_source_rank}\n", end="") + # This rank is going to receive a sample from that other rank + object_list = [None] + # Receive the sample from the source rank. Requires the global rank, disregarding the group + dist.recv_object_list(object_list, src=self._as_global_rank(self_source_rank)) + samples.append(object_list[0]) + elif len(self_target_ranks) > 0: + # This rank is going to send a sample to that other rank(s) + for dst_rank, sample_idx in self_target_ranks: + # print(f"[r={rank}]: Sending sample {sample_idx} to rank {dst_rank}\n", end="") + object_list = [samples[sample_idx]] + samples[sample_idx] = None + # Send the sample to the destination rank. Requires the global rank, disregarding the group + dist.send_object_list(object_list, dst=self._as_global_rank(dst_rank)) + + # Remove the samples that have been distributed + for i in range(len(samples) - 1, -1, -1): + if samples[i] is None: + del samples[i] + if len(samples) > 1: + # It may happen, that there are more samples than needed (case: a rank had oversampled, but not enough + # to provide for all ranks; restarting the loop, and one rank has no initial samples, but others do have + # samples now, then the oversampled samples cannot be distributed to all ranks) + # Need to save the samples in case the loop is interrupted at the yield + self._next_samples = samples + else: + # Ensure no next samples are set, all were consumed + self._next_samples = None + # assert len(samples) == 1, f"Every rank should have one sample now, have {len(samples)} on rank {self.worker_config.rank}" + # Important: When yielding, the samples list is empty. It's not part of the state, so it does not need + # to be saved. + yield samples.pop(0) + # print(f"[r={rank}]: Yielded sample {len(samples)}, getting next\n", end="") + + needed_samples = self.worker_config.world_size + + # print(f"[r={rank}]: Done iterating\n", end="") + + self._next_samples = samples + self._next_sample_restore_keys = None + + # Done iterating, reset the iterator + self._iterator = None + self.exhausted_states.fill_(0) + self.overuse_counts.fill_(0) + + def __len__(self): + return len(self.loader) + + def save_state_rank(self) -> RedistributeDataLoaderState: + assert isinstance(self.loader, SavableDataLoader) + if self._next_sample_restore_keys is not None: + restore_keys = self._next_sample_restore_keys + elif self._next_samples is not None: + restore_keys = self._next_sample_restore_keys = [ + get_sample_restore_key(sample) for sample in self._next_samples + ] + else: + restore_keys = None + return RedistributeDataLoaderState( + inner_state=self.loader.save_state_rank(), + overuse_count=int(self.overuse_counts[self.worker_config.rank].item()), + exhausted_state=bool(self.exhausted_states[self.worker_config.rank].item()), + next_samples_restore_keys=restore_keys, + ) + + def restore_state_rank(self, state: RedistributeDataLoaderState) -> None: + assert isinstance(self.loader, SavableDataLoader) + self.loader.restore_state_rank(state.inner_state) + self._next_sample_restore_keys = state.next_samples_restore_keys + self._next_samples = None + self.overuse_counts[self.worker_config.rank] = state.overuse_count + self.exhausted_states[self.worker_config.rank] = state.exhausted_state + + def can_restore_sample(self) -> bool: + return self.loader.can_restore_sample() + + def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: + return self.loader.restore_sample(restore_key) + + def save_state_global( + self, global_dst_rank: int + ) -> Optional[Sequence[RedistributeDataLoaderState]]: + """ + See :meth:`megatron.energon.SavableDataLoader.save_state_global` + """ + # Fetch current rank's worker's state + merged_state = self.save_state_rank() + + # Gather the merged states + if self.worker_config.world_size > 1: + output: Optional[Sequence[RedistributeDataLoaderState]] + if self.worker_config.global_rank() == global_dst_rank: + output = [None] * self.worker_config.world_size + else: + # Check if the global_dst_rank is in the same group at all + if self.worker_config.data_parallel_group is not None: + try: + _ = torch.distributed.get_group_rank( + self.worker_config.data_parallel_group, global_dst_rank + ) + except RuntimeError: + raise ValueError( + f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" + ) + + output = None + + torch.distributed.gather_object( + merged_state, + output, + global_dst_rank, + group=self.worker_config.data_parallel_group, + ) + + return output + else: + # Not distributed -> return the merged state + return [merged_state] + + def restore_state_global( + self, + state: Optional[Sequence[RedistributeDataLoaderState]], + *, + src_rank: Optional[int] = None, + ) -> None: + """ + See :meth:`megatron.energon.SavableDataLoader.restore_state_global` + """ + assert self._iterator is None, "Cannot restore state while workers are running" + + if src_rank is None or self.worker_config.world_size == 1: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self.worker_config.world_size, ( + "State must be a list of size world_size" + ) + + rank_state = state[self.worker_config.rank] + else: + if self.worker_config.data_parallel_group is not None: + # Only the src_rank has the state within this dp group + try: + global_src_rank = torch.distributed.get_global_rank( + self.worker_config.data_parallel_group, src_rank + ) + except RuntimeError: + raise ValueError( + f"src_rank {src_rank} is not in the group of the current rank's worker config" + ) + else: + # If no DP group is given, we assume the global rank is + # the same as the data parallel rank + global_src_rank = src_rank + + if self.worker_config.rank != src_rank: + # Send the state to all other ranks + assert state is None + # Must still be a list of Nones + state = [None] * self.worker_config.world_size + else: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self.worker_config.world_size, ( + "State must be a list of size world_size" + ) + + local_object = [None] + torch.distributed.scatter_object_list( + local_object, + state, + src=global_src_rank, + group=self.worker_config.data_parallel_group, + ) + rank_state = local_object[0] + + self.restore_state_rank(rank_state) + + def config(self) -> Dict[str, Any]: + return { + "type": type(self).__qualname__, + "loader": self.loader.config(), + } diff --git a/src/megatron/energon/sync_end/stop_first_end.py b/src/megatron/energon/sync_end/stop_first_end.py new file mode 100644 index 00000000..7dffef34 --- /dev/null +++ b/src/megatron/energon/sync_end/stop_first_end.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Any, Dict, Generic, Iterator, Optional, Sequence, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist + +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import State +from megatron.energon.savable_loader import ( + BasicDataLoader, + SavableDataLoader, + SavableDataLoaderState, +) +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import get_sample_restore_key + +__all__ = ["StopFirstLoader", "StopFirstDataLoaderState"] + + +T = TypeVar("T") + + +@edataclass +class StopFirstDataLoaderState(State): + inner_state: SavableDataLoaderState | None + iterating_from_start: bool + next_sample_restore_key: Tuple[Union[str, int, tuple], ...] | None + + def __repr__(self): + return f"StopFirstDataLoaderState(inner_state={self.inner_state!r}, iterating_from_start={self.iterating_from_start!r}, next_sample_restore_key={self.next_sample_restore_key!r})" + + +class StopFirstLoader(Generic[T]): + """ + A loader that stops as soon as the first rank is exhausted. + If continuing a second time, it will restart the previously exhausted rank and iterate until the next rank is + exhausted, restarting all ranks once. + + This is useful for trainings where the dataset is not repeated. + """ + + loader: SavableDataLoader | BasicDataLoader + worker_config: WorkerConfig + distributed_device: str + _iterator: Iterator[T] | None = None + _iterating_from_start: bool = True + _next_sample: T | None = None + _next_sample_restore_key: Tuple[Union[str, int, tuple], ...] | None = None + + def __init__(self, loader: SavableDataLoader | BasicDataLoader): + self.loader = loader + self.worker_config = loader.worker_config + self.distributed_device = ( + "cuda" + if dist.is_available() + and dist.is_initialized() + and dist.get_backend() == dist.Backend.NCCL + else "cpu" + ) + + def __iter__(self): + if self._iterator is None: + self._iterator = iter(self.loader) + + # Check if torch distributed is using cuda + flag = torch.zeros( + 1, dtype=torch.uint8, device=self.distributed_device, requires_grad=False + ) + + # If there is a pending sample, use it + if self._next_sample is not None: + sample = self._next_sample + local_has_sample = 1 + self._next_sample = None + self._next_sample_restore_key = None + # print(f"[r={self.worker_config.rank}]: Using pending sample\n", end="") + elif self._next_sample_restore_key is not None: + sample = self.restore_sample(self._next_sample_restore_key) + self._next_sample = None + self._next_sample_restore_key = None + local_has_sample = 1 + # print(f"[r={self.worker_config.rank}]: Using restored pending sample\n", end="") + else: + sample = None + local_has_sample = 0 + # print(f"[r={self.worker_config.rank}]: No pending sample\n", end="") + + while True: + if not local_has_sample: + try: + sample = next(self._iterator) + local_has_sample = 1 + except StopIteration: + if not self._iterating_from_start: + # If not iterating from start (i.e. another rank already exhausted and ended the epoch), + # The second epoch should ignore ending iterators and continue iterating. + self._iterator = iter(self.loader) + self._iterating_from_start = True + # print(f"[r={self.worker_config.rank}]: Restarting iterator\n", end="") + continue + # print(f"[r={self.worker_config.rank}]: No samples left\n", end="") + local_has_sample = 0 + + flag.fill_(local_has_sample) + # Compute *global* logical *AND* over all ranks. Using MIN as this + # is equivalent to logical AND for 0/1 bits. + dist.all_reduce( + flag, op=dist.ReduceOp.MIN, group=self.worker_config.data_parallel_group + ) + global_all_have_sample = bool(flag.item()) + + if not global_all_have_sample: + if local_has_sample == 0: + self._next_sample = sample + else: + self._next_sample = None + # At least one rank is exhausted – terminate *all* ranks. We + # purposely ignore any *local* sample obtained in this round + # to keep the step count aligned across ranks. + # The rank(s) which ended is iterating from start again, the other ranks are not. + self._iterating_from_start = local_has_sample == 0 + if local_has_sample == 0: + self._iterator = None + # print(f"[r={self.worker_config.rank}]: All exhausted, iter_from_start={self._iterating_from_start}\n", end="") + break + + # Otherwise every rank had a sample in this step – yield it. + yield sample + local_has_sample = 0 + sample = None + + def save_state_rank(self) -> StopFirstDataLoaderState: + assert isinstance(self.loader, SavableDataLoader) + if self._next_sample_restore_key is not None: + restore_key = self._next_sample_restore_key + elif self._next_sample is not None: + restore_key = self._next_sample_restore_key = get_sample_restore_key(self._next_sample) + else: + restore_key = None + return StopFirstDataLoaderState( + inner_state=self.loader.save_state_rank(), + iterating_from_start=self._iterating_from_start, + next_sample_restore_key=restore_key, + ) + + def restore_state_rank(self, state: StopFirstDataLoaderState) -> None: + assert isinstance(self.loader, SavableDataLoader) + self._iterating_from_start = state.iterating_from_start + self._next_sample_restore_key = state.next_sample_restore_key + self.loader.restore_state_rank(state.inner_state) + + def can_restore_sample(self) -> bool: + return self.loader.can_restore_sample() + + def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: + return self.loader.restore_sample(restore_key) + + def save_state_global( + self, global_dst_rank: int + ) -> Optional[Sequence[StopFirstDataLoaderState]]: + """ + See :meth:`megatron.energon.SavableDataLoader.save_state_global` + """ + # Fetch current rank's worker's state + merged_state = self.save_state_rank() + + # Gather the merged states + if self.worker_config.world_size > 1: + output: Optional[Sequence[StopFirstDataLoaderState]] + if self.worker_config.global_rank() == global_dst_rank: + output = [None] * self.worker_config.world_size + else: + # Check if the global_dst_rank is in the same group at all + if self.worker_config.data_parallel_group is not None: + try: + _ = torch.distributed.get_group_rank( + self.worker_config.data_parallel_group, global_dst_rank + ) + except RuntimeError: + raise ValueError( + f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" + ) + + output = None + + torch.distributed.gather_object( + merged_state, + output, + global_dst_rank, + group=self.worker_config.data_parallel_group, + ) + + return output + else: + # Not distributed -> return the merged state + return [merged_state] + + def restore_state_global( + self, + state: Optional[Sequence[StopFirstDataLoaderState]], + *, + src_rank: Optional[int] = None, + ) -> None: + """ + See :meth:`megatron.energon.SavableDataLoader.restore_state_global` + """ + assert self._iterator is None, "Cannot restore state while workers are running" + + if src_rank is None or self.worker_config.world_size == 1: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self.worker_config.world_size, ( + "State must be a list of size world_size" + ) + + rank_state = state[self.worker_config.rank] + else: + if self.worker_config.data_parallel_group is not None: + # Only the src_rank has the state within this dp group + try: + global_src_rank = torch.distributed.get_global_rank( + self.worker_config.data_parallel_group, src_rank + ) + except RuntimeError: + raise ValueError( + f"src_rank {src_rank} is not in the group of the current rank's worker config" + ) + else: + # If no DP group is given, we assume the global rank is + # the same as the data parallel rank + global_src_rank = src_rank + + if self.worker_config.rank != src_rank: + # Send the state to all other ranks + assert state is None + # Must still be a list of Nones + state = [None] * self.worker_config.world_size + else: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self.worker_config.world_size, ( + "State must be a list of size world_size" + ) + + local_object = [None] + torch.distributed.scatter_object_list( + local_object, + state, + src=global_src_rank, + group=self.worker_config.data_parallel_group, + ) + rank_state = local_object[0] + + self.restore_state_rank(rank_state) + + def config(self) -> Dict[str, Any]: + return { + "type": type(self).__qualname__, + "loader": self.loader.config(), + } diff --git a/tests/test_av_decoder.py b/tests/test_av_decoder.py index 7e28acb3..3c0286d6 100644 --- a/tests/test_av_decoder.py +++ b/tests/test_av_decoder.py @@ -18,12 +18,6 @@ from megatron.energon.av import AVDecoder, AVMetadata, get_clips_uniform, get_single_frames_uniform -# Set multiprocessing start method to 'spawn' on macOS to avoid DataLoader cleanup issues -if sys.platform == "darwin": - import multiprocessing - - multiprocessing.set_start_method("spawn", force=True) - def load_video_to_tensor(video_path: str) -> torch.Tensor: """Load a video file into a tensor using PyAV directly. diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 92919b0a..946a1882 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -640,6 +640,7 @@ def init_process(rank, world_size, shared_dict, fn, backend="gloo"): for p in processes: p.join() + assert p.exitcode == 0 # Phase 2 (restore state) processes = [] @@ -650,6 +651,7 @@ def init_process(rank, world_size, shared_dict, fn, backend="gloo"): for p in processes: p.join() + assert p.exitcode == 0 def test_restore_state_workers(self): worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 00000000..68564d2f --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,732 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +"""This module defines tests for the dataset.""" + +import gc +import logging +import sys +import tempfile +import unittest +import warnings +from datetime import timedelta +from multiprocessing import Manager, Process +from pathlib import Path + +import torch +import torch.distributed as dist +import webdataset as wds + +from megatron.energon import ( + DefaultTaskEncoder, + SkipSample, + TextSample, + WorkerConfig, + get_savable_loader, + get_train_dataset, + stateless, +) +from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.sync_end import RedistributeLoader, StopFirstLoader + +# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown +try: + torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1 + torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1 +except AttributeError: + pass + + +def _norng_state(state): + if isinstance(state, bytes): + if len(state) > 100: + return state[:5] + f"...".encode() + return state + elif isinstance(state, str): + if len(state) > 100: + return state[:5] + f"..." + return state + elif isinstance(state, dict): + return {k: _norng_state(v) for k, v in state.items()} + elif isinstance(state, (list, tuple)): + if len(state) > 100: + state = state[:5] + return type(state)(_norng_state(v) for v in state) + else: + return state + + +class TestDataset(unittest.TestCase): + # Set up the test fixture + def setUp(self): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + # Create a temporary directory + self.temp_dir = tempfile.TemporaryDirectory() + self.dataset_path = Path(self.temp_dir.name) + # self.dataset_path = Path("./test_dataset") + + self.dataset_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + self.create_text_test_dataset(self.dataset_path) + + print(self.dataset_path) + + def tearDown(self): + # Remove all temporary files + gc.collect() + self.temp_dir.cleanup() + + @staticmethod + def create_text_test_dataset(path: Path): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer: + for idx in range(55): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{idx:06d}", + "txt": f"{idx}".encode(), + }, + ) + # Also create smaller shards, to verify distributions + if idx in (1, 3, 6, 10, 20, 30, 40, 50): + shard_writer.next_stream() + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + ] + ) + ) + + def test_distribute_stop_first(self): + world_size = 3 + + offsets = [9, 0, 4] + + all_items = [f"{i}" for i in range(55)] + rank_subsets = [ + set(all_items[:19]), + set(all_items[19 : 19 + 18]), + set(all_items[19 + 18 :]), + ] + + def phase1(rank: int, world_size: int, shared_dict: dict): + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + + torch.manual_seed(42) + + epoch_reset = False + epoch_offset = 0 + + class LocalTaskEncoder( + DefaultTaskEncoder[TextSample, TextSample, TextSample, TextSample] + ): + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + nonlocal epoch_reset, epoch_offset + if epoch_reset: + epoch_offset = self.current_batch_index + epoch_reset = False + if self.current_batch_index >= 5 + offsets[rank] + epoch_offset: + print( + f"[r={rank}] Skip sample bi={self.current_batch_index} si={self.current_sample_index}\n", + end="", + ) + raise SkipSample() + print( + f"[r={rank}] Return sample bi={self.current_batch_index} si={self.current_sample_index}\n", + end="", + ) + return sample + + # First verify that the loader is working as expected + ref_loader = get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + + order_all = [data.text[0] for idx, data in zip(range(100), ref_loader)] + assert len(order_all) == 5 + offsets[rank], f"Rank {rank} has {len(order_all)} samples" + assert all(item in rank_subsets[rank] for item in order_all), ( + f"Rank {rank} has {order_all} samples" + ) + + epoch_reset = True + + order_all = [data.text[0] for idx, data in zip(range(100), ref_loader)] + assert len(order_all) == 5 + offsets[rank], f"Rank {rank} has {len(order_all)} samples" + assert all(item in rank_subsets[rank] for item in order_all), ( + f"Rank {rank} has {order_all} samples" + ) + + epoch_offset = 0 + + # To the actual test + loader = StopFirstLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + + state_0 = loader.save_state_global(global_dst_rank=0) + order_1 = [data.text[0] for idx, data in zip(range(4), loader)] + assert len(order_1) == 4, f"Rank {rank} has {len(order_1)} samples" + assert all(item in rank_subsets[rank] for item in order_1), ( + f"Rank {rank} has {order_1} samples" + ) + + print(f"Rank {rank} order_1: {order_1}\n", end="") + + # Exhaust the synchronized loader. It should only give a single sample. + state_1 = loader.save_state_global(global_dst_rank=0) + order_2 = [data.text[0] for idx, data in zip(range(10), loader)] + assert len(order_2) == 1, f"Rank {rank} has {len(order_2)} samples" + assert all(item in rank_subsets[rank] for item in order_2), ( + f"Rank {rank} has {order_2} samples" + ) + + print(f"Rank {rank} order_2: {order_2}\n", end="") + + # Restart iterating until exhausted. + epoch_reset = True + + state_2 = loader.save_state_global(global_dst_rank=0) + order_3 = [data.text[0] for idx, data in zip(range(30), loader)] + assert len(order_3) == 5, f"Rank {rank} has {len(order_3)} samples" + + print(f"Rank {rank} order_3: {order_3}\n", end="") + + assert all(item in rank_subsets[rank] for item in order_3), ( + f"Rank {rank} has {order_3} samples" + ) + + shared_dict[(rank, "order_1")] = order_1 + shared_dict[(rank, "order_2")] = order_2 + shared_dict[(rank, "order_3")] = order_3 + + if rank == 0: + shared_dict["state_0"] = state_0 + shared_dict["state_1"] = state_1 + shared_dict["state_2"] = state_2 + + print(f"Rank {rank} finished phase 1\n", end="") + dist.barrier() + + def phase2(rank: int, world_size: int, shared_dict: dict): + torch.manual_seed(213) + + epoch_reset = False + epoch_offset = 0 + + class LocalTaskEncoder( + DefaultTaskEncoder[TextSample, TextSample, TextSample, TextSample] + ): + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + nonlocal epoch_reset, epoch_offset + if epoch_reset: + epoch_offset = self.current_batch_index + epoch_reset = False + if self.current_batch_index >= 5 + offsets[rank] + epoch_offset: + raise SkipSample() + return sample + + order_1 = shared_dict[(rank, "order_1")] + order_2 = shared_dict[(rank, "order_2")] + order_3 = shared_dict[(rank, "order_3")] + + if rank == 0: + state_0 = shared_dict["state_0"] + state_1 = shared_dict["state_1"] + state_2 = shared_dict["state_2"] + else: + state_0 = None + state_1 = None + state_2 = None + + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + + loader = StopFirstLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_0, src_rank=0) + + order_1_r = [data.text[0] for idx, data in zip(range(4), loader)] + assert order_1_r == order_1 + + order_2_r = [data.text[0] for idx, data in zip(range(10), loader)] + assert order_2_r == order_2 + + loader = StopFirstLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_1, src_rank=0) + + order_2_r = [data.text[0] for idx, data in zip(range(10), loader)] + assert order_2_r == order_2 + + epoch_reset = True + + order_3_r = [data.text[0] for idx, data in zip(range(30), loader)] + assert order_3_r == order_3 + + loader = StopFirstLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_2, src_rank=0) + + epoch_reset = True + + order_3_r = [data.text[0] for idx, data in zip(range(30), loader)] + assert order_3_r == order_3 + + dist.barrier() + print(f"Rank {rank} finished phase 2\n", end="") + + def init_process(rank, world_size, shared_dict, fn, backend="gloo"): + """Initializes the distributed environment.""" + dist.init_process_group( + backend=backend, + init_method="tcp://127.0.0.1:12355", + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=5), + ) + fn(rank, world_size, shared_dict) + dist.destroy_process_group() + + with Manager() as manager: + shared_dict = manager.dict() + + # Phase 1 (save state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0 + + # Phase 2 (restore state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0 + + def test_distribute_stop_redistribute(self): + world_size = 3 + + offsets = [9, 0, 4] + + all_items = [f"{i}" for i in range(55)] + rank_subsets = [ + set(all_items[:19]), + set(all_items[19 : 19 + 18]), + set(all_items[19 + 18 :]), + ] + + def phase1(rank: int, world_size: int, shared_dict: dict): + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + + torch.manual_seed(42) + + epoch_reset = False + epoch_offset = 0 + + class LocalTaskEncoder( + DefaultTaskEncoder[TextSample, TextSample, TextSample, TextSample] + ): + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + nonlocal epoch_reset, epoch_offset + if epoch_reset: + epoch_offset = self.current_batch_index + epoch_reset = False + if self.current_batch_index >= 5 + offsets[rank] + epoch_offset: + # print(f"[r={rank}] Skip sample bi={self.current_batch_index} si={self.current_sample_index}\n", end="") + raise SkipSample() + # print(f"[r={rank}] Return sample bi={self.current_batch_index} si={self.current_sample_index}\n", end="") + return sample + + # This seed is used by the dataset to shuffle the data + + ref_loader = get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + + order_all = [data.text[0] for idx, data in zip(range(100), ref_loader)] + assert len(order_all) == 5 + offsets[rank], f"Rank {rank} has {len(order_all)} samples" + assert all(item in rank_subsets[rank] for item in order_all), ( + f"Rank {rank} has {order_all} samples" + ) + + epoch_reset = True + + order_all = [data.text[0] for idx, data in zip(range(100), ref_loader)] + assert len(order_all) == 5 + offsets[rank], f"Rank {rank} has {len(order_all)} samples" + assert all(item in rank_subsets[rank] for item in order_all), ( + f"Rank {rank} has {order_all} samples" + ) + + epoch_offset = 0 + + loader = RedistributeLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + + state_0 = loader.save_state_global(global_dst_rank=0) + order_1 = [data.text[0] for idx, data in zip(range(5), loader)] + assert len(order_1) == 5, f"Rank {rank} has {len(order_1)} samples" + assert all(item in rank_subsets[rank] for item in order_1), ( + f"Rank {rank} has {order_1} samples" + ) + + # print(f"Rank {rank}: order_1", order_1) + + state_1 = loader.save_state_global(global_dst_rank=0) + order_2 = [data.text[0] for idx, data in zip(range(10), loader)] + assert len(order_2) == 4, f"Rank {rank} has {len(order_2)} samples" + # States: + # [-1] remaining samples [9, 0, 4] + # [0] remaining samples [7, 0, 3] + # [1] remaining samples [6, 0, 1] + # [2] remaining samples [4, 0, 0] + # [3] remaining samples [1, 0, 0] + # End, not enough samples left + if rank == 0: + assert all(item in rank_subsets[0] for item in order_2), ( + f"Rank {rank} has {order_2} samples" + ) + elif rank == 1: + assert order_2[0] in rank_subsets[0] + assert order_2[1] in rank_subsets[2] + assert order_2[2] in rank_subsets[0] + assert order_2[3] in rank_subsets[0] + elif rank == 2: + assert order_2[0] in rank_subsets[2] + assert order_2[1] in rank_subsets[2] + assert order_2[2] in rank_subsets[2] + assert order_2[3] in rank_subsets[0] + + epoch_reset = True + + state_2 = loader.save_state_global(global_dst_rank=0) + order_3 = [data.text[0] for idx, data in zip(range(30), loader)] + assert len(order_3) == 9, f"Rank {rank} has {len(order_3)} samples" + + assert all(item in rank_subsets[rank] for item in order_3[:5]), ( + f"Rank {rank} has {order_3} samples" + ) + if rank == 0: + assert all(item in rank_subsets[0] for item in order_3[5:]), ( + f"Rank {rank} has {order_3} samples" + ) + elif rank == 1: + assert order_3[5] in rank_subsets[0] + assert order_3[6] in rank_subsets[2] + assert order_3[7] in rank_subsets[0] + assert order_3[8] in rank_subsets[0] + elif rank == 2: + assert order_3[5] in rank_subsets[2] + assert order_3[6] in rank_subsets[2] + assert order_3[7] in rank_subsets[2] + assert order_3[8] in rank_subsets[0] + + shared_dict[(rank, "order_1")] = order_1 + shared_dict[(rank, "order_2")] = order_2 + shared_dict[(rank, "order_3")] = order_3 + + if rank == 0: + shared_dict["state_0"] = state_0 + shared_dict["state_1"] = state_1 + shared_dict["state_2"] = state_2 + + print(f"Rank {rank} finished phase 1\n", end="") + dist.barrier() + + def phase2(rank: int, world_size: int, shared_dict: dict): + torch.manual_seed(213) + + epoch_reset = False + epoch_offset = 0 + + class LocalTaskEncoder( + DefaultTaskEncoder[TextSample, TextSample, TextSample, TextSample] + ): + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + nonlocal epoch_reset, epoch_offset + if epoch_reset: + epoch_offset = self.current_batch_index + epoch_reset = False + if self.current_batch_index >= 5 + offsets[rank] + epoch_offset: + raise SkipSample() + return sample + + order_1 = shared_dict[(rank, "order_1")] + order_2 = shared_dict[(rank, "order_2")] + order_3 = shared_dict[(rank, "order_3")] + + if rank == 0: + state_0 = shared_dict["state_0"] + state_1 = shared_dict["state_1"] + state_2 = shared_dict["state_2"] + else: + state_0 = None + state_1 = None + state_2 = None + + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + + loader = RedistributeLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_0, src_rank=0) + + order_1_r = [data.text[0] for idx, data in zip(range(5), loader)] + assert order_1_r == order_1 + + order_2_r = [data.text[0] for idx, data in zip(range(10), loader)] + assert order_2_r == order_2 + + loader = RedistributeLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_1, src_rank=0) + + order_2_r = [data.text[0] for idx, data in zip(range(10), loader)] + assert order_2_r == order_2 + + epoch_reset = True + + order_3_r = [data.text[0] for idx, data in zip(range(30), loader)] + assert order_3_r == order_3 + + loader = RedistributeLoader( + get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=2, + parallel_shard_iters=1, + repeat=False, + task_encoder=LocalTaskEncoder(), + ), + checkpoint_every_sec=0, + checkpoint_every_min_n_samples=1, + ) + ) + loader.restore_state_global(state_2, src_rank=0) + + epoch_reset = True + + order_3_r = [data.text[0] for idx, data in zip(range(30), loader)] + assert order_3_r == order_3 + + dist.barrier() + print(f"Rank {rank} finished phase 2\n", end="") + + def init_process(rank, world_size, shared_dict, fn, backend="gloo"): + """Initializes the distributed environment.""" + dist.init_process_group( + backend=backend, + init_method="tcp://127.0.0.1:12355", + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=5), + ) + fn(rank, world_size, shared_dict) + dist.destroy_process_group() + + with Manager() as manager: + shared_dict = manager.dict() + + # Phase 1 (save state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0 + + # Phase 2 (restore state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0 + + +if __name__ == "__main__": + unittest.main()