diff --git a/src/megatron/energon/cli/main.py b/src/megatron/energon/cli/main.py index c9f72a59..70cfd8d4 100644 --- a/src/megatron/energon/cli/main.py +++ b/src/megatron/energon/cli/main.py @@ -6,6 +6,7 @@ import click from megatron.energon.tools.analyze_debug import command as analyze_debug_command +from megatron.energon.tools.analyze_debug_merge import command as analyze_debug_merge_command from megatron.energon.tools.checkpoint import command as checkpoint_command from megatron.energon.tools.info import command as info_command from megatron.energon.tools.lint import command as lint_command @@ -51,6 +52,7 @@ def main(ctx): main.add_command(analyze_debug_command) +main.add_command(analyze_debug_merge_command) main.add_command(checkpoint_command) main.add_command(lint_command) main.add_command(info_command) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 5c309a08..370e3a85 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import functools import inspect import typing from abc import ABC, abstractmethod @@ -345,6 +346,10 @@ def _function_config(fn: Callable) -> str: mod_name = getattr(fn, "__module__", "") return f"{mod_name}.{getattr(fn, '__qualname__', getattr(fn, '__name__', ''))}" + @staticmethod + def _function_config_short(fn: Callable) -> str: + return getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + @abstractmethod def config(self) -> Dict[str, Any]: """Return a config dict that can be used to check if datasets have the same settings. @@ -447,9 +452,6 @@ def legacy_handler( ], ) -> Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]: """Safely returns the new style three argument handler. If the handler takes 2 arguments, it wraps it.""" - import functools - import inspect - handler_sig = inspect.signature(handler) if len(handler_sig.parameters) != 3: original_handler = handler diff --git a/src/megatron/energon/flavors/trace.py b/src/megatron/energon/flavors/trace.py new file mode 100644 index 00000000..44a671db --- /dev/null +++ b/src/megatron/energon/flavors/trace.py @@ -0,0 +1,188 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union + +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.tracing import AsyncContext + +T_dataset = TypeVar("T_dataset", bound=SavableDataset) +T_call = TypeVar("T_call", bound=Callable) +T_sample = TypeVar("T_sample") + + +def _flatten_str_list(keys: Any) -> Iterator[Optional[str]]: + """Flatten a list of keys into a list of strings.""" + if isinstance(keys, str): + yield keys + elif isinstance(keys, (list, tuple)): + for key in keys: + yield from _flatten_str_list(key) + else: + yield None + + +def _flatten_str_list_or_none(keys: Any) -> Optional[List[str]]: + """Flatten a list of keys into a list of strings. If this cannot be fetched, return None.""" + keys = list(_flatten_str_list(keys)) + if any(k is None for k in keys): + return None + return keys + + +def default_get_keys(batch: Any) -> Optional[List[str]]: + """Default get_keys, which has some heuristics to find the sample keys.""" + if isinstance(batch, list): + all_keys = [] + for b in batch: + k = default_get_keys(b) + if k is None: + return None + all_keys.extend(k) + return all_keys + if hasattr(batch, "__key__"): + return _flatten_str_list_or_none(batch.__key__) + elif hasattr(batch, "__keys__"): + return _flatten_str_list_or_none(batch.__keys__) + elif isinstance(batch, dict): + if "__key__" in batch: + return _flatten_str_list_or_none(batch["__key__"]) + elif "__keys__" in batch: + return _flatten_str_list_or_none(batch["__keys__"]) + elif "keys" in batch: + return _flatten_str_list_or_none(batch["keys"]) + return None + + +class TraceIter: + last_args: Dict[str, Any] = {} + + def __init__( + self, + outer_self: T_dataset, + name: str, + trace_span: AsyncContext, + call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]], + ): + self.outer_self = outer_self + self.name = name + self.trace_span = trace_span + self.call_args = call_args + + def sample_exception( + self, exception: Exception, samples: Union[T_sample, Sequence[T_sample]] + ) -> None: + self.trace_span.instant( + f"{self.name}.error/skip", + args={ + "exception": f"{type(exception).__name__}: {str(exception)}", + "sample_keys": default_get_keys(samples), + **{ + arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value + for arg_name, arg_value in self.call_args.items() + }, + }, + level=1, + ) + + def skip_sample(self, samples: Sequence[T_sample]) -> None: + self.trace_span.instant( + f"{self.name}.skip", + args={ + "sample_keys": default_get_keys(samples), + }, + level=1, + ) + + def sample( + self, sample: Union[T_sample, Sequence[T_sample]], args: Dict[str, Any] = {} + ) -> None: + self.last_args["sample_keys"] = default_get_keys(sample) + self.last_args.update(args) + + def wrap_fn(self, fn: T_call) -> T_call: + fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + with self.trace_span.span( + f"{self.name}.{fn_name}.call", + args={ + arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value + for arg_name, arg_value in self.call_args.items() + }, + level=2, + ): + return fn(*args, **kwargs) + + return wrapped_fn + + def wrap_inner(self, call_args: Callable[..., Dict[str, Any]] = lambda *args, **kwargs: {}): + def decorator(fn): + fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + + @functools.wraps(fn) + def wrapped_inner_gen(*args, **kwargs): + with self.trace_span.span( + f"{self.name}.{fn_name}.__iter__", args=call_args(*args, **kwargs), level=2 + ): + return fn(*args, **kwargs) + + return wrapped_inner_gen + + return decorator + + +def trace_iter( + name: Callable[[T_dataset], str] = lambda ds: type(ds).__name__, + call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {}, + next_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {}, +) -> Callable[ + [Callable[[T_dataset, TraceIter], Iterator[T_sample]]], + Callable[[T_dataset], Iterator[T_sample]], +]: + """Decorator for SavableDataset.__iter__ to trace the iteration using the worker config.""" + + def decorator( + iter_fn: Callable[[T_dataset, TraceIter], Iterator[T_sample]], + ) -> Callable[[T_dataset], Iterator[T_sample]]: + @functools.wraps(iter_fn) + def wrapper(self: T_dataset) -> Iterator[T_sample]: + trace_span = self.worker_config.worker_trace_span() + span_name = name(self) + trace_iter = TraceIter(self, span_name, trace_span, call_args) + with ( + trace_span.span( + f"{span_name}.__iter__", + args={ + arg_name: arg_value(self) if callable(arg_value) else arg_value + for arg_name, arg_value in call_args.items() + }, + level=1, + ), + self.worker_config.worker_trace_writer().generator( + f"{span_name}.__iter__.next", + level=2, + ) as trace_gen, + ): + for sample in trace_span.iterable( + iter_fn(self, trace_iter), + name=f"{span_name}.__iter__.loop", + level=2, + ): + with trace_gen.yield_( + last_args={ + **{ + arg_name: arg_value(self) if callable(arg_value) else arg_value + for arg_name, arg_value in next_args.items() + }, + **trace_iter.last_args, + }, + ): + trace_iter.last_args.clear() + yield sample + + return wrapper + + return decorator diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 72b08c0a..3a73fec7 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -197,42 +197,53 @@ def _slices_iter(self) -> Generator[RawSampleData, None, None]: """Iterates the samples in a list of slices, possibly using multiple parallel iterators over the slices.""" - assert self.slice_offsets is not None - - active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) - active_slices = self._active_slice_state - pending_slice_indexes = self._pending_slice_indexes - - def slice_at(idx: int) -> SliceState: + trace = self.worker_config.worker_trace_span() + + with trace.span( + "WebdatasetSampleLoaderDataset._slices_iter", + args={ + "base_paths": [str(reader.base_path) for reader in self.join_readers], + "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + }, + level=1, + ) as fn_span: assert self.slice_offsets is not None - return SliceState( - index=idx, - current=self.slice_offsets[idx], - ) - - # Weight the slices by their size to get a more even distribution of samples - if any(s is not None for s in active_slices) or self._pending_slices_offset is not None: - # Having an active state, or pending slices. This means we are resuming an epoch. - if pending_slice_indexes is None: - # Need to restore the pending slices - pending_slice_indexes = self._slices_once() - assert pending_slice_indexes is not None - # Restore the state - assert len(active_slices) == self.parallel_slice_iters - for idx, slice_state in enumerate(active_slices): - if slice_state is not None: - active_slice_probs[idx] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] - ) + active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) + active_slices = self._active_slice_state + pending_slice_indexes = self._pending_slice_indexes + + def slice_at(idx: int) -> SliceState: + assert self.slice_offsets is not None + return SliceState( + index=idx, + current=self.slice_offsets[idx], + ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( + # Weight the slices by their size to get a more even distribution of samples + if any(s is not None for s in active_slices) or self._pending_slices_offset is not None: + # Having an active state, or pending slices. This means we are resuming an epoch. + if pending_slice_indexes is None: + # Need to restore the pending slices + pending_slice_indexes = self._slices_once() + assert pending_slice_indexes is not None + + # Restore the state + assert len(active_slices) == self.parallel_slice_iters + for idx, slice_state in enumerate(active_slices): + if slice_state is not None: + active_slice_probs[idx] = ( + self.slice_offsets[slice_state.index + 1] + - self.slice_offsets[slice_state.index] + ) + + fn_span.update_args( { - "t": "WebdatasetSampleLoaderDataset._slices_iter.resume_epoch", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), + "mode": "resume_epoch", "pending_slice_indexes": pending_slice_indexes, "active_slices": [ ( @@ -245,162 +256,188 @@ def slice_at(idx: int) -> SliceState: ) for state in active_slices ], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), } ) + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.resume_epoch", + args={ + "pending_slice_indexes": pending_slice_indexes, + "active_slices": [ + ( + None + if state is None + else { + "index": state.index, + "current": state.current, + } + ) + for state in active_slices + ], + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, + }, + level=1, + ) - else: - # Start a new epoch - assert pending_slice_indexes is None - pending_slice_indexes = self._slices_once() + else: + # Start a new epoch + assert pending_slice_indexes is None + pending_slice_indexes = self._slices_once() - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( + fn_span.update_args( { - "t": "WebdatasetSampleLoaderDataset._slices_iter.next_epoch", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), + "mode": "next_epoch", "pending_slice_indexes": pending_slice_indexes, - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), - "shuffle_over_epochs": self.shuffle_over_epochs, } ) - - assert self._pending_slices_offset is not None - - # List of slice iterators, always of length `parallel_slice_iters`. May contain `None`. - active_slices.clear() - # Fill up the slice iterators - while len(pending_slice_indexes) > 0 and len(active_slices) < self.parallel_slice_iters: - slice_index = pending_slice_indexes.pop() - self._pending_slices_offset += 1 - slice_state = slice_at(slice_index) - active_slice_probs[len(active_slices)] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.next_epoch", + args={ + "pending_slice_indexes": pending_slice_indexes, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, + }, + level=1, ) - active_slices.append(slice_state) - # Fill up the slice iterators with None - for _ in range(len(active_slices), self.parallel_slice_iters): - active_slices.append(None) - - # print( - # f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}" - # ) - # for slice_state in active_slices: - # if slice_state is None: - # print(" - None") - # else: - # print( - # f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}" - # ) - - # Iterate over the slice iterators while there is an iterator left - while torch.count_nonzero(active_slice_probs).item() > 0: - if self.shuffle_over_epochs is None: - # No shuffling, deterministic order, always the same - assert self.parallel_slice_iters == 1 - slice_idx = 0 - else: - # Take a random slice iterator - slice_idx = self._worker_rng.choice_idx(active_slice_probs) - slice_state = active_slices[slice_idx] - assert slice_state is not None - sample = self._get_sample(slice_state.current) - # print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}") - slice_state.current += 1 - self._sample_count += 1 - self._epoch_sample_count += 1 - if slice_state.current >= self.slice_offsets[slice_state.index + 1]: - # Iterator exhausted -> take next / remove from list - if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: - if len(pending_slice_indexes) > 0: - # Take the next slice (without replacement) - next_idx = pending_slice_indexes.pop() - assert self._pending_slices_offset is not None - self._pending_slices_offset += 1 - else: - # Randomly select a new slice directly (with replacement) - num_slices = len(self.slice_offsets) - 1 - next_idx = self._worker_rng.randbelow(num_slices) - next_slice_state = slice_at(next_idx) - active_slice_probs[slice_idx] = ( - self.slice_offsets[next_slice_state.index + 1] - - self.slice_offsets[next_slice_state.index] - ) - active_slices[slice_idx] = next_slice_state - # print( - # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " - # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " - # f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], " - # f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}" - # ) - else: - active_slice_probs[slice_idx] = 0 - active_slices[slice_idx] = None - # print( - # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " - # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " - # f"no next slice, probs={active_slice_probs.tolist()}" - # ) - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.exhausted", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "remaining": len(pending_slice_indexes), - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - "probs": active_slice_probs.tolist(), - } - ) - if sample.data[0] is not None: - # Otherwise the sample was skipped. - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "index": sample.__restore_key__[1], - "key": sample.data[0]["__key__"], - "shard": sample.data[0]["__shard__"], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - } + assert self._pending_slices_offset is not None + + # List of slice iterators, always of length `parallel_slice_iters`. May contain `None`. + active_slices.clear() + # Fill up the slice iterators + while ( + len(pending_slice_indexes) > 0 + and len(active_slices) < self.parallel_slice_iters + ): + slice_index = pending_slice_indexes.pop() + self._pending_slices_offset += 1 + slice_state = slice_at(slice_index) + active_slice_probs[len(active_slices)] = ( + self.slice_offsets[slice_state.index + 1] + - self.slice_offsets[slice_state.index] ) - # Now, yield the sample - yield sample - del sample - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - } - ) - - # Epoch has finished, reset states. - self._epoch_count += 1 - self._epoch_sample_count = 0 - self._pending_slice_indexes = None - self._pending_slices_offset = None - # print( - # f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples" - # ) + active_slices.append(slice_state) + # Fill up the slice iterators with None + for _ in range(len(active_slices), self.parallel_slice_iters): + active_slices.append(None) + + # print( + # f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}" + # ) + # for slice_state in active_slices: + # if slice_state is None: + # print(" - None") + # else: + # print( + # f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}" + # ) + + # Iterate over the slice iterators while there is an iterator left + while torch.count_nonzero(active_slice_probs).item() > 0: + with trace.span("WebdatasetSampleLoaderDataset._slices_iter.iter", level=1): + if self.shuffle_over_epochs is None: + # No shuffling, deterministic order, always the same + assert self.parallel_slice_iters == 1 + slice_idx = 0 + else: + # Take a random slice iterator + slice_idx = self._worker_rng.choice_idx(active_slice_probs) + slice_state = active_slices[slice_idx] + assert slice_state is not None + sample = self._get_sample(slice_state.current) + # print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}") + slice_state.current += 1 + self._sample_count += 1 + self._epoch_sample_count += 1 + if slice_state.current >= self.slice_offsets[slice_state.index + 1]: + # Iterator exhausted -> take next / remove from list + if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: + if len(pending_slice_indexes) > 0: + # Take the next slice (without replacement) + next_idx = pending_slice_indexes.pop() + assert self._pending_slices_offset is not None + self._pending_slices_offset += 1 + else: + # Randomly select a new slice directly (with replacement) + num_slices = len(self.slice_offsets) - 1 + next_idx = self._worker_rng.randbelow(num_slices) + next_slice_state = slice_at(next_idx) + active_slice_probs[slice_idx] = ( + self.slice_offsets[next_slice_state.index + 1] + - self.slice_offsets[next_slice_state.index] + ) + active_slices[slice_idx] = next_slice_state + # print( + # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " + # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " + # f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], " + # f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}" + # ) + else: + active_slice_probs[slice_idx] = 0 + active_slices[slice_idx] = None + # print( + # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " + # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " + # f"no next slice, probs={active_slice_probs.tolist()}" + # ) + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.exhausted", + args={ + "remaining": len(pending_slice_indexes), + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + }, + level=2, + ) + if sample.data[0] is not None: + # Otherwise the sample was skipped. + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.yield", + args={ + "base_path": str(self.join_readers[0].base_path), + "global_sample_index": sample.__restore_key__[1], + "key": sample.data[0]["__key__"], + "shard": sample.data[0]["__shard__"], + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + }, + level=2, + ) + # Now, yield the sample + yield sample + del sample + if self.worker_config.should_log(level=2): + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted", + args={ + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + }, + level=2, + ) + + # Epoch has finished, reset states. + self._epoch_count += 1 + self._epoch_sample_count = 0 + self._pending_slice_indexes = None + self._pending_slices_offset = None + # print( + # f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples" + # ) def __len__(self) -> int: return sum( @@ -419,18 +456,6 @@ def __iter__(self) -> Iterator[RawSampleData]: self.ensure_slice_offsets() assert self.slice_offsets is not None - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset.__iter__", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "slice_offsets": self.slice_offsets, - "parallel_slice_iters": self.parallel_slice_iters, - "shuffle_over_epochs": self.shuffle_over_epochs, - } - ) - if len(self.slice_offsets) <= 1: return diff --git a/src/megatron/energon/fork_hook.py b/src/megatron/energon/fork_hook.py index b5184207..72a3d4aa 100644 --- a/src/megatron/energon/fork_hook.py +++ b/src/megatron/energon/fork_hook.py @@ -1,15 +1,14 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import functools import os import weakref -from typing import Callable, Protocol, Type, TypeVar - -_after_in_child_fork_hooks = weakref.WeakKeyDictionary() -_after_in_parent_fork_hooks = weakref.WeakKeyDictionary() -_before_fork_hooks = weakref.WeakKeyDictionary() +from dataclasses import dataclass +from typing import Callable, TypeVar +_after_in_child_fork_hooks = dict() +_after_in_parent_fork_hooks = dict() +_before_fork_hooks = dict() T = TypeVar("T", bound=Callable[[], None]) @@ -18,34 +17,41 @@ def before_fork_hook(callable: Callable[[], None]): """ Run function before the fork of a worker process. The function must be persistent. """ - # Make sure, that callable is a method of object - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - # print(f"Adding before_fork_hook for {callable.__name__}\n", end="") - _before_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _before_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _before_fork_hooks.pop(id(self))) + else: + _before_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _before_fork_hooks.pop(id(callable))) -def after_in_parent_fork_hook(callable: T): +def after_in_parent_fork_hook(callable: Callable[[], None]): """ Run function after the fork of a worker process. The function must be persistent. """ # print(f"Adding after_in_child_fork_hook for {callable.__name__}\n", end="") - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - _after_in_parent_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _after_in_parent_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _after_in_parent_fork_hooks.pop(id(self))) + else: + _after_in_parent_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _after_in_parent_fork_hooks.pop(id(callable))) -def after_in_child_fork_hook(callable: T): +def after_in_child_fork_hook(callable: Callable[[], None]): """ Run function after the fork of a worker process. The function must be persistent. """ # print(f"Adding after_in_child_fork_hook for {callable.__name__}\n", end="") - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - _after_in_child_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _after_in_child_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _after_in_child_fork_hooks.pop(id(self))) + else: + _after_in_child_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _after_in_child_fork_hooks.pop(id(callable))) class ForkMixin: @@ -55,18 +61,21 @@ class ForkMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.__post_init__() + + def __post_init__(self): if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__: - _before_fork_hooks[self] = "__before_fork__" + before_fork_hook(self.__before_fork__) if ( getattr(self.__after_in_child_fork__, "__func__", None) is not ForkMixin.__after_in_child_fork__ ): - _after_in_child_fork_hooks[self] = "__after_in_child_fork__" + after_in_child_fork_hook(self.__after_in_child_fork__) if ( getattr(self.__after_in_parent_fork__, "__func__", None) is not ForkMixin.__after_in_parent_fork__ ): - _after_in_parent_fork_hooks[self] = "__after_in_parent_fork__" + after_in_parent_fork_hook(self.__after_in_parent_fork__) def __after_in_child_fork__(self): """ @@ -87,58 +96,43 @@ def __before_fork__(self): pass -class ForkHookProtocol(Protocol): +@dataclass +class DataclassForkMixin: """ - A protocol that defines a method that runs before and after the fork of a worker process. + A mixin that runs a method after the fork of a worker process. """ + def __post_init__(self): + if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__: + before_fork_hook(self.__before_fork__) + if ( + getattr(self.__after_in_child_fork__, "__func__", None) + is not ForkMixin.__after_in_child_fork__ + ): + after_in_child_fork_hook(self.__after_in_child_fork__) + if ( + getattr(self.__after_in_parent_fork__, "__func__", None) + is not ForkMixin.__after_in_parent_fork__ + ): + after_in_parent_fork_hook(self.__after_in_parent_fork__) + def __after_in_child_fork__(self): """ A method that runs after the fork in the child process. """ - ... + pass def __after_in_parent_fork__(self): """ A method that runs after the fork in the parent process. """ - ... + pass def __before_fork__(self): """ A method that runs before the fork of a worker process. """ - ... - - -T_CLS = TypeVar("T_CLS", bound=Type[ForkHookProtocol]) - - -def fork_hook_class(cls: T_CLS) -> T_CLS: - """ - A decorator that runs a function after the fork of a worker process. - """ - if hasattr(cls, "__init__"): - orig_init = cls.__init__ - - @functools.wraps(orig_init) - def __init__(self, *args, **kwargs): - _after_in_child_fork_hooks[self] = "__after_in_child_fork__" - _after_in_parent_fork_hooks[self] = "__after_in_parent_fork__" - _before_fork_hooks[self] = "__before_fork__" - orig_init(self, *args, **kwargs) - - cls.__init__ = __init__ - else: - - def __init__(self, *args, **kwargs): - _after_in_child_fork_hooks[cls] = "__after_in_child_fork__" - _after_in_parent_fork_hooks[cls] = "__after_in_parent_fork__" - _before_fork_hooks[cls] = "__before_fork__" - cls(*args, **kwargs) - - cls.__init__ = __init__ - return cls + pass def _run_before_fork_hooks(): @@ -146,12 +140,9 @@ def _run_before_fork_hooks(): Run all the functions that were registered with the before_fork_hook decorator. """ # print(f"Running before_fork_hooks for pid={os.getpid()}") - for obj, hook in _before_fork_hooks.items(): + for hook in _before_fork_hooks.values(): # print(f"Running before_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() def _run_after_in_child_fork_hooks(): @@ -159,12 +150,9 @@ def _run_after_in_child_fork_hooks(): Run all the functions that were registered with the after_in_child_fork_hook decorator. """ # print(f"Running after_in_child_fork_hooks for pid={os.getpid()}") - for obj, hook in _after_in_child_fork_hooks.items(): + for hook in _after_in_child_fork_hooks.values(): # print(f"Running after_in_child_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() def _run_after_in_parent_fork_hooks(): @@ -172,12 +160,9 @@ def _run_after_in_parent_fork_hooks(): Run all the functions that were registered with the after_in_parent_fork_hook decorator. """ # print(f"Running after_in_parent_fork_hooks for pid={os.getpid()}") - for obj, hook in _after_in_parent_fork_hooks.items(): + for hook in _after_in_parent_fork_hooks.values(): # print(f"Running after_in_parent_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() os.register_at_fork( diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index da771b2b..6372b71e 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -38,12 +38,12 @@ State, add_sample_restore_key, ) +from megatron.energon.flavors.trace import default_get_keys from megatron.energon.rng import SystemRng, SystemRngState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.log_sample_dataset import default_get_keys from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset T = TypeVar("T") @@ -61,7 +61,24 @@ def _init_worker(seed_per_worker: List[int], worker_id: int): SystemRng.seed(worker_seed) -class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]): +@edataclass +class SavableDatasetState(State): + """State of the dataset wrapper. It stores the global random states and the index of the next + sample to be returned from the dataset. This class is not intended to be used directly, but by + :class:`megatron.energon.SavableDatasetWrapper`.""" + + #: The state of all the system random number generators + rng: SystemRngState + #: The state of the savable dataset + dataset_state: FlexState + #: Index of the next sample to be returned from the dataset + sample_index: int + + def __repr__(self): + return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" + + +class SimpleSavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]): """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is not intended to be used directly.""" @@ -71,19 +88,24 @@ class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Gen _state_restored: bool _sample_index: int - _savable_fields = ("_sample_index",) - def __init__( - self, dataset: SavableDataset[T], worker_config: WorkerConfig, cache_pool: CachePool + self, + dataset: SavableDataset[T], + worker_config: WorkerConfig, + cache_pool: CachePool, + dataloader_id: int, ): """ Args: dataset: The dataset to wrap. worker_config: The worker config to use for the dataset. cache_pool: The cache pool to use for the dataset. + dataloader_id: The id of the data loader for logging purposes. """ - super().__init__(dataset, worker_config=worker_config) + self.dataset = dataset + self.worker_config = worker_config self.cache_pool = cache_pool + self.dataloader_id = dataloader_id self.reset_state_own() @@ -91,54 +113,105 @@ def reset_state_own(self) -> None: self._sample_index = 0 self._state_restored = False - def __len__(self): + def inner_len(self): return len(self.dataset) def __iter__(self): - self._state_restored = True - worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - while self._state_restored: - self._state_restored = False - self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - sample_index = self._sample_index - src_data = add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - yield worker_id, sample_index, src_data - if self._state_restored: - # Restart iterator after restore - break - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() + trace_writer = self.worker_config.worker_trace_writer() + with trace_writer.span( + "SimpleSavableDatasetWrapper.__iter__", + args={ + "config": self.config(), + "loader_id": self.dataloader_id, + "rank": self.worker_config.rank, + "worker_id": self.worker_config.rank_worker_id(), + }, + level=1, + ): + self._state_restored = True + worker_id = self.worker_config.rank_worker_id() + global_worker_id = self.worker_config.global_worker_id() + while self._state_restored: + self._state_restored = False + self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) + worker_active = True + trace_sample_flow: dict = {} + try: + # For tracing, this contains the current sample flow + + def _trace_next(): + # Trace the next sample flow + nonlocal trace_sample_flow + span = trace_writer.span( + name="SimpleSavableDatasetWrapper.__iter__.loop.dataset.next", + args={"sample_index": self._sample_index}, + level=1, + ) + trace_sample_flow = trace_writer.flow( + f"w{global_worker_id}_s{self._sample_index}", + level=1, + ).save() + return span + + for src_data in trace_writer.iterable(self.dataset, next=_trace_next): + self.worker_config.worker_deactivate() + worker_active = False + sample_index = self._sample_index + src_data = add_sample_restore_key( + src_data, global_worker_id, sample_index, src=self + ) + self._sample_index += 1 + trace_writer.instant( + "SimpleSavableDatasetWrapper.__iter__.loop.yield", + args={"sample_index": sample_index}, + level=1, + ) + yield worker_id, sample_index, src_data, trace_sample_flow + if self._state_restored: + # Restart iterator after restore + break + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + finally: + if worker_active: + self.worker_config.worker_deactivate() + trace_writer.instant("SimpleSavableDatasetWrapper.__iter__.break", level=1) - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool + def save_state(self): + return SavableDatasetState( + rng=None, + dataset_state=self.dataset.save_state(), + sample_index=self._sample_index, ) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, + + def restore_state(self, state: SavableDatasetState): + self._sample_index = state.sample_index + self.dataset.restore_state(state.dataset_state) + + def can_restore_sample(self) -> bool: + return self.dataset.can_restore_sample() + + def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: + with self.worker_config.worker_trace_writer().span( + "SimpleSavableDatasetWrapper.restore_sample", args={"restore_key": restore_key}, level=1 + ): + id, global_worker_id, sample_idx = restore_key[:3] + assert id == type(self).__name__ + restore_key = restore_key[3:] + self.worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool ) - finally: - self.worker_config.worker_deactivate() + try: + return add_sample_restore_key( + self.dataset.restore_sample(restore_key), + global_worker_id, + sample_idx, + src=self, + ) + finally: + self.worker_config.worker_deactivate() def config(self) -> Dict[str, Any]: return self.dataset.config() @@ -147,23 +220,6 @@ def __str__(self): return f"SimpleSavableDatasetWrapper(dataset={self.dataset})" -@edataclass -class SavableDatasetState(State): - """State of the dataset wrapper. It stores the global random states and the index of the next - sample to be returned from the dataset. This class is not intended to be used directly, but by - :class:`megatron.energon.SavableDatasetWrapper`.""" - - #: The state of all the system random number generators - rng: SystemRngState - #: The state of the savable dataset - dataset_state: FlexState - #: Index of the next sample to be returned from the dataset - sample_index: int - - def __repr__(self): - return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" - - @edataclass class SavableCheckpoint: """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created @@ -235,6 +291,7 @@ def __init__( cmd_queues: List[torch.multiprocessing.Queue], result_queues: List[torch.multiprocessing.Queue], cache_pool: CachePool, + dataloader_id: int, ): """ Create the savable dataset wrapper for multiprocessing data loading. @@ -250,6 +307,7 @@ def __init__( cmd_queues: The command queues for communicating with the worker processes. result_queues: The result queues for communicating with the worker processes. cache_pool: The cache pool to use for the dataset. + dataloader_id: The id of the data loader for logging purposes. """ num_workers = max(worker_config.num_workers, 1) @@ -266,6 +324,7 @@ def __init__( self._cmd_queues = cmd_queues self._result_queues = result_queues self.cache_pool = cache_pool + self.dataloader_id = dataloader_id @staticmethod def _command_thread(self: "SavableDatasetWrapper"): @@ -276,35 +335,52 @@ def _command_thread(self: "SavableDatasetWrapper"): # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting") assert self._command_lock is not None - try: - while self._running: - try: - cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) - except queue.Empty: - continue - # print(f"recv cmd {cmd_args}") - with self._command_lock: - cmd = cmd_args[0] - if cmd is None: - break + trace_writer = self.worker_config.worker_trace_writer() + trace_writer.metadata_thread_name("command_thread") + + with trace_writer.span( + "SavableDatasetWrapper._command_thread", args={"config": self.config()}, level=1 + ): + try: + while self._running: try: - fn = getattr(self, cmd) - self._result_queues[self._worker_id].put( - {self._worker_id: fn(*cmd_args[1:])} - ) - # print(f"result sent") - except Exception as e: - traceback.print_exc() - self._result_queues[self._worker_id].put({self._worker_id: e}) - # print(f"exc sent") - except BaseException: - traceback.print_exc() - raise - finally: - pass - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") + cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) + except queue.Empty: + continue + # print(f"recv cmd {cmd_args}") + with self._command_lock: + cmd = cmd_args[0] + if cmd is None: + break + with trace_writer.span( + f"SavableDatasetWrapper._command_thread.{cmd}", level=1 + ): + try: + fn = getattr(self, cmd) + self._result_queues[self._worker_id].put( + {self._worker_id: fn(*cmd_args[1:])} + ) + # print(f"result sent") + except Exception as e: + traceback.print_exc() + self._result_queues[self._worker_id].put({self._worker_id: e}) + trace_writer.instant( + "SavableDatasetWrapper._command_thread.cmd_lock.exception", + args={ + "exc": f"{type(e).__name__}: {e}", + "tb": traceback.format_exc(), + }, + level=1, + ) + # print(f"exc sent") + except BaseException: + traceback.print_exc() + raise + finally: + pass + # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") - def __len__(self): + def inner_len(self): return len(self.dataset) def __del__(self): @@ -317,96 +393,146 @@ def __del__(self): # print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed") def __iter__(self): + trace_writer = self.worker_config.worker_trace_writer() # First: Set the worker offset globally for the current worker WorkerConfig.worker_id_offset = self._worker_offset self._worker_id = self.worker_config.rank_worker_id() global_worker_id = self.worker_config.global_worker_id() - if self._cmd_thread is None: - self._running = True - self._command_lock = threading.RLock() - weakref_self = weakref.proxy(self) - self._cmd_thread = threading.Thread( - target=SavableDatasetWrapper._command_thread, - name="command_thread", - args=(weakref_self,), - daemon=True, - ) - self._cmd_thread.start() - # atexit.register(lambda: weakref_self.__del__()) - try: - assert self._command_lock is not None - with self._command_lock: - if self._workers_restore_from[self._worker_id] is not None: - my_state = self._workers_restore_from[self._worker_id] - my_ds_state = my_state.dataset_state - assert my_state is not None - if my_ds_state is None: - self.dataset.reset_state_deep() + with trace_writer.span( + "SavableDatasetWrapper.__iter__", + args={ + "config": self.config(), + "loader_id": self.dataloader_id, + "rank": self.worker_config.rank, + "worker_id": self._worker_id, + "global_worker_id": global_worker_id, + }, + level=1, + ): + if self._cmd_thread is None: + self._running = True + self._command_lock = threading.RLock() + weakref_self = weakref.proxy(self) + self._cmd_thread = threading.Thread( + target=SavableDatasetWrapper._command_thread, + name="command_thread", + args=(weakref_self,), + daemon=True, + ) + self._cmd_thread.start() + # atexit.register(lambda: weakref_self.__del__()) + try: + assert self._command_lock is not None + with self._command_lock: + if self._workers_restore_from[self._worker_id] is not None: + with trace_writer.span("SavableDatasetWrapper.__iter__.restore", level=1): + my_state = self._workers_restore_from[self._worker_id] + my_ds_state = my_state.dataset_state + assert my_state is not None + if my_ds_state is None: + self.dataset.reset_state_deep() + else: + self.dataset.restore_state(my_ds_state) + self._restore_state(my_state) + self._workers_restore_from[self._worker_id] = None else: - self.dataset.restore_state(my_ds_state) - self._restore_state(my_state) - self._workers_restore_from[self._worker_id] = None - else: - # Store the initial state of the worker if we stop before the first sample - self._store_checkpoint() - # If skipping, also restart the iterator to reach the start of the restored - # checkpoint - last_was_skip = True - while last_was_skip: - dataset_has_samples = False - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - dataset_has_samples = True - if self._workers_skip_samples[self._worker_id] > 0: - # Skip ahead to reach the start of the restored checkpoint - # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") - self._workers_skip_samples[self._worker_id] -= 1 - self._sample_index += 1 - last_was_skip = True - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - continue - last_was_skip = False - sample_index = self._sample_index - add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - self._store_checkpoint() - try: - self._command_lock.release() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") - # Commands may be executed only when data was yielded, not during - # iteration fetching. - # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") - yield self._worker_id, sample_index, src_data - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") - self._command_lock.acquire() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - # If the dataset is empty, don't try again and again - if not dataset_has_samples: - break - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") - # Always store a final checkpoint (it's likely to be saved) - self._store_checkpoint(force=True) + # Store the initial state of the worker if we stop before the first sample + self._store_checkpoint() + + # If skipping, also restart the iterator to reach the start of the restored + # checkpoint + last_was_skip = True + while last_was_skip: + dataset_has_samples = False + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + trace_sample_flow: dict = {} + try: + with trace_writer.span("SavableDatasetWrapper.__iter__.loop", level=1): + + def _trace_next(): + nonlocal trace_sample_flow + span = trace_writer.span( + "SavableDatasetWrapper.__iter__.loop.dataset.next", + args={ + "sample_index": self._sample_index, + }, + level=1, + ) + + trace_sample_flow = trace_writer.flow( + f"w{global_worker_id}_s{self._sample_index}", + level=1, + ).save() + return span + + for src_data in trace_writer.iterable( + self.dataset, + next=_trace_next, + ): + self.worker_config.worker_deactivate() + worker_active = False + dataset_has_samples = True + if self._workers_skip_samples[self._worker_id] > 0: + with trace_writer.span( + "SavableDatasetWrapper.__iter__.loop.skip", level=1 + ): + # Skip ahead to reach the start of the restored checkpoint + # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") + self._workers_skip_samples[self._worker_id] -= 1 + self._sample_index += 1 + last_was_skip = True + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + continue + last_was_skip = False + sample_index = self._sample_index + add_sample_restore_key( + src_data, global_worker_id, sample_index, src=self + ) + self._sample_index += 1 + self._store_checkpoint() + try: + self._command_lock.release() + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") + # Commands may be executed only when data was yielded, not during + # iteration fetching. + # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") + trace_writer.instant( + "SavableDatasetWrapper.__iter__.loop.yield", + args={"sample_index": sample_index}, + level=1, + ) + yield ( + self._worker_id, + sample_index, + src_data, + trace_sample_flow, + ) + finally: + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") + self._command_lock.acquire() + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + finally: + if worker_active: + self.worker_config.worker_deactivate() + + # If the dataset is empty, don't try again and again + if not dataset_has_samples: + trace_writer.instant("SavableDatasetWrapper.__iter__.break", level=1) + break + finally: + # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") + # Always store a final checkpoint (it's likely to be saved) + self._store_checkpoint(force=True) def _store_checkpoint(self, force: bool = False) -> None: """ @@ -417,6 +543,7 @@ def _store_checkpoint(self, force: bool = False) -> None: Args: force: If true, ignore time or frequency condition. """ + trace_writer = self.worker_config.worker_trace_writer() if ( force or ( @@ -427,26 +554,24 @@ def _store_checkpoint(self, force: bool = False) -> None: ) or self._sample_index <= 1 ): - # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") - self._last_checkpoints.append( - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, + with trace_writer.span( + "SavableDatasetWrapper._store_checkpoint", + args={"force": force, "sample_index": self._sample_index}, + level=1, + ): + # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") + self._last_checkpoints.append( + SavableCheckpoint( + state=self._save_state(), + checkpoint_time=time.perf_counter(), + sample_index=self._sample_index, + ) ) - ) - if len(self._last_checkpoints) > self.n_checkpoints: - self._last_checkpoints.pop(0) + if len(self._last_checkpoints) > self.n_checkpoints: + self._last_checkpoints.pop(0) def _save_state(self) -> SavableDatasetState: """Saves the internal state""" - ( - np_tp, - np_state, - pos, - has_gauss, - cached_gaussian, - ) = np.random.get_state() return SavableDatasetState( rng=SystemRng.save_state(), dataset_state=self.dataset.save_state(), @@ -557,19 +682,24 @@ def can_restore_sample(self) -> bool: return self.dataset.can_restore_sample() def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, + with self.worker_config.worker_trace_writer().span( + "SavableDatasetWrapper.restore_sample", args={"restore_key": restore_key}, level=1 + ): + id, global_worker_id, sample_idx = restore_key[:3] + assert id == type(self).__name__ + restore_key = restore_key[3:] + self.worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool ) - finally: - self.worker_config.worker_deactivate() + try: + return add_sample_restore_key( + self.dataset.restore_sample(restore_key), + global_worker_id, + sample_idx, + src=self, + ) + finally: + self.worker_config.worker_deactivate() def config(self) -> Dict[str, Any]: return self.dataset.config() @@ -584,7 +714,7 @@ class SavableDataLoaderState(State): processed of a single rank.""" #: The internal state of the dataset (for each worker process) - worker_states: List[Union[SavableDatasetCheckpoint, FlexState]] + worker_states: Union[List[SavableDatasetCheckpoint], List[SavableDatasetState]] #: Which worker will be the next to emit a sample. Used to restore the proper order next_worker_id: int @@ -658,6 +788,8 @@ class SavableDataLoader(DataLoader[T], Generic[T]): #: Instance of the current data iterator. There shall be only one active iterator, such that the # dataset is not iterated multiple times in parallel. The state will proceed. _persistent_iterator: Optional[Iterator[T]] = None + #: Whether the dataloader has running workers. + _has_workers: bool = False #: The index of the current worker. -1 if not started yet. _worker_sample_counters: List[int] #: Id of the next worker to retrieve data from @@ -752,10 +884,11 @@ def __init__( cmd_queues=self.cmd_queues, result_queues=self.result_queues, cache_pool=cache_pool, + dataloader_id=self.id, ) else: dataset = SimpleSavableDatasetWrapper( - dataset, self.worker_config, cache_pool=cache_pool + dataset, self.worker_config, cache_pool=cache_pool, dataloader_id=self.id ) self._worker_sample_counters = [-1] * num_procs @@ -788,16 +921,16 @@ def __init__( **kwargs, ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": dataset.config(), - } - ) + self.worker_config.worker_trace_writer().trace_object_async( + self, + "SavableDataLoader", + args={ + "loader_id": self.id, + "worker_config": self.worker_config.config(), + "config": dataset.config(), + }, + level=1, + ) @staticmethod def next_id() -> int: @@ -805,97 +938,107 @@ def next_id() -> int: SavableDataLoader._next_id += 1 return next_id - def __iter__(self): - outerself = self - - class InnerIterator: - """Internal class which keeps the iterator alive across multiple `iter()` calls. - If the inner iterator is exhausted, will also exhaust and a new instance is needed. - Also saves the last sample index and the next worker id. - """ - - finished: bool = False - iter_idx: int = 0 - id: int - - def __init__(self, iterator): - self._iterator = iterator - self.id = outerself.next_id() - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "SavableDataLoader.iter", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) - - # self._debugf = open( - # f"worker_samples_rank{outerself.worker_config.rank:02}_t{int(time.time())}.log", "w" - # ) - - def __iter__(self): - return self + def __len__(self): + # We override this, because otherwise we'll see warnings + return self.dataset.inner_len() - def __next__(self): + def __iter__(self): + def _inner_generator(iterator): + iter_idx = 0 + id = self.next_id() + trace_writer = self.worker_config.worker_trace_writer() + trace_span = self.worker_config.worker_trace_span() + trace_writer.instant( + "SavableDataLoader.__iter__", + args={ + "world_size": self.worker_config.world_size, + "num_workers": self.worker_config.num_workers, + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ) + with ( + trace_span.span( + "SavableDataLoader.__iter__", + args={ + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ), + trace_writer.generator( + "SavableDataLoader.__iter__.next", + level=1, + ) as trace_generator, + ): try: - worker_id, sample_idx, sample = next(self._iterator) - outerself._worker_sample_counters[worker_id] = sample_idx - # If the next sample will be from the first worker, we can safely resume - outerself._next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1) - # self._debugf.write( - # f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n" - # ) - # self._debugf.flush() - if outerself.worker_config.should_log(level=1): - keys = default_get_keys(sample) - outerself.worker_config.worker_log( - { - **{ - "t": "SavableDataLoader.yield", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, + for worker_id, sample_idx, sample, trace_sample_flow in iterator: + self._worker_sample_counters[worker_id] = sample_idx + # If the next sample will be from the first worker, we can safely resume + self._next_worker_id = (worker_id + 1) % max(self.num_workers, 1) + if self.worker_config.should_log(level=1): + trace_writer.resume_flow(trace_sample_flow).end( + bind_enclosing_slice=True, level=1 + ) + keys = default_get_keys(sample) + trace_span.instant( + "SavableDataLoader.yield", + args={ + "loader_id": self.id, + "iter_id": id, "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": outerself._sample_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._global_sample_idx, + "worker_sample_idx": sample_idx, + "sample_idx": self._sample_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._global_sample_idx, + **({} if keys is None else {"keys": keys}), }, + level=1, + ) + else: + keys = None + with trace_generator.yield_( + last_args={ + "loader_id": self.id, + "iter_id": id, + "worker_id": worker_id, + "worker_sample_idx": sample_idx, + "sample_idx": self._sample_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._global_sample_idx, **({} if keys is None else {"keys": keys}), } - ) - outerself._sample_idx += 1 - outerself._global_sample_idx += 1 - self.iter_idx += 1 - return sample - except StopIteration: - self.finished = True - outerself._next_worker_id = 0 - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "SavableDataLoader.StopIteration", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) - raise + ): + self._sample_idx += 1 + self._global_sample_idx += 1 + iter_idx += 1 + yield sample + # After the source is exhausted, not for GeneratorExit. + self._persistent_iterator = None + self._next_worker_id = 0 + finally: + trace_span.instant( + "SavableDataLoader.StopIteration", + level=1, + args={"loader_id": self.id, "iter_id": id}, + ) + trace_writer.instant( + "SavableDataLoader.StopIteration", + level=1, + args={"loader_id": self.id, "iter_id": id}, + ) if self.num_workers > 0: # Always keep same iterator alive, as long as it yields data - if self._persistent_iterator is None or self._persistent_iterator.finished: - self._persistent_iterator = InnerIterator(super().__iter__()) + if self._persistent_iterator is None: + self._persistent_iterator = _inner_generator(super().__iter__()) self._sample_idx = 0 + self._has_workers = True # print("New Iterator", self._persistent_iterator) return self._persistent_iterator else: - return InnerIterator(super().__iter__()) + return _inner_generator(super().__iter__()) def _worker_command(self, *cmd_args) -> List[Any]: """Executes a command in all workers and returns the results.""" @@ -914,7 +1057,7 @@ def _worker_command(self, *cmd_args) -> List[Any]: def _get_batch_size(self) -> Optional[int]: """Try to infer micro batch size from the dataset""" - if isinstance(self.dataset, SavableDatasetWrapper): + if isinstance(self.dataset, (SavableDatasetWrapper, SimpleSavableDatasetWrapper)): dataset = self.dataset.dataset else: dataset = self.dataset @@ -942,14 +1085,14 @@ def save_state_rank(self) -> Optional[SavableDataLoaderState]: assert isinstance(self.dataset, SimpleSavableDatasetWrapper) worker_states = [self.dataset.save_state()] assert self._next_worker_id == 0 - elif self._persistent_iterator is None: + elif self._has_workers: + # Fetch from worker processes + worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) + else: # Workers configured, but not started yet. # If a state has already been restored, it will be returned. assert isinstance(self.dataset, SavableDatasetWrapper) worker_states = self.dataset.get_initial_checkpoint() - else: - # Fetch from worker processes - worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) if worker_states is None: return None @@ -980,15 +1123,18 @@ def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: old_micro_batch_size = state.micro_batch_size micro_batch_size = self._get_batch_size() - if isinstance(self.dataset, SavableDataset): + if self.num_workers == 0: + # No workers configured + assert isinstance(self.dataset, SimpleSavableDatasetWrapper) assert micro_batch_size == old_micro_batch_size, ( "Changing micro batch size is not allowed without workers" ) assert len(state.worker_states) == 1 - assert isinstance(state.worker_states[0], FlexState) + assert isinstance(state.worker_states[0], SavableDatasetState) self.dataset.restore_state(state.worker_states[0]) else: + # Workers configured assert isinstance(self.dataset, SavableDatasetWrapper) assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states) @@ -1267,7 +1413,7 @@ def __init__( ) dataset = SimpleSavableDatasetWrapper( - dataset, worker_config=self.worker_config, cache_pool=cache_pool + dataset, worker_config=self.worker_config, dataloader_id=self.id, cache_pool=cache_pool ) self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1) @@ -1293,96 +1439,117 @@ def __init__( worker_init_fn=partial(_init_worker, seed_per_worker), **kwargs, ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": self.config(), - } - ) - def __iter__(self): - outerself = self - - class InnerIterator: - """Internal class which keeps the iterator alive across multiple `iter()` calls. - If the inner iterator is exhausted, will also exhaust and a new instance is needed. - Also saves the last sample index and the next worker id. - """ - - iter_idx: int = 0 - id: int - - def __init__(self, iterator): - self._iterator = iterator - self.id = SavableDataLoader.next_id() - - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "BasicDataLoader.iter", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) + self.worker_config.worker_trace_writer().trace_object_async( + self, + "BasicDataLoader", + args={ + "loader_id": self.id, + "worker_config": self.worker_config.config(), + "config": self.config(), + }, + level=1, + ) - def __iter__(self): - return self + def __len__(self): + # We override this, because otherwise we'll see warnings + return self.dataset.inner_len() + + def __iter__(self): + def _inner_generator(iterator): + iter_idx = 0 + id = SavableDataLoader.next_id() + + trace_writer = self.worker_config.worker_trace_writer() + trace_span = self.worker_config.worker_trace_span() + + trace_writer.instant( + "BasicDataLoader.__iter__", + args={ + "rank": self.worker_config.rank, + "world_size": self.worker_config.world_size, + "num_workers": self.worker_config.num_workers, + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ) - def __next__(self): + with ( + trace_span.span( + "BasicDataLoader.iter", + args={ + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ), + trace_writer.generator( + "BasicDataLoader.iter", + level=1, + ) as trace_generator, + ): try: - worker_id, sample_idx, sample = next(self._iterator) - # If the next sample will be from the first worker, we can safely resume - self.next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1) - if outerself.worker_config.should_log(level=1): - keys = default_get_keys(sample) - outerself.worker_config.worker_log( - { - **{ - "t": "BasicDataLoader.yield", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, + for worker_id, sample_idx, sample, trace_sample_flow in iterator: + if self.worker_config.should_log(level=1): + trace_writer.resume_flow(trace_sample_flow).end( + bind_enclosing_slice=True, level=1 + ) + keys = default_get_keys(sample) + trace_span.instant( + "BasicDataLoader.yield", + args={ + "loader_id": self.id, + "iter_id": id, "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": self.iter_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._sample_idx, + "worker_sample_idx": sample_idx, + "sample_idx": iter_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._sample_idx, + **({} if keys is None else {"keys": keys}), }, + level=1, + ) + else: + keys = None + with trace_generator.yield_( + last_args={ + "loader_id": self.id, + "iter_id": id, + "worker_id": worker_id, + "worker_sample_idx": sample_idx, + "sample_idx": iter_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._sample_idx, **({} if keys is None else {"keys": keys}), } - ) - outerself._sample_idx += 1 - self.iter_idx += 1 - return sample - except StopIteration: - self.next_worker_id = 0 - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "BasicDataLoader.StopIteration", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) - raise + ): + self._sample_idx += 1 + iter_idx += 1 + yield sample + finally: + trace_span.instant( + "BasicDataLoader.StopIteration", + args={ + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ) + trace_writer.instant( + "BasicDataLoader.StopIteration", + level=1, + args={"loader_id": self.id, "iter_id": id}, + ) - return InnerIterator(super().__iter__()) + return _inner_generator(super().__iter__()) def config(self): """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` and `restore_state` to match the configuration as well.""" return { "type": type(self).__qualname__, - "num_workers": self.num_workers, + "num_workers": self.worker_config.num_workers, "persistent_workers": self.persistent_workers, "pin_memory": self.pin_memory, "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, diff --git a/src/megatron/energon/source_info.py b/src/megatron/energon/source_info.py index 1a90c228..8bf1153b 100644 --- a/src/megatron/energon/source_info.py +++ b/src/megatron/energon/source_info.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Sequence, Union from megatron.energon.epathlib import EPath @@ -36,3 +36,14 @@ def add_source_info(sample: Any, *source_infos: SourceInfo) -> Any: else: sample["__sources__"] = (*sample["__sources__"], *source_infos) return sample + + +def get_source_info(sample: Any) -> Sequence[SourceInfo]: + """Get the source info of a sample.""" + from megatron.energon.flavors.base_dataset import Sample + + if isinstance(sample, Sample) or hasattr(sample, "__sources__"): + return sample.__sources__ or () + elif isinstance(sample, dict): + return sample.get("__sources__") or () + return () diff --git a/src/megatron/energon/tools/analyze_debug.py b/src/megatron/energon/tools/analyze_debug.py index e1eb1754..6133fb5b 100644 --- a/src/megatron/energon/tools/analyze_debug.py +++ b/src/megatron/energon/tools/analyze_debug.py @@ -2,20 +2,13 @@ # SPDX-License-Identifier: BSD-3-Clause import json -import time -import traceback -from concurrent.futures.process import ProcessPoolExecutor +import re +from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import ( - Container, - Dict, Generator, - Iterable, - List, Literal, - Optional, - Tuple, TypedDict, Union, ) @@ -23,7 +16,8 @@ import click import numpy as np from PIL import Image -from tqdm import tqdm + +from megatron.energon.tools.analyze_debug_merge import merge_log_reader cpal = np.array( [ @@ -97,54 +91,6 @@ ) -class YieldBatchLogLine(TypedDict): - # Json example: - # { - # "t": "yield_batch", - # "r": 1, - # "w": 1, - # "m": "train", - # "idx": 1, - # "keys": ["parts/data-train-000051.tar/528866", ...], - # } - t: Literal["yield_batch"] - r: int - w: int - m: Literal["train", "val"] - idx: int - keys: List[str] - - -class SampleLoaderYieldLogLine(TypedDict): - # Json example: - # { - # "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", - # "r": 1, - # "w": 1, - # "index": 528800, - # "key": "parts/data-train-000051.tar/528866", - # "shard": "parts/data-train-000051.tar", - # "count": 633, - # "epoch": 0, - # "epoch_count": 633 - # } - t: Literal["WebdatasetSampleLoaderDataset._slices_iter.yield"] - r: int - w: int - #: The global index in the underlying dataset (concats of all shards) - index: int - #: The sample key from the shard, concatenated as f"{shard}/{key}" - key: str - #: Name of the shard - shard: str - #: Number of samples yielded from the sample loader over all epochs - count: int - #: Number of repetitions of the dataset (=epochs). First epoch is 0. - epoch: int - #: Number of samples yielded from the sample loader in the current epoch - epoch_count: int - - class AutosizingHeatmapWriter: """Writes a heatmap, automatically resizing it if necessary.""" @@ -165,6 +111,7 @@ def add(self, sample_id: int, step: int, src: int) -> None: Args: sample_id: The sample id (y-axis) step: The step (x-axis) + src: The source rank (colorizing) """ # Resize heatmap? while self.heatmap.shape[0] * self.heatmap_sample_factor <= sample_id: @@ -211,8 +158,7 @@ def save(self, path: Union[Path, str], gain: float): @click.command(name="analyze-debug") @click.argument( - "log_paths", - nargs=-1, + "log_path", type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path), ) @click.option( @@ -238,24 +184,6 @@ def save(self, path: Union[Path, str], gain: float): default=10, help="Gain (=multiplication factor) for the heatmap", ) -@click.option( - "--force-loading-order", - is_flag=True, - default=False, - help="If true, force using the dataloader loading order instead of batch data", -) -@click.option( - "--include-modality", - type=str, - default="train", - help="Choose which modality/modalities (train,val) to include. Comma separate for multiple.", -) -@click.option( - "--skip", - type=int, - default=0, - help="If >0, skip this many steps at the beginning of log file parsing.", -) @click.option( "--no-colors", is_flag=True, @@ -263,360 +191,286 @@ def save(self, path: Union[Path, str], gain: float): help="If set, disable colorizing ranks.", ) def command( - log_paths: List[Path], + log_path: Path, heatmap_path: Path, heatmap_steps: int, heatmap_samples: int, heatmap_gain: float, - force_loading_order: bool, - include_modality: str, - skip: int, no_colors: bool, ): """Internal tool to analyze randomness. The LOG_PATH should point to the folder with the debug log, or to a single log file.""" - if len(log_paths) == 0: - raise click.ClickException("No log paths specified") - log_files = [] - for log_path in log_paths: - if log_path.is_dir(): - log_files.extend(sorted(log_path.glob("*.jsonl"))) - elif log_path.is_file(): - log_files.append(log_path) - else: - raise click.ClickException(f"Invalid log path: {log_path}") + heatmap = AutosizingHeatmapWriter(heatmap_samples, heatmap_steps, colorize=not no_colors) - if len(log_files) == 0: - raise click.ClickException("No log files found") + print(f"Analyzing log {log_path}...") - heatmap = AutosizingHeatmapWriter(heatmap_samples, heatmap_steps, colorize=not no_colors) + if log_path.is_dir(): + log_paths = list(log_path.glob("*.json")) + else: + log_paths = [log_path] + + print(f"Analyzing {len(log_paths)} logs...") + + loader_log_loader = LogLoader(log_paths) - print(f"Analyzing {len(log_files)} logs...") - - modalities = [m.strip() for m in include_modality.split(",")] - - key_index = {} - count = 0 - if not force_loading_order: - loaders = [LoaderLogIter(log_file, start_idx=skip) for log_file in log_files] - loaders_by_id: Dict[int, Tuple[LoaderInfo, List[LoaderLogIter]]] = {} - with ProcessPoolExecutor(max_workers=16) as executor: - for loader, loader_info in tqdm( - executor.map(_proc_map_loader, loaders), total=len(loaders) - ): - for loader_id, loader_info in loader_info.items(): - if loader_id in loaders_by_id: - existing_loader_info, existing_loaders = loaders_by_id[loader_id] - assert ( - existing_loader_info.modality == loader_info.modality - and existing_loader_info.path == loader_info.path - ), ( - f"Found multiple loaders for {loader_id}: {existing_loader_info.modality, existing_loader_info.path} and {loader_info.modality, loader_info.path}" - ) - existing_loader_info.global_count = max( - existing_loader_info.global_count, loader_info.global_count - ) - existing_loaders.append(loader) - else: - loaders_by_id[loader_id] = (loader_info, [loader]) - print("Available loaders:") - selected_loader_id = None - must_select = False - for loader_id, (loader_info, _iters) in loaders_by_id.items(): + key_index: dict[str, int] = defaultdict(lambda: len(key_index)) + + for entry in loader_log_loader.read_entries(): + if isinstance(entry, LogLoader.LoaderIterator): print( - f" {loader_id}: {loader_info.modality} {loader_info.path} {loader_info.global_count} steps" + f"Loader rank={entry.rank} loader_id={entry.loader_id} iter_id={entry.iter_id} nw={entry.num_workers} ws={entry.world_size}" ) - if loader_info.modality in modalities: - if selected_loader_id is None: - selected_loader_id = loader_id - else: - # Have multiple loaders - must_select = True - if must_select: - while True: - loader_id_str = input("Choose loader id: ") - try: - selected_loader_id = int(loader_id_str) - except ValueError: - print(f"Invalid loader id {loader_id_str} 1") - continue - if selected_loader_id in loaders_by_id: - break - print(f"Invalid loader id {selected_loader_id}") - assert selected_loader_id is not None - selected_loader_info, selected_loader_readers = loaders_by_id[selected_loader_id] - print( - f"Reading for loader {selected_loader_id}: {selected_loader_info.modality} {selected_loader_info.path}" - ) - log_iters = [ - (idx, loader.log_entries(loader_ids={selected_loader_id})) - for idx, loader in enumerate(selected_loader_readers) - ] - with tqdm(total=selected_loader_info.global_count) as pbar: - while len(log_iters) > 0: - cur_count = 0 - # Iterate over all iterators for this count and put into heatmap - for src_idx, log_iter in tuple(log_iters): - # Iterate until None (=next count) is encountered - while True: - try: - log_keys = next(log_iter) - except StopIteration: - log_iters.remove((src_idx, log_iter)) - break - except OSError: - traceback.print_exc() - log_iters.remove((src_idx, log_iter)) - break - else: - if log_keys is None: - break - for log_key in log_keys: - key_id = key_index.setdefault(log_key, len(key_index)) - heatmap.add(key_id, count, src_idx) - cur_count += 1 - if cur_count == 0: - print(f"No data for step {count}") - count += 1 - pbar.update(1) + elif isinstance(entry, LogLoader.Worker): + print( + f"Worker rank={entry.loader.rank} loader_id={entry.loader.loader_id} iter_id={entry.loader.iter_id} worker_id={entry.worker_id}" + ) + # elif isinstance(entry, LogLoader.LoadSample): + # print(f"LoadSample {entry.worker.worker_id} {entry.worker.loader.loader_id} {entry.worker.loader.rank} {entry.worker.loader.num_workers} {entry.base_path} {entry.key} {entry.index} {entry.epoch} {entry.epoch_count}") + elif isinstance(entry, LogLoader.YieldSample): + # print(f"YieldSample rank={entry.worker.loader.rank} loader_id={entry.worker.loader.loader_id} iter_id={entry.worker.loader.iter_id} wrk_id={entry.worker.worker_id} sample_idx={entry.sample_idx} iter_idx={entry.iter_idx} global_sample_idx={entry.global_sample_idx} keys={entry.keys}") + if entry.keys is not None: + for key in entry.keys: + heatmap.add( + key_index[key], entry.global_sample_idx, src=entry.worker.loader.rank + ) + elif isinstance(entry, LogLoader.LoadNextEpoch): + # print(f"LoadNextEpoch rank={entry.worker.loader.rank} loader_id={entry.worker.loader.loader_id} iter_id={entry.worker.loader.iter_id} wrk_id={entry.worker.worker_id} epoch_idx={entry.epoch_idx} epoch_sample_count={entry.epoch_sample_count}") + pass + elif isinstance(entry, LogLoader.StopIteration): + # print(f"StopIteration rank={entry.loader.rank} loader_id={entry.loader.loader_id} iter_id={entry.loader.iter_id}") + pass if len(key_index) == 0: - if force_loading_order: - print("Forcing to use sample loader logs") - else: - print("No batch information in logs, trying sample loader logs...") - if modalities != {"train", "val"}: - print(" Data includes all modalities (train and val)") - print( - " Shuffle buffer and batching will not be considered, only the loading order from disk" - ) - log_iters = [ - _iter_sl_log_line_keys(_iter_sl_log_samples(log_file), start_idx=skip) - for log_file in log_files - ] - key_index = {} - count = 0 - start = time.time() - while len(log_iters) > 0: - cur_count = 0 - # Iterate over all iterators for this count and put into heatmap - for log_iter in tuple(log_iters): - # Iterate until None (=next count) is encountered - while True: - try: - log_key = next(log_iter) - except StopIteration: - log_iters.remove(log_iter) - break - except OSError: - traceback.print_exc() - log_iters.remove(log_iter) - break - else: - if log_key is None: - break - key_id = key_index.setdefault(log_key, len(key_index)) - heatmap.add(key_id, count) - cur_count += 1 - if cur_count == 0: - print(f"No data for step {count}") - if time.time() - start > 10: - print(f" Step {count}") - start = time.time() - count += 1 - - if count == 0: raise click.ClickException("No data found in logs") - print(f"Found {len(key_index)} unique sample keys, {count} steps") + print(f"Found {len(key_index)} unique sample keys, {heatmap.heatmap_step_max + 1} steps") # print(f"Heatmap factors: {heatmap_sample_factor} samples, {heatmap_step_factor} steps") # print(f"Heatmap max: {heatmap_sample_max} samples, {heatmap_step_max} steps") - n_samples, n_steps = heatmap.save(heatmap_path, heatmap_gain) + max_sample, max_step = heatmap.save(heatmap_path, heatmap_gain) print(f"Wrote heatmap to {heatmap_path}") print("Heatmap axes:") - print(f" x-axis: {n_steps} worker steps") - print(f" y-axis: {n_samples} samples") - - -class LoaderInitLogLine(TypedDict): - t: Literal["SavableLoader.__init__", "BasicDataLoader.__init__"] - r: int - w: None - id: int - config: dict - - -class LoaderIterLogLine(TypedDict): - t: Literal["SavableDataLoader.iter", "BasicDataLoader.iter"] - r: int - w: None - id: int - iter_id: int - - -class LoaderYieldLogLine(TypedDict): - t: Literal["SavableDataLoader.yield", "BasicDataLoader.yield"] - r: int - w: None + print(f" x-axis: {max_step + 1} worker steps") + print(f" y-axis: {max_sample + 1} samples") + + +class LogEntry(TypedDict): + """ + Chrome tracing log entry. + *ph*ase values: + - B: Begin + - E: End + - i: Instant + - b: Begin (async) + - e: End (async) + - n: Instant (async) + - C: Counter + - M: Metadata + - s: Flow start + - t: Flow step + - f: Flow end + """ + + ph: Literal["B", "E", "i", "b", "e", "n", "C", "M", "s", "t", "f"] + name: str id: int - iter_id: int - worker_id: int - worker_idx: int - idx: int - iter_idx: int - global_idx: int - keys: Optional[List[str]] - - -class LoaderStopLogLine(TypedDict): - t: Literal["SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"] - r: int - w: None - id: int - iter_id: int - - -LoaderLines = Union[ - LoaderInitLogLine, - LoaderIterLogLine, - LoaderYieldLogLine, - LoaderStopLogLine, -] - -LOADER_LOG_LINE_TYPES_T = ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.iter", - "BasicDataLoader.iter", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - "SavableDataLoader.StopIteration", - "BasicDataLoader.StopIteration", -) - - -@dataclass -class LoaderInfo: - id: int - modality: str - path: str - global_count: int - - -class LoaderLogIter: - def __init__(self, path: Path, start_idx: int = 0): - self._path = path - self._start_idx = start_idx - - def _iter_log_lines(self, which: Iterable[str]) -> Generator[LoaderLines, None, None]: - try: - with self._path.open("r") as rf: - for line in rf: - if any(f'"t": "{t}"' in line for t in which): - try: - yield json.loads(line.strip()) - except json.JSONDecodeError: - print("Cannot decode line", repr(line)) - except IOError as e: - print(f"Ignoring IOError: {e} for {self._path}") - - @staticmethod - def _find_config_modality(config: dict) -> Literal["train", "val"]: - assert isinstance(config, dict) - if "map_fn_config" in config and "training" in config["map_fn_config"]: - return "train" if config["map_fn_config"]["training"] else "val" - elif "dataset" in config: - return LoaderLogIter._find_config_modality(config["dataset"]) - elif "dataset_weights" in config: - return LoaderLogIter._find_config_modality(config["dataset_weights"][0][0]) - elif "datasets" in config: - return LoaderLogIter._find_config_modality(config["datasets"][0]) - assert False, f"Unrecognized config {config}" - - @staticmethod - def _find_config_path(config: dict) -> str: - assert isinstance(config, dict) - if "map_fn_config" in config and "_path" in config["map_fn_config"]: - return config["map_fn_config"]["_path"] - elif "dataset" in config: - return LoaderLogIter._find_config_path(config["dataset"]) - elif "dataset_weights" in config: - return LoaderLogIter._find_config_path(config["dataset_weights"][0][0]) - elif "datasets" in config: - return LoaderLogIter._find_config_path(config["datasets"][0]) - assert False, f"Unrecognized config {config}" - - def loaders(self) -> Dict[int, LoaderInfo]: - loaders = {} - for log_line in self._iter_log_lines( - ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - ) - ): - if log_line["t"] in ("SavableLoader.__init__", "BasicDataLoader.__init__"): - loaders[log_line["id"]] = LoaderInfo( - id=log_line["id"], - modality=self._find_config_modality(log_line["config"]), - path=self._find_config_path(log_line["config"]), - global_count=0, - ) - elif log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield"): - loaders[log_line["id"]].global_count = log_line["global_idx"] - return loaders - - def log_entries(self, loader_ids: Container[int]) -> Generator[Optional[List[str]], None, None]: - idx = self._start_idx - for log_line in self._iter_log_lines(("SavableDataLoader.yield", "BasicDataLoader.yield")): - if ( - log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield") - and log_line["id"] in loader_ids - ): - assert log_line["global_idx"] >= idx, ( - f"Found entry {log_line} with wrong idx <{idx}" - ) - while log_line["global_idx"] != idx: - yield None - idx += 1 - if "keys" in log_line: - yield log_line["keys"] - - def __repr__(self) -> str: - return f"log({str(self._path)})" - - -def _proc_map_loader(loader: LoaderLogIter) -> Tuple[LoaderLogIter, Dict[int, LoaderInfo]]: - return (loader, loader.loaders()) - - -def _iter_sl_log_line_keys( - log_lines: Iterable[SampleLoaderYieldLogLine], - start_idx: int = 0, -) -> Generator[Optional[str], None, None]: - count = start_idx - for log_line in log_lines: - if log_line["count"] < start_idx: - continue - assert log_line["count"] >= count - while log_line["count"] != count: - yield None - count += 1 - yield log_line["key"] - - -def _iter_sl_log_samples(path: Path) -> Generator[SampleLoaderYieldLogLine, None, None]: - with path.open("r") as rf: - for line in rf: - if '"t": "WebdatasetSampleLoaderDataset._slices_iter.yield"' in line: - try: - yield json.loads(line.strip()) - except json.JSONDecodeError: - print("Cannot decode line", repr(line)) + ts: int + pid: int + tid: int + args: dict + s: Literal["t", "p", "g"] + + +class LogLoader: + """Loads a chrome tracing log file. Extract specific information from it.""" + + _re_pname = re.compile(r"^dprank(\d+)(?:_worker(\d+))?$") + + def __init__(self, paths: list[Path]): + self._paths = paths + + def _log_reader(self, path: Path) -> Generator[LogEntry, None, None]: + """Reads a log file and yields a tuple of the line and the ts.""" + had_end = False + with open(path, "rb") as f: + assert f.read(2) == b"[\n", "Log file must start with a JSON array" + for line in f: + if not line: + assert had_end, "Log file must end with a JSON array" + if line.endswith(b"]\n"): + had_end = True + else: + assert line.endswith(b",\n"), f"Log file must be newline-terminated: {line}" + yield json.loads(line[:-2]) + assert had_end, "Log file must end with a JSON array" + + def _log_reader_all(self) -> Generator[LogEntry, None, None]: + """Reads all log files and yields a tuple of the line and the ts.""" + if len(self._paths) == 1: + yield from self._log_reader(self._paths[0]) + else: + for entry in merge_log_reader(self._paths): + yield json.loads(entry) + + @dataclass + class LoaderIterator: + world_size: int + rank: int + num_workers: int + loader_id: int + iter_id: int + + @dataclass + class Worker: + worker_id: int + loader: "LogLoader.LoaderIterator" + + @dataclass + class LoadSample: + worker: "LogLoader.Worker" + base_path: str + key: str + global_sample_index: int + sample_count: int + epoch_idx: int + epoch_sample_count: int + + @dataclass + class LoadNextEpoch: + worker: "LogLoader.Worker" + epoch_idx: int + epoch_sample_count: int + + @dataclass + class YieldSample: + worker: "LogLoader.Worker" + worker_sample_idx: int + sample_idx: int + iter_idx: int + global_sample_idx: int + keys: list[str] | None + + @dataclass + class StopIteration: + loader: "LogLoader.LoaderIterator" + + def read_entries(self): + # Maps pid to (rank, worker_id|None) + procs: dict[int, tuple[int, int | None]] = dict() + # Maps (pid, tid) to worker_id|None, only for main threads + proc_workers: dict[tuple[int, int], int | None] = dict() + # Maps (pid, tid) to worker + workers_by_pid_tid: dict[tuple[int, int], LogLoader.Worker] = dict() + # Maps (rank, loader_id, worker_id) to worker + workers_by_rank_loader_id_iter_id_worker_id: dict[ + tuple[int, int, int], LogLoader.Worker + ] = dict() + # Maps (rank, loader_id) to loader + loaders_by_rank_loader_id: dict[tuple[int, int], LogLoader.LoaderIterator] = dict() + # Maps (rank, loader_id, iter_id) to loader + loaders_by_rank_loader_id_iter_id: dict[tuple[int, int, int], LogLoader.LoaderIterator] = ( + dict() + ) + for log_entry in self._log_reader_all(): + ph = log_entry["ph"] + name = log_entry.get("name") + if ph == "M": + if name == "process_name": + pid = log_entry["pid"] + pname = log_entry["args"]["name"] + m = self._re_pname.match(pname) + if m: + rank = int(m.group(1)) + if m.group(2) is not None: + worker_id = int(m.group(2)) + else: + worker_id = None + procs[log_entry["pid"]] = (rank, worker_id) + if name == "thread_name": + thread_name = log_entry["args"]["name"] + pid = log_entry["pid"] + tid = log_entry["tid"] + if thread_name in ("main", "worker_main"): + proc_workers[(pid, tid)] = procs[pid][1] + if ph == "n": + if name == "WebdatasetSampleLoaderDataset._slices_iter.yield": + yield LogLoader.LoadSample( + worker=workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])], + base_path=log_entry["args"]["base_path"], + key=log_entry["args"]["key"], + global_sample_index=log_entry["args"]["global_sample_index"], + sample_count=log_entry["args"]["sample_count"], + epoch_idx=log_entry["args"]["epoch_idx"], + epoch_sample_count=log_entry["args"]["epoch_sample_count"], + ) + elif name == "WebdatasetSampleLoaderDataset._slices_iter.next_epoch": + yield LogLoader.LoadNextEpoch( + worker=workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])], + epoch_idx=log_entry["args"]["epoch_idx"], + epoch_sample_count=log_entry["args"]["epoch_sample_count"], + ) + elif name in ("SavableDataLoader.yield", "BasicDataLoader.yield"): + rank = procs[log_entry["pid"]][0] + yield LogLoader.YieldSample( + worker=workers_by_rank_loader_id_iter_id_worker_id[ + (rank, log_entry["args"]["loader_id"], log_entry["args"]["worker_id"]) + ], + worker_sample_idx=log_entry["args"]["worker_sample_idx"], + sample_idx=log_entry["args"]["sample_idx"], + iter_idx=log_entry["args"]["iter_idx"], + global_sample_idx=log_entry["args"]["global_sample_idx"], + keys=log_entry["args"].get("keys", None), + ) + elif name in ("SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"): + rank = procs[log_entry["pid"]][0] + yield LogLoader.StopIteration( + loader=loaders_by_rank_loader_id_iter_id[ + (rank, log_entry["args"]["loader_id"], log_entry["args"]["iter_id"]) + ], + ) + elif ph == "B": + if name in ( + "SavableDatasetWrapper.__iter__", + "SimpleSavableDatasetWrapper.__iter__", + ): + rank = procs[log_entry["pid"]][0] + # This is not 100% correct, but it's the best mapping we can get right now. + loader = loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] + worker = LogLoader.Worker( + worker_id=log_entry["args"]["worker_id"], + loader=loader, + ) + workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])] = worker + workers_by_rank_loader_id_iter_id_worker_id[ + (rank, loader.loader_id, worker.worker_id) + ] = worker + yield worker + elif ph == "b": + if name in ("SavableDataLoader.__iter__", "BasicDataLoader.__iter__"): + rank = procs[log_entry["pid"]][0] + loader = loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] + loader.iter_id = log_entry["args"]["iter_id"] + loaders_by_rank_loader_id_iter_id[(rank, loader.loader_id, loader.iter_id)] = ( + loader + ) + yield loader + elif name in ("SavableDataLoader", "BasicDataLoader"): + cfg_rank = log_entry["args"]["worker_config"]["rank"] + rank = procs[log_entry["pid"]][0] + assert rank == cfg_rank, f"Rank mismatch: {rank} != {cfg_rank}" + + loader = LogLoader.LoaderIterator( + world_size=log_entry["args"]["worker_config"]["world_size"], + rank=rank, + num_workers=log_entry["args"]["worker_config"]["num_workers"], + loader_id=log_entry["args"]["loader_id"], + iter_id=-1, + ) + # This is not 100% correct, but it's the best mapping we can get right now. + loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] = loader + yield loader if __name__ == "__main__": diff --git a/src/megatron/energon/tools/analyze_debug_merge.py b/src/megatron/energon/tools/analyze_debug_merge.py new file mode 100644 index 00000000..3fcddb91 --- /dev/null +++ b/src/megatron/energon/tools/analyze_debug_merge.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +import os +import re +from pathlib import Path +from typing import ( + Callable, + Generator, + List, +) + +import click + +# Regular expressions for parsing the log file efficiently +_re_ts = re.compile(rb'"ts":(\d+)') +_re_pid = re.compile(rb'"pid":(\d+)') + + +@click.command(name="analyze-debug-merge") +@click.argument( + "log_paths", + nargs=-1, + type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path), +) +@click.argument( + "output_path", + type=click.Path(exists=False, writable=True, dir_okay=False, path_type=Path), +) +def command( + log_paths: List[Path], + output_path: Path, +): + """Internal tool to merge multiple debug logs into a single file. + + The LOG_PATH should point to the folder with the debug log, or to a single log file.""" + + if len(log_paths) == 0: + raise click.ClickException("No log paths specified") + log_files = [] + for log_path in log_paths: + if log_path.is_dir(): + log_files.extend(sorted(log_path.glob("*.json"))) + elif log_path.is_file(): + log_files.append(log_path) + else: + raise click.ClickException(f"Invalid log path: {log_path}") + + if len(log_files) == 0: + raise click.ClickException("No log files found") + + print(f"Merging {len(log_files)} log files into {output_path}") + + entry_count = 0 + with open(output_path, "wb") as f: + f.write(b"[\n") + for entry in merge_log_reader(log_files): + f.write(entry + b",\n") + entry_count += 1 + f.seek(-2, os.SEEK_END) + f.write(b"]\n") + print(f"Merged {len(log_files)} log files with {entry_count} entries into {output_path}") + + +def merge_log_reader(log_files: List[Path]) -> Generator[bytes, None, None]: + """Merges multiple log files into a single stream of entries.""" + + # Map of (file_idx, pid) to new pid + repid_map = {} + + def get_repid(file_idx: int, pid: int) -> int: + if (file_idx, pid) in repid_map: + return repid_map[(file_idx, pid)] + repid_map[(file_idx, pid)] = len(repid_map) + return repid_map[(file_idx, pid)] + + log_readers = [ + _log_reader(log_file, functools.partial(get_repid, idx)) + for idx, log_file in enumerate(log_files) + ] + log_entries = [] + for idx in reversed(range(len(log_readers))): + reader = log_readers[idx] + try: + while True: + entry, ts = next(reader) + if ts is not None: + log_entries.append((entry, ts)) + break + yield entry + except StopIteration: + log_readers.pop(idx) + # Read the entries ordered by ts + while len(log_entries) > 0: + # Find the smallest entry, get that entry and fetch the next entry from the reader + min_ts = log_entries[0][1] + min_entry_idx = 0 + for entry_idx, (_, ts) in enumerate(log_entries[1:], 1): + if ts < min_ts: + min_ts = ts + min_entry_idx = entry_idx + min_entry, _ = log_entries[min_entry_idx] + yield min_entry + while True: + try: + next_entry, ts = next(log_readers[min_entry_idx]) + if ts is not None: + log_entries[min_entry_idx] = (next_entry, ts) + break + yield next_entry + except StopIteration: + del log_readers[min_entry_idx] + del log_entries[min_entry_idx] + break + + +def _log_reader( + log_file: Path, pidmap: Callable[[int], int] +) -> Generator[tuple[bytes, int | None], None, None]: + """Reads a log file and yields a tuple of the line and the ts.""" + + def pidmap_sub(match: re.Match[bytes]) -> bytes: + return b'"pid":' + str(pidmap(int(match.group(1)))).encode() + + had_end = False + with open(log_file, "rb") as f: + assert f.read(2) == b"[\n", "Log file must start with a JSON array" + for line in f: + if not line: + assert had_end, "Log file must end with a JSON array" + if line.endswith(b"]\n"): + had_end = True + else: + assert line.endswith(b",\n"), f"Log file must be newline-terminated: {line}" + line = _re_pid.sub(pidmap_sub, line) + ts = _re_ts.search(line) + if ts is None: + yield line[:-2], None + else: + yield line[:-2], int(ts.group(1)) + assert had_end, "Log file must end with a JSON array" + + +if __name__ == "__main__": + command() diff --git a/src/megatron/energon/tracing.py b/src/megatron/energon/tracing.py new file mode 100644 index 00000000..dc367784 --- /dev/null +++ b/src/megatron/energon/tracing.py @@ -0,0 +1,1609 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import gc +import json +import os +import threading +import time +import traceback +import weakref +from contextlib import AbstractContextManager, contextmanager +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import ( + IO, + Any, + Callable, + ClassVar, + ContextManager, + Dict, + Iterable, + Literal, + Mapping, + Optional, + TypeVar, + Union, + cast, +) + +import numpy as np +import torch + +__all__ = [ + "TraceWriter", + "Span", + "AsyncSpan", + "AsyncContext", + "Flow", + "ObjectTrace", + "NoopTraceWriter", +] + +T = TypeVar("T") + +_JSON_OPEN = b"[\n" +_JSON_NEXT = b",\n" +_JSON_CLOSE = b"]\n" + + +def _timestamp_us() -> int: + """Return current time in micro-seconds as int.""" + # Use time_ns, such that it's synchronized between processes. + return time.time_ns() // 1_000 # convert ns -> µs + + +def _cur_thread_id() -> int: + """Return current thread id as int.""" + tid = threading.get_ident() + while tid > 0xFFFFFFFF: + tid = (tid & 0xFFFFFFFF) ^ (tid >> 32) + return tid + + +class JsonEncoder(json.JSONEncoder): + """Custom JSON encoder that handles numpy arrays, torch tensors, and dataclasses.""" + + def default(self, o: Any) -> Any: + # Handle numpy arrays directly + if isinstance(o, (np.ndarray, torch.Tensor)): + try: + return o.tolist() + except Exception: + return str(o)[:250] + + # Handle dataclass *instances* (exclude dataclass *types*). + if is_dataclass(o) and not isinstance(o, type): + return {"__type__": type(o).__name__, **asdict(o)} + + return super().default(o) + + +class TraceWriter(AbstractContextManager): + """Chrome-trace writer with live-streaming capabilities. + + This helper produces trace logs that follow the *Trace Event Format* as + consumed by Chrome's ``chrome://tracing`` and the Perfetto UI. We output + the simplest JSON variant – a flat **array of event objects** – because it + can be concatenated on the fly. + + The public surface consists of one generic :py:meth:`emit` method that + serialises an *event dictionary* directly plus a set of convenience + helpers – :py:meth:`span`, :py:meth:`instant`, :py:meth:`async_begin`, + :py:meth:`flow_start`, :py:meth:`counter`, :py:meth:`object_new`, … – that + wrap the *phase* field (``ph``) semantics defined in the Chromium spec: + + ============ ============================================================= + Phase (``ph``) Helper(s) + ------------ ------------------------------------------------------------- + ``B``/``E`` :py:meth:`span` (or :pyclass:`Span` ctx-mgr) + ``i`` :py:meth:`instant` + ``b``/``n``/``e`` :py:meth:`async_begin`, :py:meth:`async_instant`, + :py:meth:`async_end` and the :pyclass:`AsyncSpan` + context-manager + ``s``/``t``/``f`` :py:meth:`flow_start`, :py:meth:`flow_step`, + :py:meth:`flow_end` + ``C`` :py:meth:`counter` + ``N``/``O``/``D`` :py:meth:`object_new`, :py:meth:`object_snapshot`, + :py:meth:`object_delete` and :pyclass:`ObjectTrace` + ============ ============================================================= + + For further background on each event family refer to the *Event + Descriptions* section in the Trace-Event specification. + """ + + _write_lock: threading.Lock + _pid: int + _events: int + _closed: bool + _stream: IO[bytes] + _own_stream: Optional[IO[bytes]] + _flush_interval: int + _pending: int + _log_level: int + + _global_next_id_lock: ClassVar[threading.Lock] = threading.Lock() + _global_next_id: ClassVar[int] = 0 + + def __init__( + self, + stream: Union[str, Path, IO[bytes]], + *, + pid: int | None = None, + log_level: int = 0, + ) -> None: + self._pid = pid if pid is not None else os.getpid() + self._events = 0 + self._closed = False + self._write_lock = threading.Lock() + + if isinstance(stream, (str, Path)): + # Ensure parent directory exists when stream is a path. + path = Path(stream) + path.parent.mkdir(parents=True, exist_ok=True) + self._stream = path.open("wb+") + buffering = os.stat(stream).st_blksize + self._own_stream = self._stream + self._flush_interval = int(buffering * 0.8) + else: + self._stream = stream + self._own_stream = None + try: + buffering = os.stat(stream).st_blksize + except Exception: + buffering = 4096 + self._flush_interval = int(buffering * 0.8) + + self._pending = 0 + + # logging level (lower is more verbose) — default 0 + self._log_level = log_level + + # Initialise the JSON array with a closing bracket so the file is + # syntactically complete right away. + self._stream.write(_JSON_OPEN + _JSON_CLOSE) + self._stream.flush() + + # --------------------------------------------------------------------- + # Low-level helpers + # --------------------------------------------------------------------- + + @classmethod + def _next_id(cls) -> int: + """Return a new unique identifier.""" + with cls._global_next_id_lock: + cls._global_next_id += 1 + return cls._global_next_id + + def _write_raw(self, json_event: bytes, *, flush: bool = False) -> None: + """Write raw *json_event* bytes keeping the trace JSON valid. Flushes the stream if needed. + + Args: + json_event: A fully-serialised event as UTF-8 encoded JSON bytes. + flush: If *True* the underlying stream is flushed after the write. + """ + with self._write_lock: + self._stream.seek(-len(_JSON_CLOSE), os.SEEK_END) + if self._events > 0: + json_event = _JSON_NEXT + json_event + _JSON_CLOSE + else: + json_event = json_event + _JSON_CLOSE + self._stream.write(json_event) + self._pending += len(json_event) + if flush or self._pending >= self._flush_interval: + self._stream.flush() + self._pending = 0 + + def close(self) -> None: + if not self._closed: + self._closed = True + with self._write_lock: + self._stream.flush() + if self._own_stream is not None: + self._own_stream.close() + self._own_stream = None + + def flush(self) -> None: + with self._write_lock: + self._stream.flush() + + def _emit(self, event: Dict[str, Any]) -> None: + """Serialize *event* mapping and append it to the trace. + + Args: + event: A dictionary that already fulfills the Trace-Event schema + expectations. + """ + json_event = json.dumps( + event, separators=(",", ":"), ensure_ascii=False, cls=JsonEncoder + ).encode("utf-8") + self._write_raw(json_event) + self._events += 1 + + # Convenience helpers -------------------------------------------------- + + def duration_begin( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit a *duration* event pair. + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + """ + if level > self._log_level: + return + event = { + "name": name, + "ph": "B", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def duration_end( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit the *end* of a *duration* event pair (``ph='E'``). + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + """ + if level > self._log_level: + return + event = { + "name": name, + "ph": "E", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def span( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "Span": + """Return a context manager capturing a *duration* event pair. + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + + Returns: + Span – a context manager emitting matching ``B``/``E`` events. + """ + if level > self._log_level: + return _NOOP_SPAN + return Span(self, name=name, cat=cat, args=args) + + def iterable( + self, + iterable: Iterable[T], + *, + name: Optional[str] = None, + next: Optional[Callable[[], ContextManager]] = None, + level: int = 0, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + if level > self._log_level: + return iterable + assert (name is not None) != (next is not None), "Either name xor next must be provided" + if name is not None: + return iterable_wrapper(iterable, span=lambda: self.span(name)) + else: + assert next is not None + return iterable_wrapper(iterable, span=next) + + def instant( + self, + name: str, + *, + cat: str | None = None, + scope: Optional[Literal["t", "p", "g"]] = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit a zero-duration *instant* event (``ph='i'``). + + Args: + name: Display name. + cat: Optional categories. + scope: Trace-viewer scope selector – ``t`` (thread), ``p`` (process) + or ``g`` (global). Defaults to ``t``. + args: Optional arguments payload. + level: Logging level. + """ + if level > self._log_level: + return + event = { + "name": name, + "ph": "i", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if scope is not None: + event["s"] = scope + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def generator( + self, + name: str, + *, + cat: str | None = None, + next_args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "GeneratorContext": + if level > self._log_level: + return _NOOP_GENERATOR_CONTEXT + return GeneratorContext(self, name=name, cat=cat, next_args=next_args) + + # Async events -------------------------------------------------------- + + def async_begin( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> Union[int, str]: + """Start a *nestable async* chain (``ph='b'``). + + Args: + name: Event display name. + id: Correlation identifier (int or str). + cat: Optional categories. + args: Optional argument object. + level: Logging level. + """ + if id is None: + id = self._next_id() + if level > self._log_level: + return id + + event = { + "name": name, + "ph": "b", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + return id + + def async_instant( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit an *instant* step for a nestable async chain (``ph='n'``). + + Args: + name: Event name. + id: Correlation identifier. + cat: Categories. + args: Additional arguments. + level: Logging level. + """ + if level > self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "n", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def async_end( + self, + name: str, + *, + id: Union[int, str], + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Finish a *nestable async* chain (``ph='e'``). + + Args: + id: Correlation identifier. + cat: Categories. + args: Additional arguments. + level: Logging level. + """ + if level > self._log_level: + return + event = { + "ph": "e", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def async_span( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "AsyncSpan": + """Return an *AsyncSpan* context-manager for a nestable async chain. + + Args: + name: Display name. + id: Correlation identifier to keep events together. + cat: Categories. + args: Arguments attached to the begin event. + level: Logging level. + + Returns: + AsyncSpan context manager. + """ + if level > self._log_level: + return _NOOP_ASYNC_SPAN + if id is None: + id = self._next_id() + + return AsyncSpan( + self, + name=name, + id=id, + cat=cat, + args=args, + ) + + def async_flow( + self, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + level: int = 0, + ) -> "AsyncContext": + """Return an *AsyncFlow* context-manager for a nestable async chain. + + Args: + id: Correlation identifier. + cat: Categories. + level: Logging level. + """ + if level > self._log_level: + return _NOOP_ASYNC_CONTEXT + if id is None: + id = self._next_id() + + return AsyncContext( + self, + id=id, + cat=cat, + ) + + def async_generator( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + next_args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "AsyncGeneratorContext": + """Emit an async *generator* (``ph='g'``) event within this async flow.""" + if level > self._log_level: + return _NOOP_ASYNC_GENERATOR_CONTEXT + if id is None: + id = self._next_id() + return AsyncGeneratorContext( + self, + name=name, + id=id, + cat=cat, + next_args=next_args, + ) + + # Counter events ------------------------------------------------------ + + def counter( + self, + name: str, + value: Union[int, float, Dict[str, Union[int, float]]], + *, + id: Union[int, str, None] = None, + cat: str | None = None, + level: int = 0, + ) -> None: + """Emit a numerical *counter* (``ph='C'``). + + Args: + name: Counter track name. + value: Either a single numeric value or a mapping of series-name to + numeric value. + id: Optional counter identifier (name+id pair becomes counter key). + cat: Categories. + level: Logging level. + """ + if level > self._log_level: + return + if isinstance(value, Mapping): + args_field = value + else: + args_field = {"value": value} + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "C", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + "args": args_field, + } + if id is not None: + event["id"] = id + if cat is not None: + event["cat"] = cat + self._emit(event) + + def async_object_trace( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + snapshot: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "ObjectTrace": + """Create an :class:`ObjectTrace` helper. + + Args: + name: Object type/name. + id: Identifier to correlate with future snapshots/deletion. + cat: Categories. + snapshot: Optional initial snapshot emitted right after ``N``. + level: Logging level. + + Returns: + AsyncObjectTrace instance. + """ + if level > self._log_level: + return _NOOP_OBJECT_TRACE + if id is None: + id = self._next_id() + + return ObjectTrace( + self, + name=name, + id=id, + cat=cat, + initial_snapshot=snapshot, + ) + + def trace_object_async( + self, + obj: Any, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "ObjectTrace": + """Attach tracing to an existing Python *obj* until GC. + + Args: + obj: Target instance to monitor. + name: Trace-viewer object name. + cat: Categories. + level: Logging level. + + Returns: + AsyncObjectTrace handle. + """ + if not gc.is_tracked(obj): + raise ValueError("Object is not tracked by the garbage collector") + if level > self._log_level: + return _NOOP_OBJECT_TRACE + trace = self.async_object_trace(name, id=id(obj), cat=cat, snapshot=args) + weakref.finalize(obj, trace.delete) + return trace + + # Metadata ------------------------------------------------------------ + + def metadata( + self: "TraceWriter", + name: str, + *, + args: Dict[str, Any], + tid: int | None = None, + ) -> None: + """Emit a generic *metadata* event (``ph='M'``). + + Args: + name: Metadata event name (e.g. ``process_name``). + args: Arguments dict as required by the spec. + tid: Thread id; required for thread metadata. + """ + event = { + "name": name, + "ph": "M", + "pid": self._pid, + } + if tid is not None: + event["tid"] = tid + if args: + event["args"] = dict(args) + self._emit(event) + + def metadata_process_name(self, name: str) -> None: + """Set the current process name.""" + self.metadata("process_name", args=dict(name=name)) + + def metadata_thread_name(self, name: str) -> None: + """Set the current thread name.""" + self.metadata("thread_name", args=dict(name=name), tid=_cur_thread_id()) + + # Flow events -------------------------------------------------------- + + def flow_start( + self, + name: str, + *, + id: Optional[int] = None, + cat: str | None = None, + level: int = 0, + ) -> Union[int, str]: + """Emit a *flow start* (``ph='s'``) event. The flow is bound to the enclosing slice. + + Args: + name: Display name. + id: Correlation identifier. + cat: Categories. + args: Additional arguments. + level: Logging level. + """ + if id is None: + id = self._next_id() + if level > self._log_level: + return id + + event: Dict[str, Any] = { + "name": name, + "ph": "s", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + self._emit(event) + return id + + def flow_step( + self, + name: str, + *, + id: int, + cat: str | None = None, + level: int = 0, + ) -> None: + """ + Emit a *flow step* (``ph='t'``) event. The flow is bound to the enclosing slice. + + Args: + name: The name of the flow. + id: The id of the flow. + cat: The category of the flow. + level: The level of the flow. + """ + if level > self._log_level: + return + + event: Dict[str, Any] = { + "name": name, + "ph": "t", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + self._emit(event) + + def flow_end( + self, + name: str, + *, + id: int, + cat: str | None = None, + bind_enclosing_slice: bool = False, + level: int = 0, + ) -> None: + """Emit a *flow end* (``ph='f'``) event. The flow is finished either in the enclosing slice or at the next slice. + + Args: + name: The name of the flow. + id: The id of the flow. + cat: The category of the flow. + bind_enclosing_slice: If *True*, adds ``bp='e'`` to bind to the + enclosing slice (see Trace Event Format), otherwise binds to the next slice. + level: The level of the flow. + """ + if level > self._log_level: + return + + event: Dict[str, Any] = { + "name": name, + "ph": "f", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if bind_enclosing_slice: + event["bp"] = "e" + self._emit(event) + + def flow( + self, + name: str, + *, + id: Optional[int] = None, + cat: str | None = None, + level: int = 0, + ) -> "Flow": + """Emit a *flow* event.""" + if level > self._log_level: + return _NOOP_FLOW + if id is None: + id = self._next_id() + return Flow(self, name=name, id=id, cat=cat) + + def resume_flow(self, saved_flow: dict) -> "Flow": + """Resume a flow from a dictionary.""" + if len(saved_flow) == 0: + return _NOOP_FLOW + return Flow(self, **saved_flow, resuming=True) + + # Exception --------------------------------------------------------- + + def async_exc( + self, + *, + name: str, + id: Union[int, str, None] = None, + cat: str | None = None, + level: int = 0, + ) -> None: + """Emit an *exception* event as an async instant (``ph='n'``). + + This is primarily used by :class:`AsyncFlow` to surface + exceptions that happened inside a flow. + """ + + if id is None: + id = self._next_id() + if level > self._log_level: + return + + # Represent exception as string to keep JSON serialisable. + exc_repr = traceback.format_exc().splitlines() + + self.async_instant(name, id=id, cat=cat, args={"exception": exc_repr}, level=level) + + # Context management --------------------------------------------------- + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.close() + return False + + # ------------------------------------------------------------------ + # Representation helpers + # ------------------------------------------------------------------ + + def __repr__(self) -> str: # pragma: no cover + status = "closed" if self._closed else f"{self._events} events" + return f"" + + +class Span(AbstractContextManager): + """Context manager for *duration* events. + + See :py:meth:`TraceWriter.span`. + """ + + __slots__ = ("_writer", "_name", "_cat", "_args", "_begin_ts") + _writer: Optional[TraceWriter] + _name: str + _cat: Optional[str] + _args: Dict[str, Any] | None + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._cat = cat + self._writer.duration_begin(self._name, cat=self._cat, args=args or None) + self._args = None + + def update_args(self, args: Dict[str, Any]) -> None: + if self._args is None: + self._args = args + else: + self._args.update(args) + + def end(self, args: Optional[Dict[str, Any]] = None) -> None: + if self._writer is None: + return + if self._args and args: + self._args.update(args) + self._writer.duration_end(self._name, cat=self._cat, args=self._args or args or None) + self._args = None + self._writer = None + + # ------------------------------------------------------------------ + # Context management + # ------------------------------------------------------------------ + + def __enter__(self): # noqa: D401 + return self + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.end() + + +class _NoopSpan(AbstractContextManager): + def begin(self, *args, **kwargs) -> None: + pass + + def update_args(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_SPAN = cast(Span, _NoopSpan()) + + +class AsyncSpan(AbstractContextManager): + """Context manager for *nestable async* events. + + Use :py:meth:`instant` for ``n`` events inside the span. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + "_args", + ) + + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + _args: Optional[Dict[str, Any]] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + self._args = None + + self._writer.async_begin(self._name, id=self._id, cat=self._cat, args=args or None) + + def update_args(self, args: Dict[str, Any]) -> None: + if self._args is None: + self._args = args + else: + self._args.update(args) + + def end(self, args: Optional[Dict[str, Any]] = None) -> None: + if self._writer is None: + return + if self._args and args: + self._args.update(args) + self._writer.async_end(self._name, id=self._id, cat=self._cat, args=self._args or None) + self._args = None + self._writer = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.end() + + +class NoopAsyncSpan(AbstractContextManager): + def update_args(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_ASYNC_SPAN = cast(AsyncSpan, NoopAsyncSpan()) + + +class AsyncContext: + """Context manager for *nestable async* events with the same id.""" + + __slots__ = ( + "_writer", + "_id", + "_cat", + ) + + _writer: TraceWriter + _id: Union[int, str] + _cat: Optional[str] + + def __init__( + self, + writer: TraceWriter, + *, + id: Union[int, str], + cat: str | None = None, + ) -> None: + self._writer = writer + self._id = id + self._cat = cat + + def instant(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *instant* (``ph='n'``) event within this async flow.""" + self._writer.async_instant(name, id=self._id, cat=self._cat, args=args, level=level) + + def start(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *start* (``ph='b'``) event within this async flow.""" + self._writer.async_begin(name, id=self._id, cat=self._cat, args=args, level=level) + + def end(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *end* (``ph='e'``) event within this async flow.""" + self._writer.async_end(name, id=self._id, cat=self._cat, args=args, level=level) + + def span( + self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0 + ) -> AsyncSpan: + """Emit an async *span* (``ph='s'``) event within this async flow.""" + return self._writer.async_span(name, id=self._id, cat=self._cat, args=args, level=level) + + def generator( + self, name: str, *, next_args: Optional[Dict[str, Any]] = None, level: int = 0 + ) -> "AsyncGeneratorContext": + """Get a generator context for the given name. + + This is used to trace all code being executed between yields of a generator. + + Usage:: + + with async_ctx.generator(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + ctx.instant("item", args={"item": item}) + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item + """ + return self._writer.async_generator( + name, id=self._id, cat=self._cat, next_args=next_args, level=level + ) + + def iterable(self, iterable: Iterable[T], *, name: str, level: int = 0) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + if level > self._writer._log_level: + return iterable + return iterable_wrapper(iterable, span=lambda: self.span(name)) + + def exc(self, *, name: str, level: int = 0) -> None: + """Emit an exception event.""" + self._writer.async_exc(name=name, id=self._id, cat=self._cat, level=level) + + +class NoopAsyncContext: + def instant(self, *args, **kwargs) -> None: + pass + + def async_start(self, *args, **kwargs) -> None: + pass + + def async_end(self, *args, **kwargs) -> None: + pass + + def span(self, *args, **kwargs) -> AsyncSpan: + return _NOOP_ASYNC_SPAN + + def iterable(self, iterable, *args, **kwargs) -> Iterable: + return iterable + + def exc(self, *args, **kwargs) -> None: + pass + + +_NOOP_ASYNC_CONTEXT = cast(AsyncContext, NoopAsyncContext()) + + +class AsyncGeneratorContext(AbstractContextManager): + """Context manager for a generator context, that interrupts when yielding. + + Use like this:: + + with writer.async_generator_context(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + ctx.instant("item", args={"item": item}) + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + "_active_scope", + ) + + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + + _active_scope: Optional[AsyncSpan] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: Optional[str] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + self._writer = writer + self._name = name + self._id = id + self._cat = cat + + self._active_scope = self._writer.async_span(name, id=id, cat=cat, args=next_args) + + @contextmanager + def yield_( + self, + *, + last_args: Optional[Dict[str, Any]] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + if self._writer is None: + return + assert self._active_scope is not None + self._active_scope.end(args=last_args) + self._active_scope = None + try: + yield self + finally: + assert self._active_scope is None + self._active_scope = self._writer.async_span( + self._name, id=self._id, cat=self._cat, args=next_args + ) + + def yield_from( + self, + iterable: Iterable[T], + *, + last_args: Optional[Dict[str, Any]] = None, + args: Optional[Dict[str, Any]] = None, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + for item in iterable: + with self.yield_(last_args=last_args, next_args=args): + last_args = None + yield item + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + assert self._active_scope is not None + self._active_scope.end() + self._active_scope = None + self._writer = None + + +class DummyAsyncGeneratorContext(AbstractContextManager): + @contextmanager + def yield_(self, *args, **kwargs): + yield self + + def yield_from(self, iterable: Iterable[T], **kwargs) -> Iterable[T]: + return iterable + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + +_NOOP_ASYNC_GENERATOR_CONTEXT = cast(AsyncGeneratorContext, DummyAsyncGeneratorContext()) + + +class GeneratorContext(AbstractContextManager): + """Context manager for a generator context, that interrupts when yielding. + + Use like this:: + + with writer.generator_context(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item + """ + + __slots__ = ( + "_writer", + "_name", + "_cat", + "_active_scope", + ) + + _writer: Optional[TraceWriter] + _name: str + _cat: Optional[str] + + _active_scope: Optional[Span] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + cat: Optional[str] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + self._writer = writer + self._name = name + self._cat = cat + + self._active_scope = self._writer.span(name, cat=cat, args=next_args) + + @contextmanager + def yield_( + self, + *, + last_args: Optional[Dict[str, Any]] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + if self._writer is None: + return + assert self._active_scope is not None + self._active_scope.end(args=last_args) + self._active_scope = None + try: + yield self + finally: + assert self._active_scope is None + self._active_scope = self._writer.span(self._name, cat=self._cat, args=next_args) + + def yield_from( + self, + iterable: Iterable[T], + *, + last_args: Optional[Dict[str, Any]] = None, + args: Optional[Dict[str, Any]] = None, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + for item in iterable: + with self.yield_(last_args=last_args, next_args=args): + last_args = None + yield item + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + assert self._active_scope is not None + self._active_scope.end() + self._active_scope = None + self._writer = None + + +class DummyGeneratorContext(AbstractContextManager): + @contextmanager + def yield_(self, *args, **kwargs): + yield self + + def yield_from(self, iterable: Iterable[T], **kwargs) -> Iterable[T]: + return iterable + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + +_NOOP_GENERATOR_CONTEXT = cast(GeneratorContext, DummyGeneratorContext()) + + +def iterable_wrapper(iterable: Iterable[T], *, span: Callable[[], ContextManager]) -> Iterable[T]: + """A wrapper for an iterable that emits trace events for each `next` call.""" + ctx = span() + ctx.__enter__() + try: + for item in iterable: + ctx.__exit__(None, None, None) + yield item + ctx = span() + ctx.__enter__() + finally: + ctx.__exit__(None, None, None) + + +class ObjectTrace: + """Lifecycle helper for Trace-Event objects, using async events to trace the object. + + Emits ``N`` on construction, :py:meth:`snapshot` for ``O`` and ``D`` upon + deletion, context exit, or garbage collection. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + ) + + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: str | None = None, + initial_snapshot: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + + # Emit object creation event + self._writer.async_begin(name, id=id, cat=cat, args=initial_snapshot) + + # ------------------------------------------------------------------ + # API + # ------------------------------------------------------------------ + + def snapshot(self, data: Dict[str, Any], *, level: int = 0) -> None: + """Emit snapshot for current state of the object.""" + if self._writer is None: + raise RuntimeError("Cannot snapshot deleted traced object") + self._writer.async_instant( + self._name, + id=self._id, + args=data, + cat=self._cat, + level=level, + ) + + def delete(self) -> None: + """Emit delete event if not already emitted.""" + if self._writer is None: + return + self._writer.async_end( + self._name, + id=self._id, + cat=self._cat, + ) + self._writer = None + + +class NoopObjectTrace: + def snapshot(self, *args, **kwargs) -> None: + pass + + def delete(self, *args, **kwargs) -> None: + pass + + +_NOOP_OBJECT_TRACE = cast(ObjectTrace, NoopObjectTrace()) + + +class NoopTraceWriter: + """A trace writer that does nothing. Used when tracing is disabled.""" + + def close(self) -> None: + pass + + def flush(self) -> None: + pass + + def duration_begin(self, *args, **kwargs) -> None: + pass + + def duration_end(self, *args, **kwargs) -> None: + pass + + def span(self, *args, **kwargs) -> "Span": + return _NOOP_SPAN + + def instant(self, *args, **kwargs) -> None: + pass + + def iterable(self, iterable: Iterable[T], *args, **kwargs) -> Iterable[T]: + return iterable + + def generator(self, *args, **kwargs) -> "GeneratorContext": + return _NOOP_GENERATOR_CONTEXT + + def async_begin(self, *args, **kwargs) -> None: + pass + + def async_instant(self, *args, **kwargs) -> None: + pass + + def async_end(self, *args, **kwargs) -> None: + pass + + def async_span(self, *args, **kwargs) -> "AsyncSpan": + return _NOOP_ASYNC_SPAN + + def async_flow(self, *args, **kwargs) -> "AsyncContext": + return _NOOP_ASYNC_CONTEXT + + def async_generator(self, *args, **kwargs) -> "AsyncGeneratorContext": + return _NOOP_ASYNC_GENERATOR_CONTEXT + + def flow_start(self, *args, **kwargs) -> None: + pass + + def flow_step(self, *args, **kwargs) -> None: + pass + + def flow_end(self, *args, **kwargs) -> None: + pass + + def flow(self, *args, **kwargs) -> "Flow": + return _NOOP_FLOW + + def resume_flow(self, saved_flow: dict) -> "Flow": + return _NOOP_FLOW + + def counter(self, *args, **kwargs) -> None: + pass + + def async_object_trace(self, *args, **kwargs) -> "ObjectTrace": + return _NOOP_OBJECT_TRACE + + def trace_object_async(self, *args, **kwargs) -> "ObjectTrace": + return _NOOP_OBJECT_TRACE + + def metadata(self, *args, **kwargs) -> None: + pass + + def metadata_process_name(self, name: str) -> None: + pass + + def metadata_thread_name(self, name: str) -> None: + pass + + def async_exc(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + def __repr__(self) -> str: + return "" + + +NOOP_TRACE_WRITER: TraceWriter = cast(TraceWriter, NoopTraceWriter()) + + +# ------------------------------------------------------------------ +# Flow context manager +# ------------------------------------------------------------------ + + +class Flow: + """Context manager for *flow* events (``ph='s'``/``'t'``/``'f'``). + + Use :py:meth:`step` for intermediate *t* events inside the flow. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + ) + + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: int, + cat: str | None = None, + resuming: bool = False, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + + # Emit flow *start* event. + if not resuming: + self._writer.flow_start( + self._name, + id=self._id, + cat=self._cat, + ) + + def step(self, *, level: int = 0) -> None: + """Emit a *flow step* (``ph='t'``) event. The flow is bound to the enclosing slice.""" + writer = self._writer + if writer is None: + return + writer.flow_step( + self._name, + id=self._id, + cat=self._cat, + level=level, + ) + + def end( + self, + *, + level: int = 0, + bind_enclosing_slice: bool = False, + ) -> None: + """ + Emit the *flow end* (``ph='f'``) event. The flow is finished either in the enclosing slice or in the next slice. + + Args: + name: The name of the flow. + level: The level of the flow. + bind_enclosing_slice: Whether to bind the flow to the enclosing slice (otherwise bind to the next slice). + """ + writer = self._writer + if writer is None: + return + + writer.flow_end( + self._name, + id=self._id, + cat=self._cat, + bind_enclosing_slice=bind_enclosing_slice, + level=level, + ) + # Mark as closed to avoid further emissions. + self._writer = None + + def save(self) -> dict: + """Return a dictionary representation of the flow, allowing resuming the flow in another process.""" + return { + "name": self._name, + "id": self._id, + "cat": self._cat, + } + + +class NoopFlow: + def step(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def save(self) -> dict: + return {} + + +_NOOP_FLOW = cast(Flow, NoopFlow()) diff --git a/src/megatron/energon/watchdog.py b/src/megatron/energon/watchdog.py index 848d97b7..604ecf09 100644 --- a/src/megatron/energon/watchdog.py +++ b/src/megatron/energon/watchdog.py @@ -279,21 +279,25 @@ def watch_iter(self, iterable: Iterable[T]) -> Iterator[T]: """ it = iter(iterable) - while True: - # Automatically resets timer - self.enable() - try: - item = next(it) - except StopIteration: - self.disable() - break - except: - # On any error, disable watchdog before re-raising - self.disable() - raise - else: - self.disable() - yield item + try: + while True: + # Automatically resets timer + self.enable() + try: + item = next(it) + except StopIteration: + self.disable() + break + except: + # On any error, disable watchdog before re-raising + self.disable() + raise + else: + self.disable() + yield item + finally: + if hasattr(it, "close"): + it.close() def repr_short(obj: Any) -> str: diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index ec01e0b2..ee59412f 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause import hashlib -import json import multiprocessing from dataclasses import dataclass from pathlib import Path @@ -12,6 +11,8 @@ import torch.utils.data from megatron.energon.cache import CachePool +from megatron.energon.fork_hook import DataclassForkMixin +from megatron.energon.tracing import NOOP_TRACE_WRITER, AsyncContext, Flow, TraceWriter __all__ = ("WorkerConfig",) @@ -19,7 +20,7 @@ @dataclass(slots=True, kw_only=True, eq=False) -class WorkerConfig: +class WorkerConfig(DataclassForkMixin): """ Provides information about the current worker and the global configuration. This gives each data parallel rank its proper config. Every `rank` (up to `world_size-1`) must be used. @@ -57,10 +58,12 @@ class WorkerConfig: worker_debug_path: Optional[str] = None #: Log level for worker logging. worker_log_level: int = 0 + #: The current trace writer for the worker. + _worker_trace_writer: Optional[TraceWriter] = None + #: The current trace writer for the worker. + _worker_trace_sample_flow: Optional[Flow] = None #: The opened file for the current worker. Should not be set from outside. _worker_debug_file: Optional[TextIO] = None - #: worker_id of the opened worker debug file - _worker_debug_file_worker_id: Optional[int] = None #: The current sample index within the current iterating worker _sample_index_stack: ClassVar[Optional[List[int]]] = None @@ -68,7 +71,7 @@ class WorkerConfig: active_worker_config: ClassVar[Optional["WorkerConfig"]] = None #: The global rank override for the worker. Required for restoring samples. - _worker_override_global_rank: ClassVar[Optional[List[int]]] = None + _worker_override_global_rank: ClassVar[Optional[int]] = None #: The current cache pool for the worker. _cache_pool: "ClassVar[Optional[CachePool]]" = None @@ -252,27 +255,50 @@ def config(self) -> Dict[str, Any]: def should_log(self, level: int) -> bool: return level <= self.worker_log_level - def worker_log(self, data: dict) -> None: - """Logs the given data to the worker debug file.""" + def __after_in_child_fork__(self): + if self._worker_trace_writer is not None: + self._worker_trace_writer.close() + self._worker_trace_writer = None + self._worker_trace_sample_flow = None + + def __before_fork__(self): + if self._worker_trace_writer is not None: + self._worker_trace_writer.flush() + + def worker_trace_writer(self) -> TraceWriter: if self.worker_debug_path is None: - print(json.dumps(data) + "\n", end="", flush=True) - else: + return NOOP_TRACE_WRITER + if self._worker_trace_writer is None: in_worker = torch.utils.data.get_worker_info() is not None # Additional "worker" with rank_worker_id=0 is the main process. All workers have +1 # as their worker_id. worker_id = ( self.rank * (self.num_workers + 1) + self.rank_worker_id() + (1 if in_worker else 0) ) - if self._worker_debug_file is None or self._worker_debug_file_worker_id != worker_id: - if self._worker_debug_file is not None: - self._worker_debug_file.close() - path = Path( - self.worker_debug_path.format( - worker_id=worker_id, pid=multiprocessing.current_process().ident - ) + if self._worker_trace_writer is not None: + self._worker_trace_writer.close() + path = Path( + self.worker_debug_path.format( + worker_id=worker_id, pid=multiprocessing.current_process().ident ) - path.parent.mkdir(exist_ok=True, parents=True) - self._worker_debug_file = path.open("w") - self._worker_debug_file_worker_id = worker_id - self._worker_debug_file.write(json.dumps(data) + "\n") - self._worker_debug_file.flush() + ) + path.parent.mkdir(exist_ok=True, parents=True) + proc_name = f"dprank{self.global_rank()}" + if in_worker: + proc_name += f"_worker{self.rank_worker_id()}" + self._worker_trace_writer = TraceWriter(path, log_level=self.worker_log_level) + self._worker_trace_writer.metadata_process_name(proc_name) + if in_worker: + self._worker_trace_writer.metadata_thread_name("worker_main") + else: + self._worker_trace_writer.metadata_thread_name("main") + self._worker_trace_writer.flush() + return self._worker_trace_writer + + def worker_trace_span(self) -> AsyncContext: + return self.worker_trace_writer().async_flow() + + def worker_trace_sample_flow(self, level: int): + if self._worker_trace_sample_flow is None: + return + self._worker_trace_sample_flow.step(level=level) diff --git a/src/megatron/energon/wrappers/_log_exception.py b/src/megatron/energon/wrappers/_log_exception.py index a20959aa..791be674 100644 --- a/src/megatron/energon/wrappers/_log_exception.py +++ b/src/megatron/energon/wrappers/_log_exception.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: BSD-3-Clause import traceback -from typing import Optional +from typing import Optional, Sequence from megatron.energon.source_info import SourceInfo -def log_exception(_e: Exception, sample, sources: Optional[list[SourceInfo]] = None): +def log_exception(_e: Exception, sample, sources: Optional[Sequence[SourceInfo]] = None): traceback.print_exc() print("-" * 10) diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 2f4dba89..842df313 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -78,17 +78,26 @@ def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDa def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: if len(self.datasets) == 1: - return self.datasets[0].restore_sample(restore_key) + with self.worker_config.worker_trace_writer().span( + f"{type(self).__name__}.restore_sample", + level=1, + ): + return self.datasets[0].restore_sample(restore_key) else: id, ds_idx = restore_key[:2] assert id == type(self).__name__ restore_key = restore_key[2:] assert isinstance(ds_idx, int) - return add_sample_restore_key( - self.datasets[ds_idx].restore_sample(restore_key), - ds_idx, - src=self, - ) + with self.worker_config.worker_trace_writer().span( + f"{type(self).__name__}.restore_sample", + args={"ds_idx": ds_idx}, + level=1, + ): + return add_sample_restore_key( + self.datasets[ds_idx].restore_sample(restore_key), + ds_idx, + src=self, + ) def save_state(self) -> FlexState: own_state = super().save_state() diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index a7b36a0e..695f90dd 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -11,14 +11,19 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, ) from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + set_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -34,7 +39,7 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ batch_size: int batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, list[T_batch_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_batch_sample], Sequence[SourceInfo]], None] _sample_index: SampleIndex _generator_sample_keys: Optional[Any] _generator_offset: Optional[int] @@ -51,7 +56,7 @@ def __init__( batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, error_handler: Callable[ - [Exception, List[T_batch_sample], List[SourceInfo]], None + [Exception, List[T_batch_sample], Sequence[SourceInfo]], None ] = log_exception, failure_tolerance: Optional[int] = 100, worker_config: WorkerConfig, @@ -107,48 +112,29 @@ def __len__(self): + n_batches_per_worker_ceil * remaining_n_sample_workers ) - def __iter__(self) -> Iterator[T_batch]: + @trace_iter( + name=lambda self: f"BatchDataset({self._function_config(self.batcher)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch]: batch: List[T_batch_sample] = [] sample_restore_keys = [] last_batch_failures = 0 - if self._generator_sample_keys is not None: - sample_restore_keys = self._generator_sample_keys - assert self._generator_offset is not None - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: - batch_sample = self.batcher(batch) - assert isinstance(batch_sample, Generator) - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - target_offset = self._generator_offset - self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) - ): - # Skip other samples - if batch_sub_idx >= target_offset: - self._generator_offset = batch_sub_idx + 1 - yield set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, - ) - self._generator_sample_keys = None - self._generator_offset = None - batch.clear() - sample_restore_keys = [] + batcher = trace_iter.wrap_fn(self.batcher) - def flush(): + def flush() -> Generator[T_batch, None, None]: nonlocal last_batch_failures try: with self._sample_index.ctx() as sample_idx: - batch_sample = self.batcher(batch) + batch_sample = batcher(batch) if isinstance(batch_sample, Generator): assert inspect.isgeneratorfunction(self.batcher), ( f"Generator in {self.batcher} but not marked as such." @@ -172,15 +158,18 @@ def flush(): else: last_batch_failures = 0 set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + trace_iter.sample(batch_sample, {"sample_idx": sample_idx}) yield batch_sample + sample_restore_keys.clear() except GeneratorExit: raise except SkipSample: - pass + trace_iter.skip_sample(batch) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(batch) except Exception as e: - self.error_handler(e, batch) + self.error_handler(e, batch, get_source_info(batch)) + trace_iter.sample_exception(e, batch) last_batch_failures += 1 if ( self.failure_tolerance is not None @@ -190,8 +179,36 @@ def flush(): batch, f"BatchDataset {self.batcher} failed {last_batch_failures} times in a row. Likely your code or dataset are broken.", ) - finally: - sample_restore_keys.clear() + + if self._generator_sample_keys is not None: + sample_restore_keys = self._generator_sample_keys + assert self._generator_offset is not None + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] + with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: + batch_sample = batcher(batch) + assert isinstance(batch_sample, Generator) + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + target_offset = self._generator_offset + self._generator_offset = 0 + for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( + self._sample_index.iter_ctx(batch_sample, sample_idx) + ): + # Skip other samples + if batch_sub_idx >= target_offset: + self._generator_offset = batch_sub_idx + 1 + yield set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *sample_restore_keys, + src=self, + ) + self._generator_sample_keys = None + self._generator_offset = None + batch.clear() + sample_restore_keys = [] for sample in self.dataset: batch.append(sample) @@ -214,40 +231,80 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_batch: - # We need to store multiple indices to restore a batch. - self.assert_can_restore() - if inspect.isgeneratorfunction(self.batcher): - id, sample_idx, batch_sub_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, sample_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - with self._sample_index.ctx(sample_idx): - batch_sample = self.batcher(batch) - if isinstance(batch_sample, Generator): - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "BatchDataset.restore_sample", args={"restore_key": restore_key}, level=1 + ): + # We need to store multiple indices to restore a batch. + self.assert_can_restore() + if inspect.isgeneratorfunction(self.batcher): + id, sample_idx, batch_sub_idx, *samples_restore_keys = restore_key + assert id == type(self).__name__ + else: + id, sample_idx, *samples_restore_keys = restore_key + assert id == type(self).__name__ + with trace_span.span( + "BatchDataset.restore_sample.restore", + args={"len": len(samples_restore_keys)}, + level=2, ): - if cur_batch_sub_idx == batch_sub_idx: - return set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *samples_restore_keys, - src=self, + batch = [ + self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys + ] + with ( + self._sample_index.ctx(sample_idx), + trace_span.span( + f"BatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}", + args={"sample_idx": sample_idx, "len": len(batch)}, + level=2, + ), + ): + batch_sample = self.batcher(batch) + if isinstance(batch_sample, Generator): + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( + self._sample_index.iter_ctx(batch_sample, sample_idx), + name=f"BatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}.next", + level=2, + ): + if cur_batch_sub_idx == batch_sub_idx: + return set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *samples_restore_keys, + src=self, + ) + assert False, f"Batch sub-index {batch_sub_idx} not found in batch" + else: + return set_sample_restore_key( + batch_sample, + sample_idx, + *samples_restore_keys, + src=self, + ) + + def _own_config(self) -> Dict[str, Any]: + return { + "batch_size": self.batch_size, + "batcher": self._function_config(self.batcher), + **( + { + "batcher_config": ( + self.batcher_config() + if callable(self.batcher_config) + else self.batcher_config ) - assert False, f"Batch sub-index {batch_sub_idx} not found in batch" - else: - return set_sample_restore_key( - batch_sample, - sample_idx, - *samples_restore_keys, - src=self, - ) + } + if self.batcher_config + else {} + ), + "batcher_stateless": self.batcher_stateless, + "drop_last": self.drop_last, + "error_handler": self._function_config(self.error_handler), + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index e7c4e972..08bbb0ad 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -1,11 +1,15 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Iterator, List, Tuple, TypeVar +from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar import torch -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -19,7 +23,9 @@ class BlendDataset(BaseWrapperDataset[T_sample, T_sample]): The datasets may be infinite. This dataset is always infinite. """ + datasets: List[SavableDataset[T_sample]] weights: Tuple[float, ...] + dataset_weights: Sequence[Tuple[SavableDataset[T_sample], float]] exhausted: List[bool] _worker_rng: WorkerRng @@ -52,9 +58,15 @@ def reset_state_own(self) -> None: def __len__(self) -> int: # Give the number of samples in inner datasets, disregarding the weight - return sum(len(dataset) for dataset, weight in self.dataset_weights) - - def __iter__(self) -> Iterator[T_sample]: + return sum(len(dataset) for dataset in self.datasets) + + @trace_iter( + name=lambda self: "BlendDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: assert self.worker_has_samples(), "Cannot blend all empty datasets" # Create a list of datasets and their weights, but @@ -107,6 +119,11 @@ def __iter__(self) -> Iterator[T_sample]: self.exhausted = [False] * len(self.dataset_weights) + def _own_config(self) -> Dict[str, Any]: + return { + "weights": self.weights, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/concat_dataset.py b/src/megatron/energon/wrappers/concat_dataset.py index 7388e3b2..b997987c 100644 --- a/src/megatron/energon/wrappers/concat_dataset.py +++ b/src/megatron/energon/wrappers/concat_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Dict, Generic, Iterator, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -32,7 +36,13 @@ def reset_state_own(self) -> None: def __len__(self): return sum(len(dataset) for dataset in self.datasets) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + name=lambda self: "ConcatDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: for ds_idx, dataset in enumerate(self.datasets): for sample in dataset: yield add_sample_restore_key( @@ -41,6 +51,9 @@ def __iter__(self) -> Iterator[T_sample]: src=self, ) + def _own_config(self) -> Dict[str, Any]: + return {} + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 6058cd27..5b6d4688 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -48,9 +52,14 @@ def __init__( def reset_state_own(self) -> None: self._offset = 0 - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + name=lambda self: "EpochizeDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: # Compute the local length for this worker, i.e. all worker's lengths sum up to the total - if self.worker_config.num_workers <= 1: local_length = self.length else: @@ -58,48 +67,40 @@ def __iter__(self) -> Iterator[T_sample]: if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: local_length += 1 - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "EpochizeDataset.epoch_start", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._offset, - "local_length": local_length, - "length": self.length, - } - ) - - offset_range = list(range(self._offset, local_length)) - - # Only iterate if there are samples to iterate - if len(offset_range) > 0: - if self._active_iter is None: - self._active_iter = iter(self.dataset) - - for idx in offset_range: - self._offset = (idx + 1) % local_length + while self._offset < local_length: + try: + if self._active_iter is None: + self._active_iter = iter(self.dataset) + + sample_offset = self._offset + self._offset += 1 try: sample = next(self._active_iter) except StopIteration: + self._active_iter = None break - yield sample - - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "EpochizeDataset.epoch_end", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._offset, - "local_length": local_length, - "length": self.length, - } - ) + + yield add_sample_restore_key( + sample, + sample_offset, + src=self, + ) + except GeneratorExit: + if self._active_iter is not None and hasattr(self._active_iter, "close"): + self._active_iter.close() + self._active_iter = None + raise + if self._offset >= local_length: + self._offset = 0 def __len__(self) -> int: return self.length + def _own_config(self) -> Dict[str, Any]: + return { + "length": self.length, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index 8b99800b..5c0deda8 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Callable, Dict, Generic, Iterator, Optional, TypeVar, Union -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex @@ -50,12 +54,43 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + name=lambda self: f"FilterDataset({self._function_config_short(self.filter_fn)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "sample_idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + filter_fn = trace_iter.wrap_fn(self.filter_fn) + for sample in self.dataset: - with self._sample_index.ctx(): - filter_res = self.filter_fn(sample) + with self._sample_index.ctx() as sample_idx: + filter_res = filter_fn(sample) if filter_res: - yield sample + yield add_sample_restore_key( + sample, + sample_idx, + src=self, + ) + + def _own_config(self) -> Dict[str, Any]: + return { + "filter_fn": self._function_config(self.filter_fn), + **( + { + "filter_fn_config": ( + self.filter_fn_config() + if callable(self.filter_fn_config) + else self.filter_fn_config + ) + } + if self.filter_fn_config + else {} + ), + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index b8a0d4c4..e4fc670a 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -11,6 +11,7 @@ from torch.distributed.distributed_c10d import reduce_op from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -100,27 +101,50 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: in_worker = torch.utils.data.get_worker_info() is not None if in_worker and not _frozen_cuda_tensors_initialized: raise GcFreezeError( "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." ) - if self.freeze: - gc.collect() + @trace_iter.wrap_inner() + def gc_freeze(): gc.freeze() + + @trace_iter.wrap_inner() + def gc_collect(): + gc.collect() + + @trace_iter.wrap_inner() + def gc_unfreeze(): + gc.unfreeze() + + if self.freeze: + gc_collect() + gc_freeze() try: iter = 0 for sample in self.dataset: yield sample iter += 1 if iter >= self.every_n_iter: - gc.collect() + gc_collect() iter = 0 finally: if self.freeze: - gc.unfreeze() + gc_unfreeze() + + def _own_config(self) -> Dict[str, Any]: + return { + "every_n_iter": self.every_n_iter, + "freeze": self.freeze, + } def config(self) -> Dict[str, Any]: # This is transparent, no config to be saved (it does not affect the dataset) diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 8807d6b5..e16c56d6 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -12,6 +12,7 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, @@ -24,8 +25,9 @@ SavableDataset, set_sample_restore_key, ) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.savable import Savable -from megatron.energon.source_info import SourceInfo +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex @@ -67,7 +69,7 @@ class GroupBatchDataset( sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]] batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_batch_sample], Sequence[SourceInfo]], None] _group_key_sample_index: SampleIndex _batch_sample_index: SampleIndex _buckets: Dict[Hashable, Bucket[T_batch_sample]] @@ -83,7 +85,7 @@ def __init__( batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, error_handler: Callable[ - [Exception, List[T_batch_sample], list[SourceInfo]], None + [Exception, List[T_batch_sample], Sequence[SourceInfo]], None ] = log_exception, failure_tolerance: Optional[int] = 100, worker_config: WorkerConfig, @@ -126,7 +128,12 @@ def __len__(self): # Return an upper bound. This is for sure not correct. return len(self.dataset) - def __iter__(self) -> Iterator[T_batch]: + @trace_iter( + next_args={ + "idx": lambda self: self._batch_sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch]: buckets = self._buckets last_batch_failures = 0 @@ -134,6 +141,8 @@ def __iter__(self) -> Iterator[T_batch]: if buckets is None: buckets = self._buckets = dict() + batcher = trace_iter.wrap_fn(self.batcher) + # Load saved state if available for bucket in buckets.values(): bucket.samples.worker_start() @@ -144,7 +153,13 @@ def __iter__(self) -> Iterator[T_batch]: # bucket.samples.debug_print(" ") # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") - def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: + @trace_iter.wrap_inner( + call_args=lambda key, bucket: { + "key": key, + "len": len(bucket.samples), + }, + ) + def flush(key: Any, bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: nonlocal last_batch_failures # Debug print the state # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") @@ -155,19 +170,29 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") try: with self._batch_sample_index.ctx() as sample_idx: - batch_sample = self.batcher(batch_items) + trace_iter.sample( + batch_items, + { + "bucket": str(key), + "bucket_size": bucket.batch_size, + "sample_idx": sample_idx, + "len": len(batch_items), + }, + ) + batch_sample = batcher(batch_items) assert not isinstance(batch_sample, Generator), ( - f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." + f"Batcher {batcher} returned a generator, which is not supported for grouped batching yet." ) last_batch_failures = 0 set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) yield batch_sample except SkipSample: - pass + trace_iter.skip_sample(batch_items) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(batch_items) except Exception as e: - self.error_handler(e, batch_items) + self.error_handler(e, batch_items, get_source_info(batch_items)) + trace_iter.sample_exception(e, batch_items) last_batch_failures += 1 if ( self.failure_tolerance is not None @@ -190,11 +215,13 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: if self.fixed_batch_size is not None: batch_size = self.fixed_batch_size except SkipSample: + trace_iter.skip_sample(sample) continue except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(sample) except Exception as e: - self.error_handler(e, [sample]) + self.error_handler(e, [sample], get_source_info(sample)) + trace_iter.sample_exception(e, sample) continue bucket = buckets.get(bucket_key) if bucket is None: @@ -209,12 +236,12 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: ) bucket.samples.append(sample) if len(bucket.samples) >= bucket.batch_size: - yield from flush(bucket) + yield from flush(bucket_key, bucket) # Flush out last samples if not self.drop_last: - for bucket in buckets.values(): + for bucket_key, bucket in buckets.items(): if len(bucket.samples) > 0: - yield from flush(bucket) + yield from flush(bucket_key, bucket) # Clear the buckets self._buckets.clear() @@ -248,18 +275,57 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: - self.assert_can_restore() - id, sample_idx, *sample_restore_keys = index - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with self._batch_sample_index.ctx(sample_idx): - batch_sample = self.batcher(batch) - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - return batch_sample + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "GroupBatchDataset.restore_sample", + args={"index": index}, + level=1, + ): + self.assert_can_restore() + id, sample_idx, *sample_restore_keys = index + assert id == type(self).__name__ + with trace_span.span( + "GroupBatchDataset.restore_sample.dataset", + args={"sample_idx": sample_idx, "len": len(sample_restore_keys)}, + level=2, + ): + batch = [ + self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys + ] + with ( + self._batch_sample_index.ctx(sample_idx), + trace_span.span( + f"GroupBatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}", + args={"sample_idx": sample_idx, "len": len(batch)}, + level=2, + ), + ): + batch_sample = self.batcher(batch) + set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + return batch_sample + + def _own_config(self) -> Dict[str, Any]: + return { + "bucket": self._function_config(self.sample_group_key), + "batcher": self._function_config(self.batcher), + **( + { + "batcher_config": ( + self.batcher_config() + if callable(self.batcher_config) + else self.batcher_config + ) + } + if self.batcher_config + else {} + ), + "batcher_stateless": self.batcher_stateless, + "drop_last": self.drop_last, + "error_handler": self._function_config(self.error_handler), + } def config(self) -> Dict[str, Any]: return { - "type": type(self).__qualname__, "bucket": self._function_config(self.sample_group_key), "batcher": self._function_config(self.batcher), **( @@ -276,8 +342,8 @@ def config(self) -> Dict[str, Any]: "batcher_stateless": self.batcher_stateless, "drop_last": self.drop_last, "error_handler": self._function_config(self.error_handler), - "worker_config": self.worker_config.config(), "dataset": self.dataset.config(), + "worker_config": self.worker_config.config(), } def __str__(self): diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index 807ad75e..f0434e0b 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -17,7 +17,11 @@ from torch.utils.data import IterableDataset from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + set_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception @@ -91,7 +95,17 @@ def reset_state_own(self) -> None: def __len__(self): return self.len_map_fn(len(self.dataset)) - def __iter__(self) -> Iterator[T_sample_out]: + @trace_iter( + name=lambda self: f"IterMapDataset.__iter__.iter_map_fn:{self._function_config(self.iter_map_fn)}", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample_out]: + iter_map_fn = trace_iter.wrap_fn(self.iter_map_fn) last_sample_wrapper = _LastSampleWrapper(self.dataset) # The iter_map_fn is stateless. Thus we need to know which inner sample created the # outer sample, and the relative outer sample index, so we can restore it. @@ -111,26 +125,30 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: ds_iter = iter(reset_idx_iter()) - # While True will break when the inner dataset is exhausted, but may continue on exception - while True: - iter_idx = 0 - try: - for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)): - yield set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, - ) - sample_restore_keys.clear() - iter_idx += 1 - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) - except Exception as e: - self.error_handler(e, last_sample_wrapper.last_sample) - else: - break + try: + # While True will break when the inner dataset is exhausted, but may continue on exception + while True: + iter_idx = 0 + try: + for sample_idx, sample in self._sample_index.iter_ctx(iter_map_fn(ds_iter)): + yield set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + sample_restore_keys.clear() + iter_idx += 1 + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) + except Exception as e: + self.error_handler(e, last_sample_wrapper.last_sample) + trace_iter.sample_exception(e, last_sample_wrapper.last_sample) + else: + break + finally: + ds_iter.close() def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_iter_fn @@ -142,38 +160,70 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + trace_span = self.worker_config.worker_trace_span() + iter_name = self._function_config(self.iter_map_fn) self.assert_can_restore() - id, sample_idx, iter_idx, *sample_restore_keys = restore_key - assert id == type(self).__name__ - assert isinstance(iter_idx, int) - inner_iter = iter( - self.iter_map_fn( - (self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys) + with trace_span.span( + "IterMapDataset.restore_sample", + args={"restore_key": restore_key}, + level=1, + ): + id, sample_idx, iter_idx, *sample_restore_keys = restore_key + assert id == type(self).__name__ + assert isinstance(iter_idx, int) + inner_iter = iter( + trace_span.iterable( + self.iter_map_fn( + ( + self.dataset.restore_sample(inner_index) + for inner_index in sample_restore_keys + ) + ), + name=f"{iter_name}.next", + level=2, + ) ) - ) - try: - # Skip inner yielded samples to get the correct sample - for skip_idx in range(iter_idx): - with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): - next(inner_iter) - # This is the sample to restore - with self._sample_index.ctx(sample_idx): - sample = next(inner_iter) - return set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, - ) - except StopIteration: - raise RuntimeError( - "Generator did not yield enough samples, but is marked stateless/deterministic." - ) - finally: - # Properly close if it's a generator - if hasattr(inner_iter, "close"): - inner_iter.close() + try: + # Skip inner yielded samples to get the correct sample + for skip_idx in range(iter_idx): + with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): + next(inner_iter) + # This is the sample to restore + with self._sample_index.ctx(sample_idx): + sample = next(inner_iter) + return set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + except StopIteration: + raise RuntimeError( + "Generator did not yield enough samples, but is marked stateless/deterministic." + ) + finally: + # Properly close if it's a generator + if hasattr(inner_iter, "close"): + inner_iter.close() + + def _own_config(self) -> Dict[str, Any]: + return { + "iter_map_fn": self._function_config(self.iter_map_fn), + **( + { + "iter_map_fn_config": ( + self.iter_map_fn_config() + if callable(self.iter_map_fn_config) + else self.iter_map_fn_config + ) + } + if self.iter_map_fn_config + else {} + ), + "len_map_fn": self._function_config(self.len_map_fn), + "error_handler": self._function_config(self.error_handler), + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index d7a491ea..a63270cb 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generic, Iterator, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -46,7 +47,12 @@ def reset_state_own(self) -> None: def __len__(self) -> int: return min(self.length, len(self.dataset)) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: worker_id = self.worker_config.rank_worker_id() # Compute the local limit for this worker, i.e. all worker's limits sum up to the total @@ -57,18 +63,6 @@ def __iter__(self) -> Iterator[T_sample]: if worker_id < self.length % self.worker_config.num_workers: local_limit += 1 - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "LimitDataset.start", - "r": self.worker_config.rank, - "w": worker_id, - "offset": self.current_offset, - "local_limit": local_limit, - "limit": self.length, - } - ) - offset_range = list(range(self.current_offset, local_limit)) # Only iterate self.dataset if there are samples to iterate if len(offset_range) > 0: @@ -79,20 +73,7 @@ def __iter__(self) -> Iterator[T_sample]: self.current_offset = offset + 1 yield sample - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "LimitDataset.done", - "r": self.worker_config.rank, - "w": worker_id, - "offset": self.current_offset, - "local_limit": local_limit, - "limit": self.length, - } - ) - # Reset the inner dataset - self.dataset.reset_state_deep() self.current_offset = 0 if self.reset_after_epoch: self.dataset.reset_state_deep() @@ -100,6 +81,12 @@ def __iter__(self) -> Iterator[T_sample]: def worker_has_samples(self) -> bool: return super().worker_has_samples() and self.length > 0 + def _own_config(self) -> Dict[str, Any]: + return { + "length": self.length, + "reset_after_epoch": self.reset_after_epoch, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 5947b84a..1caae23b 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -4,49 +4,13 @@ from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, default_get_keys, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset T_sample = TypeVar("T_sample") -def default_get_keys(batch: Any) -> Optional[List[str]]: - """Default get_keys, which has some heuristics to find the sample keys.""" - if isinstance(batch, list): - batch = batch[0] - if ( - hasattr(batch, "__key__") - and isinstance(batch.__key__, list) - and all(isinstance(k, str) for k in batch.__key__) - ): - return batch.__key__ - elif ( - hasattr(batch, "__keys__") - and isinstance(batch.__keys__, list) - and all(isinstance(k, str) for k in batch.__keys__) - ): - return batch.__keys__ - elif ( - isinstance(batch, dict) - and "__key__" in batch - and all(isinstance(k, str) for k in batch["__key__"]) - ): - return batch["__key__"] - elif ( - isinstance(batch, dict) - and "__keys__" in batch - and all(isinstance(k, str) for k in batch["__keys__"]) - ): - return batch["__keys__"] - elif ( - isinstance(batch, dict) - and "keys" in batch - and all(isinstance(k, str) for k in batch["keys"]) - ): - return batch["keys"] - return None - - class LogSampleDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset logs every yielded sample to the debug logs.""" @@ -80,25 +44,15 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def _log(self, sample: T_sample) -> None: - if self.worker_config.should_log(level=1): - log_entry = { - "t": "yield_batch", - "r": self.worker_config.rank, - "w": self.worker_config.global_worker_id(), - "m": self.mode, - "idx": self._step, - } - keys = self.get_keys_fn(sample) - if keys is not None: - log_entry["keys"] = keys - - self.worker_config.worker_log(log_entry) - - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + next_args={ + "idx": lambda self: self._step, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: for sample in self.dataset: - self._log(sample) self._step += 1 + trace_iter.sample(sample) yield sample def config(self) -> Dict[str, Any]: diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 2d544c63..8564a6a8 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -10,14 +10,19 @@ Generic, Iterator, Optional, + Sequence, Tuple, TypeVar, Union, ) from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -31,7 +36,7 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T """This dataset wrapper applies a custom function to transform each sample.""" map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]] - error_handler: Callable[[Exception, T_sample, list[SourceInfo]], None] + error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex @@ -49,7 +54,7 @@ def __init__( dataset: SavableDataset[T_sample], map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]], *, - error_handler: Callable[[Exception, T_sample, list[SourceInfo]], None] = log_exception, + error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] = log_exception, stateless_map_fn: bool = False, map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, failure_tolerance: Optional[int] = 100, @@ -90,15 +95,26 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample_out]: + @trace_iter( + name=lambda self: f"MapDataset({self._function_config_short(self.map_fn)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "sample_idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample_out]: last_map_failures = 0 + map_fn = trace_iter.wrap_fn(self.map_fn) + if self._generator_sample_key is not None: assert self._generator_offset is not None sample = self.dataset.restore_sample(self._generator_sample_key) # Do not increment the sample index, use previous index with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: - mapped_sample = self.map_fn(sample) + mapped_sample = map_fn(sample) assert isinstance(mapped_sample, Generator) assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." @@ -124,7 +140,7 @@ def __iter__(self) -> Iterator[T_sample_out]: restore_key = get_sample_restore_key(sample) try: with self._sample_index.ctx() as sample_idx: - mapped_sample = self.map_fn(sample) + mapped_sample = map_fn(sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." @@ -156,11 +172,12 @@ def __iter__(self) -> Iterator[T_sample_out]: except GeneratorExit: raise except SkipSample: - pass + trace_iter.skip_sample(sample) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(sample) except Exception as e: - self.error_handler(e, sample) + self.error_handler(e, sample, get_source_info(sample)) + trace_iter.sample_exception(e, sample) last_map_failures += 1 if ( self.failure_tolerance is not None @@ -181,33 +198,68 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: + trace_span = self.worker_config.worker_trace_span() self.assert_can_restore() - if inspect.isgeneratorfunction(self.map_fn): - id, sample_idx, local_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - assert isinstance(local_idx, int) - else: - id, sample_idx = restore_key[:2] - assert id == type(self).__name__ - restore_key = restore_key[2:] - inner_sample = self.dataset.restore_sample(restore_key) - with self._sample_index.ctx(sample_idx): - mapped_sample = self.map_fn(inner_sample) - if isinstance(mapped_sample, Generator): - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - for idx, (sample_idx, res_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + with trace_span.span( + "MapDataset.restore_sample", + args={"restore_key": restore_key}, + level=1, + ): + if inspect.isgeneratorfunction(self.map_fn): + id, sample_idx, local_idx = restore_key[:3] + assert id == type(self).__name__ + restore_key = restore_key[3:] + assert isinstance(local_idx, int) + else: + id, sample_idx = restore_key[:2] + assert id == type(self).__name__ + restore_key = restore_key[2:] + with trace_span.span( + "MapDataset.restore_sample.dataset", + args={"restore_key": restore_key}, + level=2, ): - if idx == local_idx: - return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) - assert False, ( - "Generator did not yield enough samples, but is marked stateless/deterministic." - ) - else: - return add_sample_restore_key(mapped_sample, sample_idx, src=self) + inner_sample = self.dataset.restore_sample(restore_key) + with ( + self._sample_index.ctx(sample_idx), + trace_span.span( + f"MapDataset.restore_sample.map_fn:{self._function_config(self.map_fn)}", + args={"sample_idx": sample_idx}, + level=2, + ), + ): + mapped_sample = self.map_fn(inner_sample) + if isinstance(mapped_sample, Generator): + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." + ) + for idx, (sample_idx, res_sample) in trace_span.iterable( + enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), + name=f"MapDataset.restore_sample.map_fn:{self._function_config(self.map_fn)}.next", + level=2, + ): + if idx == local_idx: + return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) + assert False, ( + "Generator did not yield enough samples, but is marked stateless/deterministic." + ) + else: + return add_sample_restore_key(mapped_sample, sample_idx, src=self) + + def _own_config(self) -> Dict[str, Any]: + return { + "map_fn": self._function_config(self.map_fn), + **( + { + "map_fn_config": ( + self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config + ) + } + if self.map_fn_config + else {} + ), + "map_fn_stateless": self.stateless_map_fn, + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index c56665ea..c31917aa 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -12,6 +12,7 @@ Iterator, List, Optional, + Sequence, TypeVar, Union, ) @@ -22,7 +23,8 @@ add_sample_restore_key, set_sample_restore_key, ) -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -48,7 +50,7 @@ class PackingDataset( final_packer: Callable[[List[T_encoded_sample]], T_batch_sample] final_packer_stateless: bool packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] - error_handler: Callable[[Exception, List[T_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_sample], Sequence[SourceInfo]], None] #: The buffer for collecting the samples that shall be packed. _reading_buffer: SavableSampleBuffer @@ -61,7 +63,7 @@ class PackingDataset( #: The samples are stored sequentially in the pre_packing_buffer because #: SavableSampleBuffer doesn't support nesting. But to keep the groups #: separate, we need to store the lengths of the groups here. - _pre_packing_lengths: List[List[int]] + _pre_packing_lengths: List[int] #: Sample index for the pre_packer _pre_packing_sample_index: SampleIndex @@ -89,11 +91,11 @@ def __init__( final_packer: Callable[[List[T_encoded_sample]], T_batch_sample], *, final_packer_stateless: bool = False, - sample_encoder: Optional[Callable[[List[T_sample]], T_encoded_sample]] = None, + sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] = None, sample_encoder_stateless: bool = False, packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, error_handler: Callable[ - [Exception, List[T_sample], list[SourceInfo]], None + [Exception, List[T_sample], Sequence[SourceInfo]], None ] = log_exception, pre_packer_failure_tolerance: Optional[int] = 100, final_packer_failure_tolerance: Optional[int] = 100, @@ -161,63 +163,43 @@ def __len__(self): return len(self.dataset) - def _fill_reading_buffer(self, source_iter: Iterator, log_progress: bool = False) -> bool: - """ - Fill the reading buffer with samples from the dataset source iterator. - - Args: - source_iter: Iterator of samples from the dataset. - log_progress: If True, log the progress of the filling. - - Returns: - True if samples are successfully read into the buffer, False if no more data. - """ - - if log_progress: - import tqdm - - pbar_ctx = pbar = tqdm.tqdm(total=self.buffer_size, desc="Filling reading buffer") + @trace_iter( + name=lambda self: f"PackingDataset({self._function_config(self.pre_packer)}, {self._function_config(self.final_packer)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._pre_packing_sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch_sample]: + pre_packer = trace_iter.wrap_fn(self.pre_packer) + final_packer = trace_iter.wrap_fn(self.final_packer) + if self.sample_encoder is not None: + sample_encoder = trace_iter.wrap_fn(self.sample_encoder) else: - pbar_ctx = contextlib.nullcontext() - pbar = None - - with pbar_ctx: - while len(self._reading_buffer) + len(self._pre_packing_buffer) < self.buffer_size: - try: - sample = next(source_iter) - self._reading_buffer.append(sample) - if pbar is not None: - pbar.update(1) - except StopIteration: - return False - return True - - def __iter__(self) -> Iterator[T_batch_sample]: - pre_packing_lengths = self._pre_packing_lengths - # The source dataset - src_iter = iter(self.dataset) + sample_encoder = None last_pre_pack_failures = 0 last_final_pack_failures = 0 last_sample_encoder_failures = 0 - self._pre_packing_buffer.worker_start() - self._reading_buffer.worker_start() - - is_initial_pack = True - + @trace_iter.wrap_inner( + call_args=lambda pack: { + "len": len(pack), + "sample_encoder_idx": self._sample_encoder_sample_index.current_idx, + } + ) def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: """Encode the samples in the pack using the sample encoder.""" nonlocal last_sample_encoder_failures - + assert sample_encoder is not None # Apply the sample encoder to the pack - if self.sample_encoder is None: - return pack encoded_pack = [] for sample in pack: try: with self._sample_encoder_sample_index.ctx() as encode_idx: - encoded_sample = self.sample_encoder(sample) + encoded_sample = sample_encoder(sample) assert not isinstance(encoded_sample, Generator), "Generator not supported" encoded_pack.append( add_sample_restore_key( @@ -227,22 +209,28 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: ) ) except SkipSample: - pass + trace_iter.skip_sample(pack) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(pack) except Exception as e: - self.error_handler(e, [sample]) - last_sample_encoder_failures += 1 + self.error_handler(e, pack, get_source_info(pack)) + trace_iter.sample_exception(e, pack) if ( self.sample_encoder_failure_tolerance is not None and last_sample_encoder_failures >= self.sample_encoder_failure_tolerance ): raise FatalSampleError.from_sample( pack, - f"Sample encoder {self.sample_encoder} failed {last_sample_encoder_failures} times. Likely your code or dataset are broken.", + f"Sample encoder {sample_encoder} failed {last_sample_encoder_failures} times. Likely your code or dataset are broken.", ) return encoded_pack + @trace_iter.wrap_inner( + call_args=lambda: { + "len": len(self._reading_buffer), + "pre_packing_idx": self._pre_packing_sample_index.current_idx, + } + ) def next_pre_pack(): """Take the samples from the reading buffer and select groups of samples to be packed together.""" @@ -254,17 +242,19 @@ def next_pre_pack(): samples = list(self._reading_buffer) # Clear buffer and pre_packing_lengths self._reading_buffer.clear() - pre_packing_lengths.clear() + self._pre_packing_lengths.clear() # Now pre pack the samples try: with self._pre_packing_sample_index.ctx(): - pre_packs = self.pre_packer(samples) + pre_packs = pre_packer(samples) except SkipSample: pre_packs = [] + trace_iter.skip_sample(samples) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(samples) except Exception as e: - self.error_handler(e, samples) + self.error_handler(e, samples, get_source_info(samples)) + trace_iter.sample_exception(e, samples) pre_packs = [] last_pre_pack_failures += 1 if ( @@ -273,7 +263,7 @@ def next_pre_pack(): ): raise FatalSampleError.from_sample( samples, - f"Pre packer {self.pre_packer} failed {last_pre_pack_failures} times. Likely your code or dataset are broken.", + f"Pre packer {pre_packer} failed {last_pre_pack_failures} times. Likely your code or dataset are broken.", ) # Put the pre-packed samples into the pre_packing_buffer @@ -283,23 +273,33 @@ def next_pre_pack(): for pre_pack in pre_packs: if len(pre_pack) > 0: self._pre_packing_buffer.extend(pre_pack) - pre_packing_lengths.append(len(pre_pack)) + self._pre_packing_lengths.append(len(pre_pack)) + @trace_iter.wrap_inner( + call_args=lambda pack: { + "len": len(pack), + "final_packing_idx": self._final_packing_sample_index.current_idx, + } + ) def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" nonlocal last_final_pack_failures - pack = list(self._pre_packing_buffer[: pre_packing_lengths[0]]) + pack = list(self._pre_packing_buffer[: self._pre_packing_lengths[0]]) if len(pack) == 0: return - pack = encode_pack_samples(pack) - - del self._pre_packing_buffer[: pre_packing_lengths[0]] - del pre_packing_lengths[0] + if sample_encoder is not None: + pack = encode_pack_samples(pack) + if len(pack) == 0: + # All samples in the pack were skipped + return + + del self._pre_packing_buffer[: self._pre_packing_lengths[0]] + del self._pre_packing_lengths[0] try: pack_restore_keys = tuple(get_sample_restore_key(sample) for sample in pack) with self._final_packing_sample_index.ctx() as pack_idx: - final_packed_sample = self.final_packer(pack) + final_packed_sample = final_packer(pack) if isinstance(final_packed_sample, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." @@ -322,11 +322,12 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: src=self, ) except SkipSample: - pass + trace_iter.skip_sample(pack) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(pack) except Exception as e: - self.error_handler(e, pack) + self.error_handler(e, pack, get_source_info(pack)) + trace_iter.sample_exception(e, pack) last_final_pack_failures += 1 if ( self.final_packer_failure_tolerance is not None @@ -337,46 +338,101 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: f"Final packer {self.final_packer} failed {last_final_pack_failures} times. Likely your code or dataset are broken.", ) - # Main loop: - pre_pack_round = 0 - while True: - if pre_pack_round > self.pre_packer_failure_tolerance: - raise RuntimeError( - f"Pre packer {self.pre_packer} did not yield any packs after {pre_pack_round} rounds. Likely your code or dataset are broken." - ) - # Fill a portion of the buffer - if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): - # Break out of the main loop when the source is exhausted. - break - is_initial_pack = False - - # Create new pre packs if necessary - if len(pre_packing_lengths) == 0: - assert len(self._pre_packing_buffer) == 0 - assert len(self._reading_buffer) == self.buffer_size - next_pre_pack() - if len(pre_packing_lengths) == 0: - # Retry packing, nothing was returned. - pre_pack_round += 1 - continue + @trace_iter.wrap_inner( + call_args=lambda source_iter, log_progress: { + "to_fill": self.buffer_size + - len(self._reading_buffer) + - len(self._pre_packing_buffer), + "reading_buffer": len(self._reading_buffer), + "pre_packing_buffer": len(self._pre_packing_buffer), + "buffer_size": self.buffer_size, + } + ) + def fill_reading_buffer( + source_iter: Iterator[T_sample], log_progress: bool = False + ) -> bool: + """ + Fill the reading buffer with samples from the dataset source iterator. + + Args: + source_iter: Iterator of samples from the dataset. + log_progress: If True, log the progress of the filling. + + Returns: + True if samples are successfully read into the buffer, False if no more data. + """ + + if log_progress: + import tqdm + + pbar_ctx = pbar = tqdm.tqdm(total=self.buffer_size, desc="Filling reading buffer") + else: + pbar_ctx = contextlib.nullcontext() + pbar = None + + with pbar_ctx: + while len(self._reading_buffer) + len(self._pre_packing_buffer) < self.buffer_size: + try: + sample = next(source_iter) + self._reading_buffer.append(sample) + if pbar is not None: + pbar.update(1) + except StopIteration: + return False + return True + + # The source dataset + src_iter = iter(self.dataset) + + try: + self._pre_packing_buffer.worker_start() + self._reading_buffer.worker_start() + + is_initial_pack = True + + pre_pack_round = 0 - if len(pre_packing_lengths) > 0: + # Main loop: + while True: + if pre_pack_round > self.pre_packer_failure_tolerance: + raise RuntimeError( + f"Pre packer {self.pre_packer} did not yield any packs after {pre_pack_round} rounds. Likely your code or dataset are broken." + ) + # Fill a portion of the buffer + if not fill_reading_buffer(src_iter, log_progress=is_initial_pack): + # Break out of the main loop when the source is exhausted. + break + is_initial_pack = False + + # Create new pre packs if necessary + if len(self._pre_packing_lengths) == 0: + assert len(self._pre_packing_buffer) == 0 + assert len(self._reading_buffer) == self.buffer_size + next_pre_pack() + if len(self._pre_packing_lengths) == 0: + # Retry packing, nothing was returned. + pre_pack_round += 1 + continue + # Reset the pre pack round counter for failing pre_pack_round = 0 - yield from next_final_pack() + yield from next_final_pack() - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + yield from next_final_pack() - # If there are still samples in the partial reading buffer, pre-pack them and yield the - # resulting (partial) packs - if len(self._reading_buffer) > 0: - next_pre_pack() + # If there are still samples in the partial reading buffer, pre-pack them and yield the + # resulting (partial) packs + if len(self._reading_buffer) > 0: + next_pre_pack() - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + yield from next_final_pack() + finally: + if hasattr(src_iter, "close"): + src_iter.close() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. @@ -394,52 +450,96 @@ def assert_can_restore(self): super().assert_can_restore() def restore_sample(self, restore_key: Any) -> T_sample: + trace_span = self.worker_config.worker_trace_span() # We need to store multiple indices to restore a batch. self.assert_can_restore() - if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, pack_idx, *pack_restore_keys = restore_key - id, pack_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - - pack = [] - for inner_idx in pack_restore_keys: - if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx + with trace_span.span( + "PackingDataset.restore_sample", args={"restore_key": restore_key}, level=1 + ): + if inspect.isgeneratorfunction(self.final_packer): + id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ - id, sample_idx, *inner_idx = inner_idx + else: + id, pack_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ - assert isinstance(sample_idx, int) - sample = self.dataset.restore_sample(inner_idx) - if self.sample_encoder is not None: - with self._sample_encoder_sample_index.ctx(sample_idx): - sample = self.sample_encoder(sample) - assert not isinstance(sample, Generator), "Generator not supported" - sample = add_sample_restore_key(sample, sample_idx, src=self) - pack.append(sample) - with self._final_packing_sample_index.ctx(pack_idx): - final_pack = self.final_packer(pack) - if isinstance(final_pack, Generator): - assert inspect.isgeneratorfunction(self.final_packer), ( - f"Generator in {self.final_packer} but not marked as such." - ) - for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - self._final_packing_sample_index.iter_ctx(final_pack, pack_idx) + + with trace_span.span( + "PackingDataset.restore_sample.restore_samples", + args={"len": len(pack_restore_keys)}, + level=2, ): - if cur_batch_sub_idx == pack_sub_idx: - return set_sample_restore_key( - inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, + pack = [] + for inner_sample_idx, inner_idx in enumerate(pack_restore_keys): + if self.sample_encoder is not None: + id, sample_idx, *inner_idx = inner_idx + assert id == type(self).__name__ + assert isinstance(sample_idx, int) + with trace_span.span( + "PackingDataset.restore_sample.dataset", + args={"sample_idx": inner_sample_idx}, + level=2, + ): + sample = self.dataset.restore_sample(inner_idx) + if self.sample_encoder is not None: + with ( + self._sample_encoder_sample_index.ctx(sample_idx), + trace_span.span( + f"PackingDataset.restore_sample.sample_encoder:{self._function_config(self.sample_encoder)}", + args={"sample_idx": sample_idx}, + level=2, + ), + ): + sample = self.sample_encoder(sample) + assert not isinstance(sample, Generator), "Generator not supported" + sample = add_sample_restore_key(sample, sample_idx, src=self) + pack.append(sample) + with ( + self._final_packing_sample_index.ctx(pack_idx), + trace_span.span( + f"PackingDataset.restore_sample.final_packer:{self._function_config(self.final_packer)}", + args={"pack_idx": pack_idx}, + level=2, + ), + ): + final_pack = self.final_packer(pack) + if isinstance(final_pack, Generator): + assert inspect.isgeneratorfunction(self.final_packer), ( + f"Generator in {self.final_packer} but not marked as such." + ) + for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in trace_span.iterable( + self._final_packing_sample_index.iter_ctx(final_pack, pack_idx), + name=f"PackingDataset.restore_sample.final_packer:{self._function_config(self.final_packer)}.next", + level=2, + ): + if cur_batch_sub_idx == pack_sub_idx: + return set_sample_restore_key( + inner_batch_sample, + pack_idx, + pack_sub_idx, + *pack_restore_keys, + src=self, + ) + assert False, f"Pack sub-index {pack_sub_idx} not found in pack" + else: + return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + + def _own_config(self) -> Dict[str, Any]: + return { + "buffer_size": self.buffer_size, + "pre_packer": self._function_config(self.pre_packer), + "final_packer": self._function_config(self.final_packer), + "final_packer_stateless": self.final_packer_stateless, + **( + { + "packer_config": ( + self.packer_config() if callable(self.packer_config) else self.packer_config ) - assert False, f"Pack sub-index {pack_sub_idx} not found in pack" - else: - return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + } + if self.packer_config + else {} + ), + "error_handler": self._function_config(self.error_handler), + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index 1261adf8..613b7450 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar, Union from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -52,11 +53,36 @@ def __len__(self): return len(self.dataset) return int(len(self.dataset) * self.repeats) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + next_args={ + "idx": lambda self: self._index, + }, + call_args={ + "repetition": lambda self: self._repetition, + "inner_len": lambda self: len(self.dataset), + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: assert self.repeats is not None or self.dataset.worker_has_samples(), ( "Cannot repeat empty dataset indefinitely" ) + @trace_iter.wrap_inner( + call_args=lambda stop_after: { + "repetition": self._repetition, + "inner_len": len(self.dataset), + "stop_after": stop_after, + } + ) + def repeat(stop_after: Optional[int]): + for sample in self.dataset: + self._index += 1 + yield sample + + if stop_after is not None and self._index >= stop_after: + break + ds_len = len(self.dataset) while self.repeats is None or self._repetition < self.repeats: @@ -70,22 +96,7 @@ def __iter__(self) -> Iterator[T_sample]: else: stop_after = None - for sample in self.dataset: - self._index += 1 - yield sample - if stop_after is not None and self._index >= stop_after: - break - - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "RepeatDataset.repeat", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._repetition, - "repeats": self.repeats, - } - ) + yield from repeat(stop_after) self._repetition += 1 self._index = 0 @@ -95,6 +106,11 @@ def __iter__(self) -> Iterator[T_sample]: # No more repeats self._repetition = math.ceil(self.repeats) + def _own_config(self) -> Dict[str, Any]: + return { + "repeats": self.repeats, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 1aa41496..11c23343 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -1,9 +1,10 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -18,8 +19,10 @@ class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sam size: int _worker_rng: WorkerRng _active_buffer: SavableSampleBuffer[T_sample] + _iterations: int + _sample_creation: List[int] - _savable_fields = ("_active_buffer", "_worker_rng") + _savable_fields = ("_active_buffer", "_worker_rng", "_iterations", "_sample_creation") def __init__( self, @@ -36,28 +39,55 @@ def __init__( def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) + self._iterations = 0 + self._sample_creation = [] def __len__(self) -> int: return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_creation[-1], + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: self._active_buffer.worker_start() it = iter(self._active_buffer.append_iter()) - while True: - if len(self._active_buffer) >= self.size: - pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) - yield self._active_buffer.pop(pop_idx) - else: - try: - next(it) - except StopIteration: - break + try: + while True: + if len(self._active_buffer) >= self.size: + pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + sample_creation = self._sample_creation.pop(pop_idx) + trace_iter.sample( + self._active_buffer.pop(pop_idx), + { + "idx": pop_idx, + "sample_creation": sample_creation, + "sample_age": self._iterations - sample_creation, + }, + ) + yield self._active_buffer.pop(pop_idx) + else: + try: + next(it) + self._sample_creation.append(self._iterations) + self._iterations += 1 + except StopIteration: + break + finally: + if hasattr(it, "close"): + it.close() while len(self._active_buffer) > 0: pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) yield self._active_buffer.pop(pop_idx) - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self._active_buffer.restore_sample(restore_key) + def _own_config(self) -> Dict[str, Any]: + return { + "size": self.size, + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index e11e7b5d..41c97806 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.watchdog import Watchdog from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -45,6 +46,7 @@ def __len__(self): return len(self.dataset) def _watchdog_trigger(self) -> None: + self.worker_config.worker_trace_span().instant("WatchdogDataset._watchdog_trigger", level=2) if self.fail_on_timeout: # Raising an exception here will kill the whole process raise TimeoutError( @@ -56,7 +58,13 @@ def _watchdog_trigger(self) -> None: RuntimeWarning, ) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + name=lambda self: f"WatchdogDataset({self._function_config(self.dataset)})", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: if self.timeout_seconds is None: yield from self.dataset else: @@ -68,6 +76,13 @@ def __iter__(self) -> Iterator[T_sample]: ) yield from watchdog.watch_iter(self.dataset) + def _own_config(self) -> Dict[str, Any]: + return { + "timeout_seconds": self.timeout_seconds, + "initial_timeout_seconds": self.initial_timeout_seconds, + "fail_on_timeout": self.fail_on_timeout, + } + def config(self) -> Dict[str, Any]: # Watchdog is transparent, it won't change the samples return self.dataset.config() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 43567c9d..cf841041 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1665,18 +1665,22 @@ def test_debug_dataset(self): world_size=1, num_workers=2, worker_log_level=3, - worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.jsonl", + worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.json", + # worker_debug_path="./tmp_worker_debug/{worker_id}.json", ) # Reset this to 0 to make sure the test is deterministic SavableDataLoader._next_id = 0 loader = get_savable_loader( - get_val_dataset( + get_train_dataset( self.dataset_path, split_part="train", batch_size=5, worker_config=worker_config, + shuffle_buffer_size=10, + max_samples_per_sequence=None, + virtual_epoch_length=10, ), ) @@ -1684,35 +1688,35 @@ def test_debug_dataset(self): samples = [[batch.__key__ for batch in loader] for _ in range(2)] print(samples) + del loader + gc.collect() debug_log_path = self.dataset_path / "worker_debug" - assert (debug_log_path / "0.jsonl").is_file() - assert (debug_log_path / "1.jsonl").is_file() - assert (debug_log_path / "2.jsonl").is_file() - - collected_keys_order = [[None] * 10 for _ in range(2)] - with (debug_log_path / "0.jsonl").open() as rf: - for line in rf: - line_data = json.loads(line) - if line_data["t"] == "SavableDataLoader.yield": - print(line_data) - for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["idx"]] is None: - collected_keys_order[i][line_data["idx"]] = line_data["keys"] - break - else: - assert False, "Too many entries for key" - - print(collected_keys_order) - assert collected_keys_order == samples + assert (debug_log_path / "0.json").is_file(), f"{list(debug_log_path.iterdir())}" + assert (debug_log_path / "1.json").is_file(), f"{list(debug_log_path.iterdir())}" + assert (debug_log_path / "2.json").is_file(), f"{list(debug_log_path.iterdir())}" + + collected_keys = defaultdict(list) + with (debug_log_path / "0.json").open() as rf: + raw = json.load(rf) + for entry in raw: + if entry["ph"] == "n" and entry["name"] == "SavableDataLoader.yield": + # print(entry) + collected_keys[entry["args"]["global_sample_idx"]].extend(entry["args"]["keys"]) + assert len(raw) > 0 + dst_keys = [ + [collected_keys[i] for i in range(10)], + [collected_keys[i] for i in range(10, 20)], + ] + + print(dst_keys) + assert dst_keys == samples runner = CliRunner() result = runner.invoke( analyze_debug_command, [ str(debug_log_path), - "--include-modality", - "train,val", "--heatmap-path", str(self.dataset_path / "heatmap.png"), ], diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index c344a066..f3aabd92 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -139,6 +139,7 @@ def test_s3_path_resolution(self): from multistorageclient.rclone import read_rclone_config read_rclone_config.cache_clear() + try: # Test globbing p = EPath("msc://s3/tmp/path/subpath.txt")