Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
044cc13
WIP: Custom dataloader
voegtlel Jul 4, 2025
2ed657a
WIP: Custom dataloader
voegtlel Jul 8, 2025
8cfd43a
WIP: Custom dataloader
voegtlel Jul 9, 2025
c9cc54b
WIP: Custom dataloader
voegtlel Jul 9, 2025
4d97aeb
WIP: Custom dataloader
voegtlel Jul 9, 2025
1d7fda5
Add dataloader test and fix dataloader cleanup
voegtlel Jul 10, 2025
0809440
Add tests for save/restore for forking/threaded data loader
voegtlel Jul 10, 2025
32b2c38
Merge remote-tracking branch 'origin/develop' into feature/custom_dat…
voegtlel Jul 21, 2025
54e98be
WIP Refactoring loader
voegtlel Jul 24, 2025
50f8375
WIP Refactoring loader
voegtlel Jul 24, 2025
be62656
Add pinning
voegtlel Jul 24, 2025
e634631
DataLoader as ctx mgr
voegtlel Jul 24, 2025
ffa25a4
WIP: Thread-local state
voegtlel Jul 30, 2025
139e7f5
WIP: Fix thread local state
voegtlel Aug 12, 2025
e4db0ae
Fix several minor issues,
voegtlel Aug 15, 2025
00263e3
Implement RestoreKey as dataclass. Fix redistribution of checkpoints
voegtlel Aug 19, 2025
3e4ba10
Fix usage of dataclasses.asdict
voegtlel Aug 19, 2025
6436e6a
Fix small issues with threaded worker and thread-local states
voegtlel Aug 20, 2025
d82e7ba
Merge remote-tracking branch 'origin/develop' into feature/custom_dat…
voegtlel Aug 21, 2025
a76d860
Implement group batch buckets as savable object
voegtlel Aug 21, 2025
7c0bec4
Make analyze_debug and debug output work again
voegtlel Aug 21, 2025
c310c88
Add proper favicon
voegtlel Aug 21, 2025
6742388
Publish dataloader.start
voegtlel Aug 22, 2025
338eed8
Clean up rng
voegtlel Aug 25, 2025
ea4367a
Fix rng assignment
philipp-fischer Aug 28, 2025
355e733
Remove THREAD_SAFE
philipp-fischer Aug 29, 2025
dc38c76
Add coverage command
philipp-fischer Sep 1, 2025
7ff6c18
_state_field -> _worker_local_fields
voegtlel Sep 1, 2025
7db2c0e
Fix docs
voegtlel Sep 2, 2025
e8c6c43
Adapt tests to use context managers for loaders
voegtlel Sep 2, 2025
df45aed
Change all tests to pytest. Cleanup a few unused functions. Implement…
voegtlel Sep 4, 2025
e1f8b0e
Merge commit 'd97b66b7d28b975af47688106e19c0c3a866dc42' into feature/…
philipp-fischer Dec 8, 2025
7a07566
Two import fixes
philipp-fischer Dec 8, 2025
00deed0
Ruff
philipp-fischer Dec 8, 2025
7db3f1b
Fix mock file pool and wrapping file stores.
philipp-fischer Dec 8, 2025
e5a30ea
Remove empty file
voegtlel Oct 1, 2025
a3e2980
Fix file store worker init
voegtlel Dec 9, 2025
087a13d
Remove print
voegtlel Dec 9, 2025
f82b86a
Fix SqliteIndexReader shutdown
voegtlel Dec 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,72 @@
[run]
# Source directories to measure coverage for
source = src

# Parallel mode for multiprocessing support
parallel = True

# Branch coverage measurement
branch = True

# Data file location
data_file = .coverage

