Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/megatron/energon/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import click

from megatron.energon.tools.analyze_debug import command as analyze_debug_command
from megatron.energon.tools.analyze_debug_merge import command as analyze_debug_merge_command
from megatron.energon.tools.checkpoint import command as checkpoint_command
from megatron.energon.tools.info import command as info_command
from megatron.energon.tools.lint import command as lint_command
Expand Down Expand Up @@ -51,6 +52,7 @@ def main(ctx):


main.add_command(analyze_debug_command)
main.add_command(analyze_debug_merge_command)
main.add_command(checkpoint_command)
main.add_command(lint_command)
main.add_command(info_command)
Expand Down
8 changes: 5 additions & 3 deletions src/megatron/energon/flavors/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import functools
import inspect
import typing
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -345,6 +346,10 @@ def _function_config(fn: Callable) -> str:
mod_name = getattr(fn, "__module__", "<unknown>")
return f"{mod_name}.{getattr(fn, '__qualname__', getattr(fn, '__name__', '<unknown>'))}"

@staticmethod
def _function_config_short(fn: Callable) -> str:
return getattr(fn, "__qualname__", getattr(fn, "__name__", "<unknown>"))

@abstractmethod
def config(self) -> Dict[str, Any]:
"""Return a config dict that can be used to check if datasets have the same settings.
Expand Down Expand Up @@ -447,9 +452,6 @@ def legacy_handler(
],
) -> Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]:
"""Safely returns the new style three argument handler. If the handler takes 2 arguments, it wraps it."""
import functools
import inspect

handler_sig = inspect.signature(handler)
if len(handler_sig.parameters) != 3:
original_handler = handler
Expand Down
188 changes: 188 additions & 0 deletions src/megatron/energon/flavors/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.tracing import AsyncContext

T_dataset = TypeVar("T_dataset", bound=SavableDataset)
T_call = TypeVar("T_call", bound=Callable)
T_sample = TypeVar("T_sample")


def _flatten_str_list(keys: Any) -> Iterator[Optional[str]]:
"""Flatten a list of keys into a list of strings."""
if isinstance(keys, str):
yield keys
elif isinstance(keys, (list, tuple)):
for key in keys:
yield from _flatten_str_list(key)
else:
yield None


def _flatten_str_list_or_none(keys: Any) -> Optional[List[str]]:
"""Flatten a list of keys into a list of strings. If this cannot be fetched, return None."""
keys = list(_flatten_str_list(keys))
if any(k is None for k in keys):
return None
return keys


def default_get_keys(batch: Any) -> Optional[List[str]]:
"""Default get_keys, which has some heuristics to find the sample keys."""
if isinstance(batch, list):
all_keys = []
for b in batch:
k = default_get_keys(b)
if k is None:
return None
all_keys.extend(k)
return all_keys
if hasattr(batch, "__key__"):
return _flatten_str_list_or_none(batch.__key__)
elif hasattr(batch, "__keys__"):
return _flatten_str_list_or_none(batch.__keys__)
elif isinstance(batch, dict):
if "__key__" in batch:
return _flatten_str_list_or_none(batch["__key__"])
elif "__keys__" in batch:
return _flatten_str_list_or_none(batch["__keys__"])
elif "keys" in batch:
return _flatten_str_list_or_none(batch["keys"])
return None


class TraceIter:
last_args: Dict[str, Any] = {}

def __init__(
self,
outer_self: T_dataset,
name: str,
trace_span: AsyncContext,
call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]],
):
self.outer_self = outer_self
self.name = name
self.trace_span = trace_span
self.call_args = call_args

def sample_exception(
self, exception: Exception, samples: Union[T_sample, Sequence[T_sample]]
) -> None:
self.trace_span.instant(
f"{self.name}.error/skip",
args={
"exception": f"{type(exception).__name__}: {str(exception)}",
"sample_keys": default_get_keys(samples),
**{
arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value
for arg_name, arg_value in self.call_args.items()
},
},
level=1,
)

def skip_sample(self, samples: Sequence[T_sample]) -> None:
self.trace_span.instant(
f"{self.name}.skip",
args={
"sample_keys": default_get_keys(samples),
},
level=1,
)

def sample(
self, sample: Union[T_sample, Sequence[T_sample]], args: Dict[str, Any] = {}
) -> None:
self.last_args["sample_keys"] = default_get_keys(sample)
self.last_args.update(args)

def wrap_fn(self, fn: T_call) -> T_call:
fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "<unknown>"))

@functools.wraps(fn)
def wrapped_fn(*args, **kwargs):
with self.trace_span.span(
f"{self.name}.{fn_name}.call",
args={
arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value
for arg_name, arg_value in self.call_args.items()
},
level=2,
):
return fn(*args, **kwargs)

return wrapped_fn

def wrap_inner(self, call_args: Callable[..., Dict[str, Any]] = lambda *args, **kwargs: {}):
def decorator(fn):
fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "<unknown>"))

@functools.wraps(fn)
def wrapped_inner_gen(*args, **kwargs):
with self.trace_span.span(
f"{self.name}.{fn_name}.__iter__", args=call_args(*args, **kwargs), level=2
):
return fn(*args, **kwargs)

return wrapped_inner_gen

return decorator


def trace_iter(
name: Callable[[T_dataset], str] = lambda ds: type(ds).__name__,
call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {},
next_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {},
) -> Callable[
[Callable[[T_dataset, TraceIter], Iterator[T_sample]]],
Callable[[T_dataset], Iterator[T_sample]],
]:
"""Decorator for SavableDataset.__iter__ to trace the iteration using the worker config."""

def decorator(
iter_fn: Callable[[T_dataset, TraceIter], Iterator[T_sample]],
) -> Callable[[T_dataset], Iterator[T_sample]]:
@functools.wraps(iter_fn)
def wrapper(self: T_dataset) -> Iterator[T_sample]:
trace_span = self.worker_config.worker_trace_span()
span_name = name(self)
trace_iter = TraceIter(self, span_name, trace_span, call_args)
with (
trace_span.span(
f"{span_name}.__iter__",
args={
arg_name: arg_value(self) if callable(arg_value) else arg_value
for arg_name, arg_value in call_args.items()
},
level=1,
),
self.worker_config.worker_trace_writer().generator(
f"{span_name}.__iter__.next",
level=2,
) as trace_gen,
):
for sample in trace_span.iterable(
iter_fn(self, trace_iter),
name=f"{span_name}.__iter__.loop",
level=2,
):
with trace_gen.yield_(
last_args={
**{
arg_name: arg_value(self) if callable(arg_value) else arg_value
for arg_name, arg_value in next_args.items()
},
**trace_iter.last_args,
},
):
trace_iter.last_args.clear()
yield sample

return wrapper

return decorator
Loading
Loading