Skip to content
Draft
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
044cc13
WIP: Custom dataloader
voegtlel Jul 4, 2025
2ed657a
WIP: Custom dataloader
voegtlel Jul 8, 2025
8cfd43a
WIP: Custom dataloader
voegtlel Jul 9, 2025
c9cc54b
WIP: Custom dataloader
voegtlel Jul 9, 2025
4d97aeb
WIP: Custom dataloader
voegtlel Jul 9, 2025
1d7fda5
Add dataloader test and fix dataloader cleanup
voegtlel Jul 10, 2025
0809440
Add tests for save/restore for forking/threaded data loader
voegtlel Jul 10, 2025
32b2c38
Merge remote-tracking branch 'origin/develop' into feature/custom_dat…
voegtlel Jul 21, 2025
54e98be
WIP Refactoring loader
voegtlel Jul 24, 2025
50f8375
WIP Refactoring loader
voegtlel Jul 24, 2025
be62656
Add pinning
voegtlel Jul 24, 2025
e634631
DataLoader as ctx mgr
voegtlel Jul 24, 2025
ffa25a4
WIP: Thread-local state
voegtlel Jul 30, 2025
139e7f5
WIP: Fix thread local state
voegtlel Aug 12, 2025
e4db0ae
Fix several minor issues,
voegtlel Aug 15, 2025
00263e3
Implement RestoreKey as dataclass. Fix redistribution of checkpoints
voegtlel Aug 19, 2025
3e4ba10
Fix usage of dataclasses.asdict
voegtlel Aug 19, 2025
6436e6a
Fix small issues with threaded worker and thread-local states
voegtlel Aug 20, 2025
d82e7ba
Merge remote-tracking branch 'origin/develop' into feature/custom_dat…
voegtlel Aug 21, 2025
a76d860
Implement group batch buckets as savable object
voegtlel Aug 21, 2025
7c0bec4
Make analyze_debug and debug output work again
voegtlel Aug 21, 2025
c310c88
Add proper favicon
voegtlel Aug 21, 2025
6742388
Publish dataloader.start
voegtlel Aug 22, 2025
338eed8
Clean up rng
voegtlel Aug 25, 2025
ea4367a
Fix rng assignment
philipp-fischer Aug 28, 2025
355e733
Remove THREAD_SAFE
philipp-fischer Aug 29, 2025
dc38c76
Add coverage command
philipp-fischer Sep 1, 2025
7ff6c18
_state_field -> _worker_local_fields
voegtlel Sep 1, 2025
7db2c0e
Fix docs
voegtlel Sep 2, 2025
e8c6c43
Adapt tests to use context managers for loaders
voegtlel Sep 2, 2025
df45aed
Change all tests to pytest. Cleanup a few unused functions. Implement…
voegtlel Sep 4, 2025
e1f8b0e
Merge commit 'd97b66b7d28b975af47688106e19c0c3a866dc42' into feature/…
philipp-fischer Dec 8, 2025
7a07566
Two import fixes
philipp-fischer Dec 8, 2025
00deed0
Ruff
philipp-fischer Dec 8, 2025
7db3f1b
Fix mock file pool and wrapping file stores.
philipp-fischer Dec 8, 2025
e5a30ea
Remove empty file
voegtlel Oct 1, 2025
a3e2980
Fix file store worker init
voegtlel Dec 9, 2025
087a13d
Remove print
voegtlel Dec 9, 2025
f82b86a
Fix SqliteIndexReader shutdown
voegtlel Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/android-chrome-192x192.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/android-chrome-512x512.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/apple-touch-icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/favicon-16x16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/favicon-32x32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/favicon.ico
Binary file not shown.
1 change: 1 addition & 0 deletions docs/source/_static/site.webmanifest
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"name":"Megatron Energon Dataloader Documentation","short_name":"Megatron Energon","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ED467A","background_color":"#411046","display":"standalone"}
5 changes: 5 additions & 0 deletions docs/source/_templates/favicon.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<link rel="apple-touch-icon" sizes="180x180" href="{{ pathto('_static/apple-touch-icon.png', 1) }}">
<link rel="icon" type="image/png" sizes="32x32" href="{{ pathto('_static/favicon-32x32.png', 1) }}">
<link rel="icon" type="image/png" sizes="16x16" href="{{ pathto('_static/favicon-16x16.png', 1) }}">
<link rel="shortcut icon" href="{{ pathto('_static/favicon.ico', 1) }}">
<link rel="manifest" href="{{ pathto('_static/site.webmanifest', 1) }}">
8 changes: 8 additions & 0 deletions docs/source/_templates/layout.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{# Extend the base theme layout #}
{% extends "!layout.html" %}

{# Add custom favicon links to the head section #}
{% block extrahead %}
{{ super() }}
{% include "favicon.html" %}
{% endblock %}
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@
html_static_path = ["_static"]
html_css_files = ["css/custom.css"]

# Favicon configuration
html_favicon = "_static/favicon.ico"

# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
Expand Down
2 changes: 0 additions & 2 deletions src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
load_dataset,
prepare_metadataset,
)
from megatron.energon.savable_loader import SavableDataLoader
from megatron.energon.source_info import SourceInfo
from megatron.energon.task_encoder import (
AugmentTaskEncoder,
Expand Down Expand Up @@ -184,7 +183,6 @@
"RepeatDataset",
"Sample",
"SampleDecoder",
"SavableDataLoader",
"SavableDataset",
"ShuffleBufferDataset",
"SimilarityInterleavedSample",
Expand Down
15 changes: 15 additions & 0 deletions src/megatron/energon/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .dataloader import DataLoader
from .pin_memory import NoPinMemory, PinMemory, PinMemoryThread
from .workers import DataLoaderWorker, ForkDataLoaderWorker, ThreadDataLoaderWorker

__all__ = [
"DataLoader",
"PinMemory",
"NoPinMemory",
"PinMemoryThread",
"DataLoaderWorker",
"ThreadDataLoaderWorker",
"ForkDataLoaderWorker",
]
14 changes: 14 additions & 0 deletions src/megatron/energon/dataloader/asynchronous/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .base import Asynchronous, QueueProtocol, WorkerCommand, WorkerResult
from .fork import ForkAsynchronous
from .thread import ThreadAsynchronous

__all__ = [
"Asynchronous",
"QueueProtocol",
"WorkerCommand",
"WorkerResult",
"ForkAsynchronous",
"ThreadAsynchronous",
]
287 changes: 287 additions & 0 deletions src/megatron/energon/dataloader/asynchronous/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import threading
import traceback
from abc import abstractmethod
from typing import Any, Callable, ParamSpec, Protocol, TypeVar

from megatron.energon.dataloader.future import CancelledError, Future
from megatron.energon.edataclass import edataclass

P = ParamSpec("P")
T = TypeVar("T")
R = TypeVar("R", covariant=True)


DEBUG_LEVEL = 0


class QueueProtocol(Protocol[T]):
"""Protocol for a queue."""

def get(self, /) -> T: ...

def put(self, item: T, /) -> None: ...

def qsize(self, /) -> int: ...


@edataclass
class WorkerCommand:
"""Internal class for communicating a command to the worker via the command queue."""

cmd: str
args: tuple[Any, ...]
kwargs: dict[str, Any]
future_id: int


@edataclass
class WorkerResult:
"""Internal class for communicating a result from the worker via the result queue."""

future_id: int
result: Any = None
exception: Exception | None = None


class FutureImpl(Future[Any]):
"""Class for returning a future result from the worker.."""

__slots__ = ("_worker", "_future_id", "_result", "_exception", "_cancelled")

_worker: "Asynchronous"
_future_id: int
_result: Any
_exception: Exception

def __init__(self, worker: "Asynchronous", future_id: int):
self._worker = worker
self._future_id = future_id

def get(self) -> Any:
if not hasattr(self, "_result") and not hasattr(self, "_exception"):
self._worker._wait_for_worker_result(self)
if hasattr(self, "_exception"):
raise self._exception
return self._result

def cancel(self) -> bool:
if hasattr(self, "_result") or hasattr(self, "_exception"):
if DEBUG_LEVEL >= 1:
print(
f"[{self._worker._name}, fut={self._future_id}] already has result or exception\n",
end="",
)
return False
self._exception = CancelledError.with_current_traceback()
self._worker._cancel_future(self._future_id)
return True

def done(self) -> bool:
return hasattr(self, "_result") or hasattr(self, "_exception")

def _set_result(self, result: Any) -> None:
self._result = result

def _set_exception(self, exception: Exception) -> None:
self._exception = exception

def __str__(self) -> str:
return f"FutureImpl(worker={self._worker._name!r}, future_id={self._future_id!r}, done={self.done()!r}, exception={getattr(self, '_exception', '<no exception>')})"


class Asynchronous:
"""Asynchronous base class."""

_cmd_queue: QueueProtocol[WorkerCommand]
_result_queue: QueueProtocol[WorkerResult]
_next_future_id: int
_pending_futures: dict[int, FutureImpl]
_name: str
_result_lock: threading.Lock

def _asynchronous_init(self, name: str) -> None:
self._cmd_queue, self._result_queue = self._queues()
self._next_future_id = 0
self._pending_futures = {}
self._name = name
self._result_lock = threading.Lock()

@abstractmethod
def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: ...

def _wait_for_worker_result(self, future: FutureImpl) -> None:
"""
Wait for the result of a future.
If another result comes first, update the corresponding future.

Args:
future: The future to wait for.
"""
if DEBUG_LEVEL >= 1:
print(f"[{self._name}, fut={future._future_id}] waiting for result\n", end="")
with self._result_lock:
if future.done():
# If calling get() from multiple threads, the future may be done now, because
# the other thread already set the result.
return
if DEBUG_LEVEL >= 2:
print(f"[{self._name}, fut={future._future_id}] got future\n", end="")
while True:
res = self._result_queue.get()
fut = self._pending_futures.pop(res.future_id)
if res.exception is not None:
fut._set_exception(res.exception)
else:
fut._set_result(res.result)
if res.future_id == future._future_id:
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}, fut={future._future_id}] got result, return\n", end=""
)
return
else:
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}, fut={future._future_id}] got result for {res.future_id=}, continue\n",
end="",
)
continue