[report]
include = ./src/megatron/energon/**
# Include source code in the report
include = src/megatron/energon/**

# Minimum coverage percentage to pass
fail_under = 80

# Show missing lines in the report
show_missing = True

# Skip covered files in the report
skip_covered = False

# Skip empty files
skip_empty = True

# Precision for coverage percentages
precision = 2

# Sort order for the report
sort = filename

# Exclude lines from coverage
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain about missing debug-only code:
def __repr__
if self\.debug

# Don't complain if tests don't hit exceptions:
raise .*

# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:

# Don't complain about abstract methods, they aren't run:
@(abc\.)?abstractmethod

# Don't complain about type checking code:
if TYPE_CHECKING:


[html]
# Directory to put the HTML report
directory = htmlcov

# Title for the HTML report
title = Megatron Energon Coverage Report

[xml]
output = ./coverage.xml
# Output file for XML report
output = coverage.xml

[lcov]
# Output file for LCOV report
output = lcov.info
10 changes: 9 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ check: dev-sync

# Execute all unit tests
test: dev-sync
uv run -m unittest discover -v -s tests
uv run pytest tests -v

coverage: dev-sync
COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests
# COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_dataloader.py
# COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_typed_converter.py
uv run -m coverage combine
uv run -m coverage lcov
echo "Coverage LCOV report generated at ./lcov.info"

# Build the docs
docs: dev-sync
Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ dependencies = [

[project.optional-dependencies]
dev = [
"coverage",
"pytest",
"ruff",
"sphinxcontrib-napoleon",
"sphinx",
Expand Down Expand Up @@ -114,3 +116,17 @@ exclude = [
"docs",
]

[tool.coverage.run]
branch = true
parallel = true
concurrency = ["multiprocessing"]

[tool.coverage.report]
show_missing = true
skip_covered = true

[tool.coverage.html]
show_contexts = true

[tool.pytest.ini_options]
addopts = "--tb=native --capture=no"
2 changes: 0 additions & 2 deletions src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
load_dataset,
prepare_metadataset,
)
from megatron.energon.savable_loader import SavableDataLoader
from megatron.energon.source_info import SourceInfo
from megatron.energon.task_encoder import (
AugmentTaskEncoder,
Expand Down Expand Up @@ -189,7 +188,6 @@
"reraise_exception",
"Sample",
"SampleDecoder",
"SavableDataLoader",
"SavableDataset",
"ShuffleBufferDataset",
"SimilarityInterleavedSample",
Expand Down
40 changes: 39 additions & 1 deletion src/megatron/energon/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
T = TypeVar("T")


class FileStore(Generic[T]):
class FileStore(ABC, Generic[T]):
"""Base type for a dataset that can be accessed randomly by sample key."""

@abstractmethod
Expand Down Expand Up @@ -40,6 +40,21 @@ def get_path(self) -> str:
"""Returns the path to the dataset."""
...

@abstractmethod
def worker_init(self) -> None:
"""Initializes the file store for the current worker."""
raise NotImplementedError("worker_init is not implemented for this file store")

@abstractmethod
def worker_close(self) -> None:
"""Closes the file store for the current worker."""
raise NotImplementedError("worker_close is not implemented for this file store")

@abstractmethod
def close(self) -> None:
"""Closes the file store."""
raise NotImplementedError("close is not implemented for this file store")

def get_media_metadata(self, key: str) -> MediaMetadataBase:
"""Return the media metadata for the given key if available."""

Expand Down Expand Up @@ -73,6 +88,15 @@ def _decode_raw(self, data: T, **kwargs) -> T:
"""
return self._inner._decode_raw(data, **kwargs)

def worker_init(self) -> None:
self._inner.worker_init()

def worker_close(self) -> None:
self._inner.worker_close()

def close(self) -> None:
self._inner.close()


@edataclass
class Lazy(Generic[T]):
Expand Down Expand Up @@ -169,6 +193,20 @@ def get_lazy(self, ds: FileStore, fname: str) -> Lazy:
"""
...

@abstractmethod
def worker_init(self) -> None:
"""
Initialize the cache pool for the current worker.
"""
...

@abstractmethod
def worker_close(self) -> None:
"""
Close the cache pool for the current worker.
"""
...

@abstractmethod
def to_cache(self, data: T, name: str) -> Lazy[T]:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/megatron/energon/cache/file_cache_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class FileStoreCachePool(CachePool, ForkMixin):
# Whether the pool is shutting down
_shutting_down: bool = False

_workers_initialized: dict[int, bool] = {}

def __init__(
self,
*,
Expand Down Expand Up @@ -278,6 +280,9 @@ def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> boo
with self._lock:
if self._shutting_down:
return False
if not self._workers_initialized.get(threading.get_ident(), False):
ds.worker_init()
self._workers_initialized[threading.get_ident()] = True

# Perform the data read
if self.method == "raw":
Expand Down Expand Up @@ -363,6 +368,12 @@ def get_lazy(self, ds: FileStore, fname: str) -> FileCacheLazy:

return FileCacheLazy(ds=ds, fname=fname, pool=self, entry=entry)

def worker_init(self) -> None:
pass

def worker_close(self) -> None:
pass

def to_cache(self, data: T, name: str) -> CacheFileLazy[T]:
"""
Move the data to the cache and return a lazy to fetch it later.
Expand Down
13 changes: 11 additions & 2 deletions src/megatron/energon/cache/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]:
file_names=(key,),
)

def worker_init(self) -> None:
pass

def worker_close(self) -> None:
pass

def close(self) -> None:
pass

def get_path(self) -> str:
"""Returns the path to the dataset."""
return str(self.base_dir)
Expand Down Expand Up @@ -157,7 +166,7 @@ def get_path(self) -> str:
def get_media_metadata(self, key: str) -> MediaMetadataBase:
if self._media_metadata_available is None:
try:
has_metadata = self.sqlite_reader.db_has_media_metadata()
has_metadata = self._sqlite_reader.db_has_media_metadata()
except sqlite3.Error as exc: # pragma: no cover - defensive
raise RuntimeError(
"Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`."
Expand All @@ -172,7 +181,7 @@ def get_media_metadata(self, key: str) -> MediaMetadataBase:
self._media_metadata_available = True

try:
row = self.sqlite_reader.get_media_metadata(key)
row = self._sqlite_reader.get_media_metadata(key)
except sqlite3.Error as exc: # pragma: no cover - defensive
raise RuntimeError(
"Failed to load media metadata. Re-run `energon prepare --media-metadata-by-...`."
Expand Down
6 changes: 6 additions & 0 deletions src/megatron/energon/cache/no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any:
def get_lazy(self, ds: FileStore, fname: str) -> DirectLazy:
return DirectLazy(ds=ds, fname=fname, pool=self)

def worker_init(self) -> None:
pass

def worker_close(self) -> None:
pass

def to_cache(self, data: T, name: str) -> DirectLazy:
return MockLazy(fname=name, get_fn=lambda _: data)

Expand Down
15 changes: 15 additions & 0 deletions src/megatron/energon/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .dataloader import DataLoader
from .pin_memory import NoPinMemory, PinMemory, PinMemoryThread
from .workers import DataLoaderWorker, ForkDataLoaderWorker, ThreadDataLoaderWorker

__all__ = [
"DataLoader",
"PinMemory",
"NoPinMemory",
"PinMemoryThread",
"DataLoaderWorker",
"ThreadDataLoaderWorker",
"ForkDataLoaderWorker",
]
14 changes: 14 additions & 0 deletions src/megatron/energon/dataloader/asynchronous/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .base import Asynchronous, QueueProtocol, WorkerCommand, WorkerResult
from .fork import ForkAsynchronous
from .thread import ThreadAsynchronous

__all__ = [
"Asynchronous",
"QueueProtocol",
"WorkerCommand",
"WorkerResult",
"ForkAsynchronous",
"ThreadAsynchronous",
]
Loading