def _cancel_future(self, future_id: int) -> None:
"""Cancel a future."""
if DEBUG_LEVEL >= 1:
print(f"[{self._name}, fut={future_id}] cancelling future\n", end="")
# In case the main process is waiting for thie future to complete, add the result
self._result_queue.put(
WorkerResult(future_id=future_id, exception=CancelledError.with_current_traceback())
)

def _cancel_futures(self) -> None:
"""Cancel all futures after worker shutdown."""
for fut in self._pending_futures.values():
fut.cancel()
self._pending_futures.clear()

def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]:
"""
Call a function in the worker and return a future for getting the result.
The function must be an instance method of `self`. Uses the name to identify the function in the worker
instance.

Args:
fn: The function to call.
*args: The arguments to pass to the function.
**kwargs: The keyword arguments to pass to the function.
"""
self._assert_running()
assert not self._in_worker(), "worker_call must not be called in the worker"
future_id = self._next_future_id
self._next_future_id += 1

self._pending_futures[future_id] = future = FutureImpl(self, future_id)
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}] worker_call {fn.__name__=} {future_id=}\n",
end="",
)
self._cmd_queue.put(
WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id)
)
if DEBUG_LEVEL >= 2:
print(f"[{self._name}] cmd_queue: {self._cmd_queue.qsize()=}\n", end="")
return future

def _worker_run(
self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult]
) -> None:
"""
The worker main loop.
It waits for commands via the command queue and executes them.
The functions to call are identified by their name.
The result of the call is put into the result queue.
The worker exits when the command `_shutdown_worker` is received.

Args:
cmd_queue: The command queue to wait for commands.
result_queue: The result queue to put the results into.
"""
assert self._in_worker(), "_worker_run must be called in the worker"
try:
while True:
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n",
end="",
)
cmd = cmd_queue.get()
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=}\n",
end="",
)
try:
fn = getattr(self, cmd.cmd)
result = fn(*cmd.args, **cmd.kwargs)
except Exception as e:
if DEBUG_LEVEL >= 2:
print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="")
result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e))
if DEBUG_LEVEL >= 2:
print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="")
else:
if DEBUG_LEVEL >= 2:
print(f"[{self._name}, fut={cmd.future_id}] send result\n", end="")
result_queue.put(WorkerResult(future_id=cmd.future_id, result=result))
if DEBUG_LEVEL >= 2:
print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="")
del result
# cmd_queue.task_done()
if cmd.cmd == self._wrk_shutdown_worker.__name__:
if DEBUG_LEVEL >= 1:
print(
f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n",
end="",
)
break
if DEBUG_LEVEL >= 2:
print(
f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n",
end="",
)
except:
traceback.print_exc()
raise

@abstractmethod
def _assert_running(self) -> bool:
"""Check if the execution is within the worker."""
...

@abstractmethod
def _in_worker(self) -> bool:
"""Check if the execution is within the worker."""
...

def _wrk_shutdown_worker(self) -> None:
"""Does nothing. The actual shutdown is handled in the _worker_run method."""
assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker"

def _shutdown_worker(self) -> None:
"""Shutdown the worker. The actual shutdown is handled in the _worker_run method."""
assert not self._in_worker(), "shutdown_worker must not be called in the worker"
# This is not actually a recursive call, because the worker loop will exit before calling this method.
self._worker_call(self._wrk_shutdown_worker).get()
self._cancel_futures()
if DEBUG_LEVEL >= 1:
print(f"[{self._name}] shutdown\n", end="")

@abstractmethod
def start(self) -> None: ...

@abstractmethod
def shutdown(self) -> None: ...

@abstractmethod
def running(self) -> bool: ...
Loading