diff --git a/.coveragerc b/.coveragerc index 0e334746..27e98e10 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,72 @@ +[run] +# Source directories to measure coverage for +source = src + +# Parallel mode for multiprocessing support +parallel = True + +# Branch coverage measurement +branch = True + +# Data file location +data_file = .coverage + [report] -include = ./src/megatron/energon/** +# Include source code in the report +include = src/megatron/energon/** + +# Minimum coverage percentage to pass +fail_under = 80 + +# Show missing lines in the report +show_missing = True + +# Skip covered files in the report +skip_covered = False + +# Skip empty files +skip_empty = True + +# Precision for coverage percentages +precision = 2 + +# Sort order for the report +sort = filename + +# Exclude lines from coverage +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit exceptions: + raise .* + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + # Don't complain about abstract methods, they aren't run: + @(abc\.)?abstractmethod + + # Don't complain about type checking code: + if TYPE_CHECKING: + + +[html] +# Directory to put the HTML report +directory = htmlcov + +# Title for the HTML report +title = Megatron Energon Coverage Report [xml] -output = ./coverage.xml \ No newline at end of file +# Output file for XML report +output = coverage.xml + +[lcov] +# Output file for LCOV report +output = lcov.info diff --git a/justfile b/justfile index d1e43ff7..1fe3896e 100644 --- a/justfile +++ b/justfile @@ -25,7 +25,15 @@ check: dev-sync # Execute all unit tests test: dev-sync - uv run -m unittest discover -v -s tests + uv run pytest tests -v + +coverage: dev-sync + COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests + # COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_dataloader.py + # COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_typed_converter.py + uv run -m coverage combine + uv run -m coverage lcov + echo "Coverage LCOV report generated at ./lcov.info" # Build the docs docs: dev-sync diff --git a/pyproject.toml b/pyproject.toml index 793fb161..698fb8fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dependencies = [ [project.optional-dependencies] dev = [ + "coverage", + "pytest", "ruff", "sphinxcontrib-napoleon", "sphinx", @@ -114,3 +116,17 @@ exclude = [ "docs", ] +[tool.coverage.run] +branch = true +parallel = true +concurrency = ["multiprocessing"] + +[tool.coverage.report] +show_missing = true +skip_covered = true + +[tool.coverage.html] +show_contexts = true + +[tool.pytest.ini_options] +addopts = "--tb=native --capture=no" diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index 0d92c702..123ef788 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -70,7 +70,6 @@ load_dataset, prepare_metadataset, ) -from megatron.energon.savable_loader import SavableDataLoader from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder import ( AugmentTaskEncoder, @@ -189,7 +188,6 @@ "reraise_exception", "Sample", "SampleDecoder", - "SavableDataLoader", "SavableDataset", "ShuffleBufferDataset", "SimilarityInterleavedSample", diff --git a/src/megatron/energon/cache/base.py b/src/megatron/energon/cache/base.py index c4bf4ec4..3de61184 100644 --- a/src/megatron/energon/cache/base.py +++ b/src/megatron/energon/cache/base.py @@ -11,7 +11,7 @@ T = TypeVar("T") -class FileStore(Generic[T]): +class FileStore(ABC, Generic[T]): """Base type for a dataset that can be accessed randomly by sample key.""" @abstractmethod @@ -40,6 +40,21 @@ def get_path(self) -> str: """Returns the path to the dataset.""" ... + @abstractmethod + def worker_init(self) -> None: + """Initializes the file store for the current worker.""" + raise NotImplementedError("worker_init is not implemented for this file store") + + @abstractmethod + def worker_close(self) -> None: + """Closes the file store for the current worker.""" + raise NotImplementedError("worker_close is not implemented for this file store") + + @abstractmethod + def close(self) -> None: + """Closes the file store.""" + raise NotImplementedError("close is not implemented for this file store") + def get_media_metadata(self, key: str) -> MediaMetadataBase: """Return the media metadata for the given key if available.""" @@ -73,6 +88,15 @@ def _decode_raw(self, data: T, **kwargs) -> T: """ return self._inner._decode_raw(data, **kwargs) + def worker_init(self) -> None: + self._inner.worker_init() + + def worker_close(self) -> None: + self._inner.worker_close() + + def close(self) -> None: + self._inner.close() + @edataclass class Lazy(Generic[T]): @@ -169,6 +193,20 @@ def get_lazy(self, ds: FileStore, fname: str) -> Lazy: """ ... + @abstractmethod + def worker_init(self) -> None: + """ + Initialize the cache pool for the current worker. + """ + ... + + @abstractmethod + def worker_close(self) -> None: + """ + Close the cache pool for the current worker. + """ + ... + @abstractmethod def to_cache(self, data: T, name: str) -> Lazy[T]: """ diff --git a/src/megatron/energon/cache/file_cache_pool.py b/src/megatron/energon/cache/file_cache_pool.py index e075759c..68b5af9e 100644 --- a/src/megatron/energon/cache/file_cache_pool.py +++ b/src/megatron/energon/cache/file_cache_pool.py @@ -157,6 +157,8 @@ class FileStoreCachePool(CachePool, ForkMixin): # Whether the pool is shutting down _shutting_down: bool = False + _workers_initialized: dict[int, bool] = {} + def __init__( self, *, @@ -278,6 +280,9 @@ def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> boo with self._lock: if self._shutting_down: return False + if not self._workers_initialized.get(threading.get_ident(), False): + ds.worker_init() + self._workers_initialized[threading.get_ident()] = True # Perform the data read if self.method == "raw": @@ -363,6 +368,12 @@ def get_lazy(self, ds: FileStore, fname: str) -> FileCacheLazy: return FileCacheLazy(ds=ds, fname=fname, pool=self, entry=entry) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + def to_cache(self, data: T, name: str) -> CacheFileLazy[T]: """ Move the data to the cache and return a lazy to fetch it later. diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index 556b0bdd..82b1b052 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -83,6 +83,15 @@ def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]: file_names=(key,), ) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + + def close(self) -> None: + pass + def get_path(self) -> str: """Returns the path to the dataset.""" return str(self.base_dir) @@ -157,7 +166,7 @@ def get_path(self) -> str: def get_media_metadata(self, key: str) -> MediaMetadataBase: if self._media_metadata_available is None: try: - has_metadata = self.sqlite_reader.db_has_media_metadata() + has_metadata = self._sqlite_reader.db_has_media_metadata() except sqlite3.Error as exc: # pragma: no cover - defensive raise RuntimeError( "Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`." @@ -172,7 +181,7 @@ def get_media_metadata(self, key: str) -> MediaMetadataBase: self._media_metadata_available = True try: - row = self.sqlite_reader.get_media_metadata(key) + row = self._sqlite_reader.get_media_metadata(key) except sqlite3.Error as exc: # pragma: no cover - defensive raise RuntimeError( "Failed to load media metadata. Re-run `energon prepare --media-metadata-by-...`." diff --git a/src/megatron/energon/cache/no_cache.py b/src/megatron/energon/cache/no_cache.py index ba14916a..6e461923 100644 --- a/src/megatron/energon/cache/no_cache.py +++ b/src/megatron/energon/cache/no_cache.py @@ -47,6 +47,12 @@ def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any: def get_lazy(self, ds: FileStore, fname: str) -> DirectLazy: return DirectLazy(ds=ds, fname=fname, pool=self) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + def to_cache(self, data: T, name: str) -> DirectLazy: return MockLazy(fname=name, get_fn=lambda _: data) diff --git a/src/megatron/energon/dataloader/__init__.py b/src/megatron/energon/dataloader/__init__.py new file mode 100644 index 00000000..9839f8e2 --- /dev/null +++ b/src/megatron/energon/dataloader/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .dataloader import DataLoader +from .pin_memory import NoPinMemory, PinMemory, PinMemoryThread +from .workers import DataLoaderWorker, ForkDataLoaderWorker, ThreadDataLoaderWorker + +__all__ = [ + "DataLoader", + "PinMemory", + "NoPinMemory", + "PinMemoryThread", + "DataLoaderWorker", + "ThreadDataLoaderWorker", + "ForkDataLoaderWorker", +] diff --git a/src/megatron/energon/dataloader/asynchronous/__init__.py b/src/megatron/energon/dataloader/asynchronous/__init__.py new file mode 100644 index 00000000..f77fef80 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .base import Asynchronous, QueueProtocol, WorkerCommand, WorkerResult +from .fork import ForkAsynchronous +from .thread import ThreadAsynchronous + +__all__ = [ + "Asynchronous", + "QueueProtocol", + "WorkerCommand", + "WorkerResult", + "ForkAsynchronous", + "ThreadAsynchronous", +] diff --git a/src/megatron/energon/dataloader/asynchronous/base.py b/src/megatron/energon/dataloader/asynchronous/base.py new file mode 100644 index 00000000..4ba11813 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/base.py @@ -0,0 +1,287 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import threading +import traceback +from abc import abstractmethod +from typing import Any, Callable, ParamSpec, Protocol, TypeVar + +from megatron.energon.dataloader.future import CancelledError, Future +from megatron.energon.edataclass import edataclass + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +DEBUG_LEVEL = 0 + + +class QueueProtocol(Protocol[T]): + """Protocol for a queue.""" + + def get(self, /) -> T: ... + + def put(self, item: T, /) -> None: ... + + def qsize(self, /) -> int: ... + + +@edataclass +class WorkerCommand: + """Internal class for communicating a command to the worker via the command queue.""" + + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int + + +@edataclass +class WorkerResult: + """Internal class for communicating a result from the worker via the result queue.""" + + future_id: int + result: Any = None + exception: Exception | None = None + + +class FutureImpl(Future[Any]): + """Class for returning a future result from the worker..""" + + __slots__ = ("_worker", "_future_id", "_result", "_exception", "_cancelled") + + _worker: "Asynchronous" + _future_id: int + _result: Any + _exception: Exception + + def __init__(self, worker: "Asynchronous", future_id: int): + self._worker = worker + self._future_id = future_id + + def get(self) -> Any: + if not hasattr(self, "_result") and not hasattr(self, "_exception"): + self._worker._wait_for_worker_result(self) + if hasattr(self, "_exception"): + raise self._exception + return self._result + + def cancel(self) -> bool: + if hasattr(self, "_result") or hasattr(self, "_exception"): + if DEBUG_LEVEL >= 1: + print( + f"[{self._worker._name}, fut={self._future_id}] already has result or exception\n", + end="", + ) + return False + self._exception = CancelledError.with_current_traceback() + self._worker._cancel_future(self._future_id) + return True + + def done(self) -> bool: + return hasattr(self, "_result") or hasattr(self, "_exception") + + def _set_result(self, result: Any) -> None: + self._result = result + + def _set_exception(self, exception: Exception) -> None: + self._exception = exception + + def __str__(self) -> str: + return f"FutureImpl(worker={self._worker._name!r}, future_id={self._future_id!r}, done={self.done()!r}, exception={getattr(self, '_exception', '')})" + + +class Asynchronous: + """Asynchronous base class.""" + + _cmd_queue: QueueProtocol[WorkerCommand] + _result_queue: QueueProtocol[WorkerResult] + _next_future_id: int + _pending_futures: dict[int, FutureImpl] + _name: str + _result_lock: threading.Lock + + def _asynchronous_init(self, name: str) -> None: + self._cmd_queue, self._result_queue = self._queues() + self._next_future_id = 0 + self._pending_futures = {} + self._name = name + self._result_lock = threading.Lock() + + @abstractmethod + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: ... + + def _wait_for_worker_result(self, future: FutureImpl) -> None: + """ + Wait for the result of a future. + If another result comes first, update the corresponding future. + + Args: + future: The future to wait for. + """ + if DEBUG_LEVEL >= 1: + print(f"[{self._name}, fut={future._future_id}] waiting for result\n", end="") + with self._result_lock: + if future.done(): + # If calling get() from multiple threads, the future may be done now, because + # the other thread already set the result. + return + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={future._future_id}] got future\n", end="") + while True: + res = self._result_queue.get() + fut = self._pending_futures.pop(res.future_id) + if res.exception is not None: + fut._set_exception(res.exception) + else: + fut._set_result(res.result) + if res.future_id == future._future_id: + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={future._future_id}] got result, return\n", end="" + ) + return + else: + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={future._future_id}] got result for {res.future_id=}, continue\n", + end="", + ) + continue + + def _cancel_future(self, future_id: int) -> None: + """Cancel a future.""" + if DEBUG_LEVEL >= 1: + print(f"[{self._name}, fut={future_id}] cancelling future\n", end="") + # In case the main process is waiting for thie future to complete, add the result + self._result_queue.put( + WorkerResult(future_id=future_id, exception=CancelledError.with_current_traceback()) + ) + + def _cancel_futures(self) -> None: + """Cancel all futures after worker shutdown.""" + for fut in self._pending_futures.values(): + fut.cancel() + self._pending_futures.clear() + + def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Call a function in the worker and return a future for getting the result. + The function must be an instance method of `self`. Uses the name to identify the function in the worker + instance. + + Args: + fn: The function to call. + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + """ + self._assert_running() + assert not self._in_worker(), "worker_call must not be called in the worker" + future_id = self._next_future_id + self._next_future_id += 1 + + self._pending_futures[future_id] = future = FutureImpl(self, future_id) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] worker_call {fn.__name__=} {future_id=}\n", + end="", + ) + self._cmd_queue.put( + WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) + ) + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] cmd_queue: {self._cmd_queue.qsize()=}\n", end="") + return future + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + """ + The worker main loop. + It waits for commands via the command queue and executes them. + The functions to call are identified by their name. + The result of the call is put into the result queue. + The worker exits when the command `_shutdown_worker` is received. + + Args: + cmd_queue: The command queue to wait for commands. + result_queue: The result queue to put the results into. + """ + assert self._in_worker(), "_worker_run must be called in the worker" + try: + while True: + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", + end="", + ) + cmd = cmd_queue.get() + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=}\n", + end="", + ) + try: + fn = getattr(self, cmd.cmd) + result = fn(*cmd.args, **cmd.kwargs) + except Exception as e: + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") + else: + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={cmd.future_id}] send result\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") + del result + # cmd_queue.task_done() + if cmd.cmd == self._wrk_shutdown_worker.__name__: + if DEBUG_LEVEL >= 1: + print( + f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", + end="", + ) + break + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", + end="", + ) + except: + traceback.print_exc() + raise + + @abstractmethod + def _assert_running(self) -> bool: + """Check if the execution is within the worker.""" + ... + + @abstractmethod + def _in_worker(self) -> bool: + """Check if the execution is within the worker.""" + ... + + def _wrk_shutdown_worker(self) -> None: + """Does nothing. The actual shutdown is handled in the _worker_run method.""" + assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" + + def _shutdown_worker(self) -> None: + """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" + assert not self._in_worker(), "shutdown_worker must not be called in the worker" + # This is not actually a recursive call, because the worker loop will exit before calling this method. + self._worker_call(self._wrk_shutdown_worker).get() + self._cancel_futures() + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutdown\n", end="") + + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def shutdown(self) -> None: ... + + @abstractmethod + def running(self) -> bool: ... diff --git a/src/megatron/energon/dataloader/asynchronous/fork.py b/src/megatron/energon/dataloader/asynchronous/fork.py new file mode 100644 index 00000000..a99625e6 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/fork.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import multiprocessing +import os +import sys +import threading +import warnings + +import torch.multiprocessing + +from megatron.energon.dataloader.asynchronous.base import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) + +DEBUG_LEVEL = 1 + + +class ForkAsynchronous(Asynchronous): + """Mixin for asynchronous workers that use processes.""" + + _process: multiprocessing.Process | None = None + _cmd_queue: multiprocessing.Queue + _result_queue: multiprocessing.Queue + + _threaded_shutdown: threading.Thread | None = None + + _spawning_process: int + + def _asynchronous_init(self, name: str) -> None: + super()._asynchronous_init(name) + self._spawning_process = os.getpid() + + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return torch.multiprocessing.Queue(), torch.multiprocessing.Queue() + + def _check_parent_process(self, evt_exit: threading.Event) -> None: + """Check if the parent process is alive. If it is dead, exit the worker process.""" + parent_proc = torch.multiprocessing.parent_process() + parent_pid = os.getppid() + if parent_proc is None: + print(f"[{self._name}] No parent process, exiting", file=sys.stderr) + os._exit(-1) + while not evt_exit.wait(1): + if parent_proc.exitcode is not None or os.getppid() != parent_pid: + print(f"[{self._name}] Parent process died, exiting", file=sys.stderr) + os._exit(-1) + + def _worker_run( + self, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + ) -> None: + try: + from torch.utils.data._utils import signal_handling + + signal_handling._set_worker_signal_handlers() + except (ImportError, AttributeError): + pass + + try: + torch.multiprocessing._set_thread_name("pt_data_worker") + except (ImportError, AttributeError): + pass + + # Disable torch internal multithreading, it may deadlock the forked process. + torch.set_num_threads(1) + + # cmd_queue is read only, so we can cancel the join thread. + cmd_queue.cancel_join_thread() + worker_exit_evt = threading.Event() + parent_check_thread = threading.Thread( + target=self._check_parent_process, args=(worker_exit_evt,), daemon=True + ) + parent_check_thread.start() + try: + super()._worker_run(cmd_queue, result_queue) + finally: + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down\n", end="") + worker_exit_evt.set() + if DEBUG_LEVEL >= 1: + print( + f"[{self._name}] shutting down, wait for parent_check_thread\n", + end="", + ) + parent_check_thread.join() + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down, close queues\n", end="") + result_queue.close() + result_queue.join_thread() + cmd_queue.close() + cmd_queue.cancel_join_thread() + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down, done\n", end="") + + def _in_worker(self) -> bool: + return torch.multiprocessing.current_process() == self._process + + def start(self) -> None: + torch.multiprocessing.set_start_method("fork", force=True) + orig_num_threads = torch.get_num_threads() + # Disable torch internal multithreading, it may deadlock the forked process. + torch.set_num_threads(1) + self._process = torch.multiprocessing.Process( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"ForkDataLoaderWorker-{self._name}", + ) + self._process.start() + # Revert the original number of threads in the main process. + torch.set_num_threads(orig_num_threads) + + def shutdown(self, in_del: bool = False) -> None: + if self._spawning_process != os.getpid(): + # Should avoid forked process containing a forked worker on exit. + warnings.warn( + "Shutting down worker from a different process than the one that spawned it, skipping" + ) + return + if self._process is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", + end="", + file=sys.stderr, + ) + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + # Kill the process, because we cannot communicate with it in the gc loop. + self._process.terminate() + self._process = None + self._cancel_futures() + else: + try: + self._shutdown_worker() + except Exception: + self._process.join(10) + if self._process.is_alive(): + self._process.terminate() + else: + self._process.join() + assert self._process.exitcode == 0, ( + f"Process exit code {self._process.exitcode}" + ) + self._process = None + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + + def running(self) -> bool: + return self._process is not None + + def _assert_running(self) -> None: + assert self._process is not None, "Worker must be started first" + assert self._process.is_alive(), "Worker died" diff --git a/src/megatron/energon/dataloader/asynchronous/thread.py b/src/megatron/energon/dataloader/asynchronous/thread.py new file mode 100644 index 00000000..8b108659 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/thread.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import queue +import sys +import threading +import warnings + +from megatron.energon.dataloader.asynchronous.base import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) + + +class ThreadAsynchronous(Asynchronous): + """Mixin for asynchronous workers that use threads.""" + + _thread: threading.Thread | None = None + + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return queue.Queue(), queue.Queue() + + def _in_worker(self) -> bool: + return threading.current_thread() == self._thread + + def start(self) -> None: + self._thread = threading.Thread( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"{self._name}", + ) + self._thread.start() + + def shutdown(self, in_del: bool = False) -> None: + if self._thread is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking threads.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", + end="", + file=sys.stderr, + ) + # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. + self._cmd_queue.put( + WorkerCommand( + cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 + ) + ) + self._cancel_futures() + self._thread = None + else: + self._shutdown_worker() + self._thread.join() + self._thread = None + + def running(self) -> bool: + return self._thread is not None + + def _assert_running(self) -> None: + assert self._thread is not None, "Thread must be started first" + assert self._thread.is_alive(), "Thread died" diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py new file mode 100644 index 00000000..86ee950a --- /dev/null +++ b/src/megatron/energon/dataloader/dataloader.py @@ -0,0 +1,668 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import os +import sys +import warnings +from typing import ( + Any, + ClassVar, + Generator, + Generic, + Literal, + Protocol, + Sequence, + TypeVar, +) + +import torch.distributed + +from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.future import Future +from megatron.energon.dataloader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread +from megatron.energon.dataloader.workers.base_worker import ( + DataLoaderNoWorker, + DataLoaderWorker, + WorkerSampleRestoreKey, + WorkerState, +) +from megatron.energon.dataloader.workers.fork_worker import ForkDataLoaderWorker +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key +from megatron.energon.wrappers.batch_dataset import BatchDataset +from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset +from megatron.energon.wrappers.log_sample_dataset import default_get_batch_keys +from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset + +TSample = TypeVar("TSample", covariant=True) + +DEBUG_LEVEL = 1 + + +@edataclass +class RankState: + """ + State of a rank. + """ + + #: This is a list (per worker) of lists of (batch) sample keys, which have been (asynchronously) prefetched from workers + # but not been fetched yet by iterating. + prefetched_restore_keys: list[list[RestoreKey | None]] + #: This is a list of worker states, which have been saved from the workers (or `None` for the initial state). + worker_states: list[WorkerState | None] + #: The next worker ID to prefetch from (i.e. append to the prefetched samples). + next_worker_id: int + #: The micro batch size of the dataset, or `None` if not known. Needed for redistributing the state. + micro_batch_size: int | None + + +class WorkerType(Protocol[TSample]): + """Protocol for a worker type, i.e. for the constructor of a worker class.""" + + def __call__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ) -> DataLoaderWorker[TSample]: ... + + +class DataLoader(Generic[TSample]): + """ + Implementation for a data loader. Orchestrates the workers for prefetching samples. + Opposing the `torch.utils.data.DataLoader`, this loader needs explicit shutdown when done, + to avoid leaking workers (fixes a bug). + """ + + _next_id: ClassVar[int] = 0 + _id: int + _next_epoch_id: int = 0 + + _workers: list[DataLoaderWorker[TSample]] | None = None + _exhausted_workers: list[bool] + _next_worker_id: int = 0 + + _restore_state: RankState | None = None + + _dataset: SavableDataset + _worker_config: WorkerConfig + _prefetch_factor: int + _worker_type: WorkerType + _prefetching_samples: list[list[Future[TSample]]] + _pin_memory: PinMemory[TSample] + + _current_epoch_iter: Generator[TSample, None, None] | None = None + + _spawning_process: int + + _global_sample_idx: int = 0 + + def __init__( + self, + dataset: SavableDataset, + *, + prefetch_factor: int = 1, + worker_type: WorkerType = ForkDataLoaderWorker, + cache_pool: CachePool | None = None, + # Garbage collection configuration + gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, + gc_freeze_at_start: bool = True, + # Watchdog configuration + watchdog_timeout_seconds: float | None = 60, + watchdog_initial_timeout_seconds: float | None = None, + fail_on_timeout: bool = False, + # Pin memory configuration + pin_memory: PinMemory[TSample] | None | Literal["automatic"] = "automatic", + ): + """ + Create the dataloader supporting saving and restoring the state. + + Args: + dataset: The dataset to load. The loader takes ownership of the dataset, i.e. it cannot be shared and will be closed on shutdown. + prefetch_factor: The number of samples to prefetch from each worker. + worker_type: The type of worker to use. + cache_pool: If set, the cache pool to use for the dataset. + gc_collect_every_n_steps: The number of steps after which the garbage collector is + called. As we're usually handling large (but few) tensors here, and the python + garbage collection is already full of objects just by importing, this can improve + the memory footprint quite a lot, and may even be necessary to avoid memory + overflow. + gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker + processes. This improves the garbage collection performance by a lot. + In rare cases, this may cause issues and can be disabled. Keep enabled if you + experience no issues. + watchdog_timeout_seconds: The timeout in seconds. If `None`, the watchdog is disabled. + watchdog_initial_timeout_seconds: The initial timeout in seconds. If `None`, the timeout is the same as `watchdog_timeout_seconds`. + fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + pin_memory: The memory pinner to use. If `None`, no memory is not pinned. + If "automatic", the memory is pinned automatically if cuda is available. + If a `PinMemory` instance, the instance may only be used for one `DataLoader`. + """ + self._id = DataLoader._next_id + DataLoader._next_id += 1 + + if getattr(dataset, "__dataloader_id", None) is not None: + raise ValueError( + f"Dataset {dataset} is already associated with dataloader {getattr(dataset, '__dataloader_id')}. Initialize one dataset per dataloader." + ) + setattr(dataset, "__dataloader_id", self._id) + + if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: + worker_type = DataLoaderNoWorker + + if watchdog_timeout_seconds is not None: + dataset = WatchdogDataset( + dataset, + worker_config=dataset.worker_config, + timeout_seconds=watchdog_timeout_seconds, + initial_timeout_seconds=watchdog_initial_timeout_seconds, + fail_on_timeout=fail_on_timeout, + ) + + if gc_collect_every_n_steps > 0: + dataset = GcDataset( + dataset, + worker_config=dataset.worker_config, + every_n_iter=gc_collect_every_n_steps, + freeze=gc_freeze_at_start, + ) + + self._dataset = dataset + self._worker_config = dataset.worker_config + self._prefetch_factor = prefetch_factor + self._worker_type = worker_type + self._cache_pool = cache_pool + self._prefetching_samples = [[] for _ in range(self._worker_config.safe_num_workers)] + self._exhausted_workers = [False] * self._worker_config.safe_num_workers + if pin_memory == "automatic": + # Automatic pinning + if torch.cuda.is_available(): + # Use cuda + self._pin_memory = PinMemoryThread(torch.device("cuda")) + else: + self._pin_memory = NoPinMemory() + else: + if pin_memory is None: + self._pin_memory = NoPinMemory() + else: + self._pin_memory = pin_memory + + if self._worker_config.num_workers == 0: + assert prefetch_factor == 1, "prefetch_factor must be 1 for num_workers == 0" + else: + assert prefetch_factor > 0, "prefetch_factor must be > 0 for num_workers > 0" + + self._spawning_process = os.getpid() + + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.__init__", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "config": dataset.config(), + } + ) + + def start(self) -> None: + """Start the workers and restore the state if available.""" + self._workers = [ + self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) + for local_worker_id in range(self._worker_config.safe_num_workers) + ] + for worker in self._workers: + worker.start() + + if self._restore_state is None: + worker_states = [None] * self._worker_config.safe_num_workers + else: + worker_states = self._restore_state.worker_states + + assert len(worker_states) == self._worker_config.safe_num_workers, ( + "Number of initial states must match number of workers" + ) + + for worker, worker_state in zip(self._workers, worker_states): + worker.dataset_init(worker_state) + + if self._restore_state is not None: + self._prefetching_samples = [ + [ + self._pin_memory(self._restore_sample(restore_key)) + for restore_key in prefetched_restore_keys + ] + for prefetched_restore_keys in self._restore_state.prefetched_restore_keys + ] + self._next_worker_id = self._restore_state.next_worker_id + self._exhausted_workers = [ + False if worker_state is None else worker_state.exhausted + for worker_state in worker_states + ] + # State was restored, clear + self._restore_state = None + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown the workers and the pin memory thread. + + Args: + in_del: Whether the shutdown is called from the garbage collector (in __del__). + Users should not need to set this. + """ + if self._workers is not None: + if in_del: + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.\n", + end="", + file=sys.stderr, + ) + for worker in self._workers: + worker.shutdown(in_del=in_del) + self._workers = None + self._dataset.close() + self._pin_memory.shutdown(in_del=in_del) + + def __del__(self) -> None: + self.shutdown(in_del=True) + + def __enter__(self) -> "DataLoader[TSample]": + # Already start if using the context manager. This ensures the lifecycle is fixed. + # Otherwise, will start when iterating. + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.shutdown() + + def _epoch_iter(self) -> Generator[TSample, None, None]: + """Iterate over the dataset for one epoch (i.e. all workers StopIteration). + One epoch may also be infinite (if looping the dataset).""" + epoch_id = self._next_epoch_id + self._next_epoch_id += 1 + + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.epoch_iter", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + } + ) + + if self._workers is None: + self.start() + assert self._workers is not None, "DataLoader not started" + + if all(self._exhausted_workers): + # All workers are exhausted, restart for the next epoch. + for worker in self._workers: + worker.new_iter() + self._exhausted_workers = [False] * self._worker_config.safe_num_workers + # Ensure deterministic interleaving across epochs by starting from worker 0 + self._next_worker_id = 0 + + # For all workers, enqueue prefetching samples. + for worker_idx, (worker, exhausted) in enumerate( + zip(self._workers, self._exhausted_workers) + ): + while ( + len(self._prefetching_samples[worker_idx]) < self._prefetch_factor and not exhausted + ): + self._prefetching_samples[worker_idx].append( + self._pin_memory(worker.prefetch_next()) + ) + + # Main loop: + # - Get the next worker to prefetch samples from. + # - Prefetch samples from the worker. + # - Pop the first sample future from the prefetching samples. + # - Get the sample from the sample future (may wait for the sample to be prefetched). + # - Yield the sample. + if DEBUG_LEVEL >= 1: + print(f"{self._exhausted_workers=}\n", end="") + epoch_sample_idx = 0 + while not all(self._exhausted_workers): + # Get the next worker to prefetch samples from. + worker_idx = self._next_worker_id + worker = self._workers[worker_idx] + if DEBUG_LEVEL >= 2: + print(f"{worker_idx=} {worker=}\n", end="") + self._next_worker_id = (worker_idx + 1) % self._worker_config.safe_num_workers + if self._exhausted_workers[worker_idx]: + if DEBUG_LEVEL >= 1: + print(f"{worker_idx=} exhausted, continue with next worker\n", end="") + continue + # Pop the first sample future from the prefetching samples. + sample_future = self._prefetching_samples[worker_idx].pop(0) + if DEBUG_LEVEL >= 2: + print(f"{sample_future=}\n", end="") + # Prefetch samples from the worker. + while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: + # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. + self._prefetching_samples[worker_idx].append( + self._pin_memory(worker.prefetch_next()) + ) + try: + # Get the sample from the sample future (may wait for the sample to be ready). + sample = sample_future.get() + except StopIteration: + if DEBUG_LEVEL >= 1: + print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") + # If the sample future raises StopIteration, remove the worker from the list. + self._prefetching_samples[worker_idx] = [] + self._exhausted_workers[worker_idx] = True + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.epoch_iter.StopIteration", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + } + ) + continue + else: + if DEBUG_LEVEL >= 2: + print(f"{worker_idx=} got sample, yield\n", end="") + if self._worker_config.should_log(level=1): + keys = default_get_batch_keys(sample) + restore_key = get_sample_restore_key(sample) + self._worker_config.worker_log( + { + **{ + "t": "DataLoader.epoch_iter.yield", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + "worker_id": worker_idx, + "worker_sample_idx": restore_key.sample_idx + if isinstance(restore_key, WorkerSampleRestoreKey) + else None, + "epoch_sample_idx": epoch_sample_idx, + "global_sample_idx": self._global_sample_idx, + }, + **({} if keys is None else {"keys": keys}), + **( + {} + if restore_key is None + else {"restore_key": restore_key.as_tuple()} + ), + } + ) + epoch_sample_idx += 1 + self._global_sample_idx += 1 + # Yield the sample. + yield sample + + def __iter__(self) -> Generator[TSample, None, None]: + # Restart the epoch iterator if was not created yet. Otherwise, the existing epoch iterator will be continued. + # That happens e.g. when iteration was interrupted. + if self._current_epoch_iter is None: + if DEBUG_LEVEL >= 1: + print("DL: Starting epoch iterator") + self._current_epoch_iter = self._epoch_iter() + else: + if DEBUG_LEVEL >= 1: + print("DL: Continuing epoch iterator") + assert self._current_epoch_iter is not None + # Important: Do not use yield from here, as it will delegate .close to the inner generator. + for sample in self._current_epoch_iter: + yield sample + # Reset the epoch iterator, it was exhausted. + if DEBUG_LEVEL >= 1: + print("DL: Closing epoch iterator") + self._current_epoch_iter.close() + self._current_epoch_iter = None + + def __len__(self): + return len(self._dataset) + + def _get_batch_size(self) -> int | None: + """Try to infer micro batch size from the dataset""" + if ( + isinstance(self._dataset, BaseWrapperDataset) + and (bds := self._dataset._find_wrapped_dataset(BatchDataset)) is not None + ): + assert isinstance(bds, BatchDataset) + return bds.batch_size + else: + return None + + def save_state_rank(self) -> RankState: + if self._restore_state is not None: + return self._restore_state + prefetched_restore_keys = [ + [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] + for prefetching_sample in self._prefetching_samples + ] + worker_states: list[WorkerState | None] + if self._workers is None: + worker_states = [None] * self._worker_config.safe_num_workers + else: + worker_states = [worker.save_state() for worker in self._workers] + + # Make sure that the exhausted_workers match the individual worker states + assert all( + worker_state is None or worker_state.exhausted == exhausted_worker + for worker_state, exhausted_worker in zip(worker_states, self._exhausted_workers) + ), "Exhausted workers mismatch" + + return RankState( + prefetched_restore_keys=prefetched_restore_keys, + worker_states=worker_states, + next_worker_id=self._next_worker_id, + micro_batch_size=self._get_batch_size(), + ) + + def save_state_global(self, global_dst_rank: int) -> Sequence[RankState | None] | None: + """ + Saves the state of the dataset globally, collecting the state from all ranks using torch + distributed. Allows for restoring the state later using `restore_state_global`, given the + result of this method. + Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not + save the state. Later, restore the state either only loaded on the `dst_rank` or + loading on all ranks separately using `restore_state_global`. + + Note: If you want to save/restore the state per rank separately, use `save_state_rank` and + the corresponding `restore_state_rank`. Also, these do not rely on torch distributed. + + Args: + global_dst_rank: The state will be gathered to this rank. The rank refers to the + global rank, not the rank within the data parallel group. + + Returns: + The state of the dataset (or `None`, if not on `dst_rank`). + """ + # Fetch current rank's worker's state + merged_state = self.save_state_rank() + + # Gather the merged states + if self._worker_config.world_size > 1: + output: Sequence[RankState | None] | None + if self._worker_config.global_rank() == global_dst_rank: + output = [None] * self._worker_config.world_size + else: + # Check if the global_dst_rank is in the same group at all + if self._worker_config.data_parallel_group is not None: + try: + _ = torch.distributed.get_group_rank( + self._worker_config.data_parallel_group, global_dst_rank + ) + except RuntimeError: + raise ValueError( + f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" + ) + + output = None + + torch.distributed.gather_object( + merged_state, + output, + global_dst_rank, + group=self._worker_config.data_parallel_group, + ) + + return output + else: + # Not distributed -> return the merged state + return [merged_state] + + def restore_state_rank(self, state: RankState | None) -> None: + """ + Restore the state of the DataLoader on the current rank. + The state is actually restored when the processes are started, in the iterator. + """ + assert self._workers is None and self._current_epoch_iter is None, ( + "Cannot restore state while workers are running" + ) + assert self._restore_state is None, "Restore state already set" + + if state is None: + # Assume initial state. + return + + assert isinstance(state, RankState) + assert state.micro_batch_size == self._get_batch_size(), "Micro batch size mismatch" + + self._restore_state = state + + def restore_state_global( + self, + state: Sequence[RankState | None] | None, + *, + src_rank: int | None = None, + ) -> None: + """ + Restores the saved state from `save_state_global` (in torch distributed setup). + The global state needs be loaded on every rank that has a data loader instance. + + Optionally, one can specify a src_rank and only provide the state once. + In case of multiple data parallel groups, you must provide the state once + in each data parallel group. In this case the `src_rank` is the rank within the + data parallel group. + + Args: + state: The state to restore, as saved by `save_state_global`. + src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups). + """ + + assert self._workers is None and self._current_epoch_iter is None, ( + "Cannot restore state while workers are running" + ) + assert self._restore_state is None, "Restore state already set" + + # Only restore multi-rank if state is actually a list and we are in a distributed setup. + # Otherwise treat as single rank state. + if src_rank is None or self._worker_config.world_size == 1: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self._worker_config.world_size, ( + "State must be a list of size world_size" + ) + + # All ranks have the state + # Select the state of the current rank + rank_state = state[self._worker_config.rank] + else: + if self._worker_config.data_parallel_group is not None: + # Only the src_rank has the state within this dp group + try: + global_src_rank = torch.distributed.get_global_rank( + self._worker_config.data_parallel_group, src_rank + ) + except RuntimeError: + raise ValueError( + f"src_rank {src_rank} is not in the group of the current rank's worker config" + ) + else: + # If no DP group is given, we assume the global rank is + # the same as the data parallel rank + global_src_rank = src_rank + + if self._worker_config.rank != src_rank: + # Send the state to all other ranks + assert state is None + # Must still be a list of Nones + state = [None] * self._worker_config.world_size + else: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self._worker_config.world_size, ( + "State must be a list of size world_size" + ) + + local_object = [None] + torch.distributed.scatter_object_list( + local_object, + state, + src=global_src_rank, + group=self._worker_config.data_parallel_group, + ) + rank_state = local_object[0] + + self.restore_state_rank(rank_state) + + def _restore_sample(self, restore_key: RestoreKey) -> Future[TSample]: + """ + Restore a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. + + Returns: + A future that will be resolved to the restored sample. + """ + assert isinstance(restore_key, WorkerSampleRestoreKey) + assert self._workers is not None, "Workers must be started before restoring a sample" + rank_worker_id = self._worker_config.rank_worker_id( + override_global_worker_id=restore_key.worker_id + ) + return self._workers[rank_worker_id].restore_sample(restore_key) + + def restore_sample(self, restore_key: RestoreKey) -> TSample: + """ + Restore a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. + + Returns: + The restored sample. + """ + return self._restore_sample(restore_key).get() + + def with_restored_state_rank(self, state: RankState | None) -> "DataLoader[TSample]": + """ + Use this data loader and restore the state. Useful for chaining commands. See `save_state_rank` for more details. + """ + self.restore_state_rank(state) + return self + + def with_restored_state_global( + self, state: Sequence[RankState | None] | None, src_rank: int | None = None + ) -> "DataLoader[TSample]": + """ + Use this data loader and restore the state. Useful for chaining commands. See `save_state_global` for more details. + """ + self.restore_state_global(state, src_rank=src_rank) + return self + + def can_restore_sample(self) -> bool: + """Check if the dataset can save and restore samples.""" + return self._dataset.can_restore_sample() + + def config(self) -> dict[str, Any]: + """Get the configuration of the dataset.""" + return self._dataset.config() + + def __str__(self) -> str: + return f"DataLoader(_id={self._id}, prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" diff --git a/src/megatron/energon/dataloader/future.py b/src/megatron/energon/dataloader/future.py new file mode 100644 index 00000000..b8f1e53c --- /dev/null +++ b/src/megatron/energon/dataloader/future.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from abc import abstractmethod +from typing import Any, Callable, Generic, TypeVar + +R = TypeVar("R", covariant=True) +T = TypeVar("T", covariant=True) + + +class CancelledError(Exception): + """Exception raised when a future was cancelled.""" + + @classmethod + def with_current_traceback(cls): + try: + raise cls() + except cls as e: + if e.__traceback__ is not None and e.__traceback__.tb_next is not None: + return e.with_traceback(e.__traceback__.tb_next) + return e + + +class Future(Generic[R]): + """Base class for abstract futures.""" + + @abstractmethod + def get(self) -> R: + """Get the result of the future. Waits until the future is done.""" + ... + + @abstractmethod + def cancel(self) -> bool: + """Cancel the future. + + Returns: + True if the future was cancelled, False if already done. + """ + ... + + +class DoneFuture(Future[R]): + """Future that is already done.""" + + def __init__(self, result: R): + self._result = result + + def get(self) -> R: + return self._result + + def cancel(self) -> bool: + return False + + +class CallableFuture(Future[R]): + """Future that calls a callable to get the result.""" + + __slots__ = ("_callable", "_value", "_exception", "_cancelled") + + _callable: Callable[[], R] + _value: R + _exception: Exception + _cancelled: bool + + def __init__(self, callable: Callable[[], R]): + self._callable = callable + + def get(self) -> R: + if getattr(self, "_cancelled", False): + raise CancelledError("Future was cancelled") + if not hasattr(self, "_value") and not hasattr(self, "_exception"): + try: + self._value = self._callable() + except Exception as e: + self._exception = e + if hasattr(self, "_exception"): + raise self._exception + return self._value + + def cancel(self) -> bool: + if getattr(self, "_cancelled", False): + return True + if hasattr(self, "_value") or hasattr(self, "_exception"): + return False + self._cancelled = True + return True + + @staticmethod + def chain(future: Future[T], fn: Callable[[Future[T]], R]) -> Future[R]: + """ + Chain a function to a future. + + Args: + future: The future which provides the input for the function. + fn: The function to call on the result of the future, to transform the result. + + Returns: + A future that will be resolved to the result of the function given the result of the future. + """ + return CallableFuture(lambda: fn(future)) + + +class ExceptionFuture(Future[Any]): + """Future that raises an exception.""" + + def __init__(self, exception: Exception): + self._exception = exception + + def get(self) -> Any: + raise self._exception + + def cancel(self) -> bool: + return False diff --git a/src/megatron/energon/dataloader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py new file mode 100644 index 00000000..60b29a00 --- /dev/null +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import threading +from typing import Generic, TypeVar, cast + +import torch + +from megatron.energon.dataloader.asynchronous import ThreadAsynchronous +from megatron.energon.dataloader.future import CallableFuture, Future +from megatron.energon.flavors.base_dataset import PinMemoryMixin + +TSample = TypeVar("TSample") +T = TypeVar("T") + +DEBUG_LEVEL = 1 + + +class PinMemory(Generic[TSample]): + """Base class for pinning memory of samples. + + This class is used to pin memory of samples in the primary process. + """ + + def __init__(self, device: str | torch.device): + self._device = device + + def _pin_memory(self, sample: TSample) -> TSample: + return PinMemoryMixin.sample_pin_memory(sample, self._device) + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + """Pin the memory of a sample. The default implementation runs in the main thread.""" + return CallableFuture.chain(sample, lambda fut: self._pin_memory(fut.get())) + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown any running threads. + + Args: + in_del: Whether the shutdown is called from the garbage collector. + """ + pass + + +class NoPinMemory(PinMemory[TSample]): + """No-op implementation of :class:`PinMemory`. + + Does not pin the memory of samples. + """ + + def __init__(self): + super().__init__(device="cpu") + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + return sample + + +class PinMemoryThread(PinMemory[TSample], ThreadAsynchronous, Generic[TSample]): + """Threaded implementation of :class:`PinMemory`. + + Pins the memory of samples in a separate thread in the background. + + Creates the thread on first use. + """ + + _SHUTDOWN = cast(Future[TSample], object()) + + _thread: threading.Thread | None = None + + def __init__( + self, + device: str | torch.device, + ): + super().__init__(device) + self._asynchronous_init(name="pin-memory") + + def _worker_run(self, *args, **kwargs) -> None: + try: + torch.multiprocessing._set_thread_name("pt_data_pin") + except AttributeError: + pass + super()._worker_run(*args, **kwargs) + + def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] Pinning memory of sample {sample}, waiting for sample data\n", + end="", + ) + sample_data = sample.get() + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] Got sample data\n", end="") + return self._pin_memory(sample_data) + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + """ + Pin the memory of a sample. + Submits the sample future to the thread to fetch it and pins the memory in the thread, + then returns a future for fetching the pinned sample. + """ + if not self.running(): + self.start() + return self._worker_call(self._wrk_pin_memory, sample) diff --git a/src/megatron/energon/dataloader/workers/__init__.py b/src/megatron/energon/dataloader/workers/__init__.py new file mode 100644 index 00000000..ee41db0f --- /dev/null +++ b/src/megatron/energon/dataloader/workers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .base_worker import DataLoaderWorker +from .fork_worker import ForkDataLoaderWorker +from .thread_worker import ThreadDataLoaderWorker + +__all__ = [ + "DataLoaderWorker", + "ThreadDataLoaderWorker", + "ForkDataLoaderWorker", +] diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py new file mode 100644 index 00000000..34d16bb3 --- /dev/null +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import threading +from typing import Generic, TypeVar + +import torch.utils.data._utils + +from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.asynchronous import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) +from megatron.energon.dataloader.future import Future +from megatron.energon.dataloader.workers.base_worker import ( + DataLoaderWorker, + WorkerSampleRestoreKey, + WorkerState, +) +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.rng import SystemRng +from megatron.energon.worker import WorkerConfig + +TSample = TypeVar("TSample", covariant=True) + +# Patching the torch worker info. + +_thread_local_worker_info = threading.local() + + +def torch_set_worker_info(id: int, num_workers: int, seed: int, dataset: SavableDataset): + _thread_local_worker_info._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=id, + num_workers=num_workers, + seed=seed, + dataset=dataset, + ) + + +def _patch_get_worker_info(): + return getattr(_thread_local_worker_info, "_worker_info", None) + + +torch.utils.data.get_worker_info = _patch_get_worker_info + + +class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample], Asynchronous, Generic[TSample]): + """ + Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. + + There are different implementations of the async worker: + - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. + - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. + """ + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + super().__init__(dataset, worker_config, rank_worker_id, cache_pool) + assert worker_config.num_workers > 0, "Async workers require num_workers > 0" + self._asynchronous_init(name=f"wkr-{rank_worker_id}") + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods - now calling to workers via queues. + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + SystemRng.seed(self._seed) + self._global_worker_id = self.worker_config.global_worker_id() + + super()._worker_run(cmd_queue, result_queue) + + def _wrk_prefetch_next(self) -> TSample: + """Wraps the super class method to call it in the worker process.""" + # The super class implementation already returns a resolved future (to be interface compatible), + # so immediately resolve the future to the result (get returns immediately). + return super().prefetch_next().get() + + def _wrk_restore_sample(self, restore_key: WorkerSampleRestoreKey) -> TSample: + """Wraps the super class method to call it in the worker process.""" + # The super class implementation already returns a resolved future (to be interface compatible), + # so immediately resolve the future to the result (get returns immediately). + return super().restore_sample(restore_key).get() + + def dataset_init(self, initial_state: WorkerState | None) -> None: + if self._in_worker(): + return super().dataset_init(initial_state) + else: + return self._worker_call(self.dataset_init, initial_state).get() + + def new_iter(self) -> None: + if self._in_worker(): + return super().new_iter() + else: + return self._worker_call(self.new_iter).get() + + def prefetch_next(self) -> Future[TSample]: + # Do not resolve the future here, but return it. + if self._in_worker(): + return super().prefetch_next() + return self._worker_call(self._wrk_prefetch_next) + + def restore_sample(self, restore_key: WorkerSampleRestoreKey) -> Future[TSample]: + # Do not resolve the future here, but return it. + assert isinstance(restore_key, WorkerSampleRestoreKey) + if self._in_worker(): + return super().restore_sample(restore_key) + return self._worker_call(self._wrk_restore_sample, restore_key) + + def save_state(self) -> WorkerState: + if self._in_worker(): + return super().save_state() + else: + return self._worker_call(self.save_state).get() diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py new file mode 100644 index 00000000..e35c27d6 --- /dev/null +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -0,0 +1,238 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass +from typing import Generic, TypeVar + +from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.future import DoneFuture, ExceptionFuture, Future +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key +from megatron.energon.rng import SystemRng, SystemRngState +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import WrappedRestoreKey, wrap_sample_restore_key + +TSample = TypeVar("TSample", covariant=True) + + +@dataclass(kw_only=True, slots=True, frozen=True) +class WorkerSampleRestoreKey(WrappedRestoreKey): + worker_id: int + sample_idx: int + + +@edataclass +class WorkerState: + """ + State of a worker. + """ + + rng: SystemRngState + dataset: FlexState + exhausted: bool + sample_index: int + + def __str__(self): + from hashlib import sha256 + + rng_hash = sha256(str(self.rng).encode()).hexdigest() + return f"WorkerState(dataset={self.dataset}, exhausted={self.exhausted}, sample_index={self.sample_index}, rng_hash={rng_hash})" + + +class DataLoaderWorker(Generic[TSample]): + """ + A worker for a :class:`DataLoader`. + + The basic implementation iterates the dataset. + The async extension implements the main commands via a command and results queue. + """ + + dataset: SavableDataset[TSample] + worker_config: WorkerConfig + + _rank_worker_id: int + _global_worker_id: int + _seed: int + _cache_pool: CachePool | None + _sample_index: int = 0 + _exhausted: bool = True + + def __init__( + self, + dataset: SavableDataset[TSample], + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + """ + Initialize the worker. + + Args: + dataset: The dataset to iterate over. + worker_config: The worker configuration. + rank_worker_id: The rank of the worker. + cache_pool: The cache pool to use. + """ + self.dataset = dataset + self.worker_config = worker_config + self._rank_worker_id = rank_worker_id + self._global_worker_id = worker_config.global_worker_id(rank_worker_id) + self._seed = self.worker_config.worker_seed(rank_worker_id) + self._cache_pool = cache_pool + + # ------------------------------------------------------------------------------------------------ + # Section: Main control methods + + def start(self) -> None: + """ + Start the worker. + """ + pass + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown the worker. + + Args: + in_del: If True, the worker is being deleted. + """ + self.dataset.worker_close() + self.dataset.close() + + def running(self) -> bool: + """ + Check if the worker is running. + """ + return True + + def _assert_running(self) -> None: + """ + Assert that the worker is running and alive. + """ + assert self.running(), "Worker must be running" + + def __del__(self) -> None: + self.shutdown(in_del=True) + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods + + def dataset_init(self, state: WorkerState | None) -> None: + """ + Initialize the worker (may restore the state). + Calls `new_iter` if the worker is not exhausted and also initially (`state=None`). + + Args: + state: The state to restore the worker from or None for using the initial state. + """ + # This is called in the worker context (process/thread). + assert self._global_worker_id == self.worker_config.global_worker_id(), ( + "Global worker ID mismatch" + ) + assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" + print("dataset_init\n", end="") + self.dataset.reset_state() + if state is None: + self._sample_index = 0 + print("dataset_init reset_state_deep\n", end="") + self.new_iter() + print("dataset_init new_iter\n", end="") + else: + print(f"dataset_init restore_state: {state=}\n", end="") + self._sample_index = state.sample_index + SystemRng.restore_state(state.rng) + self.dataset.restore_state(state.dataset) + if not state.exhausted: + self.new_iter() + assert self._exhausted == state.exhausted, "Exhausted state mismatch" + + def new_iter(self) -> None: + """ + Start a new iterator of the dataset. + Called after the dataset is initialized and to start a new epoch (if the dataset is not infinite). + The iterator is stored in the worker and is used by the `prefetch_next` method, which calls `next` on it. + Updates the exhausted flag to False. + """ + # This is called in the worker context (process/thread). + print("new_iter\n", end="") + self._dataset_iter = iter(self.dataset) + self._exhausted = False + print("new_iter done\n", end="") + + def prefetch_next(self) -> Future[TSample]: + """ + Fetch the next sample (i.e. call `next` on the iterator) and return a future for getting the result. + Updates the exhausted flag if the iterator is exhausted. + + Returns: + A future that will either be resolved to the next sample or raise StopIteration if the iterator is exhausted. + """ + # This is called in the worker context (process/thread). + assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" + if self._exhausted: + try: + raise StopIteration() + except StopIteration as e: + return ExceptionFuture(e) + sample_idx = self._sample_index + self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) + try: + next_sample = next(self._dataset_iter) + self._sample_index += 1 + next_sample = wrap_sample_restore_key( + next_sample, + WorkerSampleRestoreKey, + worker_id=self._global_worker_id, + sample_idx=sample_idx, + ) + except StopIteration as e: + self._exhausted = True + return ExceptionFuture(e) + finally: + self.worker_config.worker_deactivate() + return DoneFuture(next_sample) + + def restore_sample(self, restore_key: WorkerSampleRestoreKey) -> Future[TSample]: + """ + Restore a sample from a restore key in the worker. + + Args: + restore_key: The restore key of the sample to restore. + + Returns: + A future that will be resolved to the restored sample. + """ + assert isinstance(restore_key, WorkerSampleRestoreKey) + assert self._global_worker_id == restore_key.worker_id, "Global worker ID mismatch" + self.worker_config.worker_activate( + restore_key.sample_idx, + cache_pool=self._cache_pool, + ) + try: + return DoneFuture( + set_sample_restore_key(self.dataset.restore_sample(restore_key.inner), restore_key) + ) + finally: + self.worker_config.worker_deactivate() + + def save_state(self) -> WorkerState: + """ + Save the state of the worker. + """ + # This is called in the worker context (process/thread). + print(f"save_state: {self._sample_index=}, {self._exhausted=}\n", end="") + return WorkerState( + rng=SystemRng.save_state(), + dataset=self.dataset.save_state(), + exhausted=self._exhausted, + sample_index=self._sample_index, + ) + + +class DataLoaderNoWorker(DataLoaderWorker[TSample], Generic[TSample]): + """ + DataLoader without async worker. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/src/megatron/energon/dataloader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py new file mode 100644 index 00000000..0897dcb9 --- /dev/null +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import multiprocessing +from typing import Generic, TypeVar + +from megatron.energon.dataloader.asynchronous import ForkAsynchronous +from megatron.energon.dataloader.workers.async_worker import ( + DataLoaderAsynchronousWorker, + torch_set_worker_info, +) +from megatron.energon.wrappers.gc_dataset import gc_init_worker + +TSample = TypeVar("TSample", covariant=True) + + +class ForkDataLoaderWorker( + ForkAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] +): + """ + Implements the `DataLoaderWorker` interface using processes. + """ + + def _worker_run( + self, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + ) -> None: + gc_init_worker(self._rank_worker_id) + + torch_set_worker_info( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + + try: + super()._worker_run(cmd_queue, result_queue) + finally: + self.dataset.worker_close() + self.dataset.close() diff --git a/src/megatron/energon/dataloader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py new file mode 100644 index 00000000..a8a2dea5 --- /dev/null +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Generic, TypeVar + +from megatron.energon.dataloader.asynchronous import ThreadAsynchronous, WorkerCommand, WorkerResult +from megatron.energon.dataloader.asynchronous.base import QueueProtocol +from megatron.energon.dataloader.workers.async_worker import ( + DataLoaderAsynchronousWorker, + torch_set_worker_info, +) + +TSample = TypeVar("TSample", covariant=True) + + +class ThreadDataLoaderWorker( + ThreadAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] +): + """ + Implements the `DataLoaderWorker` interface using threads. + """ + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + torch_set_worker_info( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + try: + return super()._worker_run(cmd_queue, result_queue) + finally: + self.dataset.worker_close() diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index f51aef01..e08bb5c1 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -3,6 +3,7 @@ import dataclasses import inspect +import threading import typing from abc import ABC, abstractmethod from copy import deepcopy @@ -26,6 +27,7 @@ from torch.utils.data import IterableDataset from typing_extensions import Self +import megatron.energon from megatron.energon.cache import FileStore from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath @@ -41,32 +43,39 @@ class PinMemoryMixin: """A mixin class providing a generic `pin_memory` function.""" - def _pin_memory(self, batch: T, device: Union[torch.device, str, None] = None) -> T: + @classmethod + def sample_pin_memory(cls, batch: T, device: Union[torch.device, str, None] = None) -> T: """Pin memory of a batch. Uses recursion to handle nested structures. Supports nested structures of dicts, dataclasses, namedtuples, lists and tuples.""" - if isinstance(batch, torch.Tensor): + if hasattr(batch, "pin_memory"): return batch.pin_memory(device) - elif isinstance(batch, dict): - return {key: self._pin_memory(value, device) for key, value in batch.items()} + if isinstance(batch, dict): + return {key: cls.sample_pin_memory(value, device) for key, value in batch.items()} elif dataclasses.is_dataclass(batch): return type(batch)( **{ - field.name: self._pin_memory(getattr(batch, field.name), device) + field.name: cls.sample_pin_memory(getattr(batch, field.name), device) for field in dataclasses.fields(batch) } ) - elif isinstance(batch, (tuple, list)): + elif not isinstance(batch, (str, bytes)) and isinstance(batch, (tuple, list)): if hasattr(batch, "_fields"): # NamedTuple - return type(batch)(*[self._pin_memory(val, device) for val in batch]) + return type(batch)(*[cls.sample_pin_memory(val, device) for val in batch]) else: # list / tuple - return type(batch)(self._pin_memory(val, device) for val in batch) + return type(batch)(cls.sample_pin_memory(val, device) for val in batch) else: return batch - def pin_memory(self: Self) -> Self: - return self._pin_memory(self) + def pin_memory(self: Self, device: torch.device | str | None = None) -> Self: + assert dataclasses.is_dataclass(self), "Must be a dataclass" + return type(self)( + **{ + field.name: self.sample_pin_memory(getattr(self, field.name), device) + for field in dataclasses.fields(self) + } + ) class ExtendableDataclassMixin: @@ -99,7 +108,7 @@ class MyExtendedClass(MyBaseClass): The extended dataclass instance. """ assert is_dataclass(cls), "Must be a dataclass" - assert issubclass(cls, type(src)), "Cannot extend class of different type" + # assert issubclass(cls, type(src)), "Cannot extend class of different type" for f in dataclasses.fields(src): if not f.init or f.type is ClassVar or typing.get_origin(f.type) is ClassVar: @@ -122,7 +131,8 @@ class Sample(ABC, PinMemoryMixin, ExtendableDataclassMixin): __key__: str #: Key for restoring the sample. This is used to restore the sample from a checkpoint. It # should be a (nested) tuple of strings and integers, which can be used to index the dataset. - __restore_key__: Tuple[Union[str, int, tuple], ...] + # May be None in some cases, but it may then not be restorable. + __restore_key__: "RestoreKey | None" #: A dataset may define a subflavors to distinguish between samples of the same sample type. __subflavors__: Optional[Dict[str, Any]] = None @@ -249,7 +259,7 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) How dataset state saving works: 1. The dataset state needs to be saved in all forked worker processes which contain a copy of - the main dataset instance (see :class:`megatron.energon.SavableDataLoader`). Each worker returns + the main dataset instance (see :class:`megatron.energon.DataLoader`). Each worker returns only its own state. 2. The main process merges the states via the :meth:`megatron.energon.SavableDataset.merge_states` method in the main process on the main dataset instance (which doesn't hold the worker states, @@ -263,8 +273,12 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) #: List of names of the fields that are saved and restored in the state. _savable_fields: ClassVar[Tuple[str, ...]] = () + #: List of names of the fields that are not saved, but are still part of the state (i.e. not shared between workers). + _worker_local_fields: ClassVar[Tuple[str, ...]] = () + def __init__(self, worker_config: WorkerConfig): self.worker_config = worker_config + self._thread_state = threading.local() @abstractmethod def len_worker(self, worker_idx: int | None = None) -> int: @@ -341,14 +355,18 @@ def restore_state(self, state: FlexState) -> None: else: setattr(self, key, value) - @abstractmethod - def reset_state_own(self) -> None: - """Resets the state of the dataset to the initial state. Can only be called in a worker process.""" - ... + def reset_state(self) -> None: + """ + Resets the state of the dataset. Called at least once in the worker process before iterating. + Recursively resets the state of all wrapped datasets as well. + """ + pass - def reset_state_deep(self) -> None: - """Resets the state of the dataset to the initial state. Can only be called in a worker process.""" - self.reset_state_own() + def worker_close(self) -> None: + """ + Closes all worker-local resources. + """ + pass @abstractmethod def worker_has_samples(self) -> bool: @@ -382,18 +400,49 @@ def assert_can_restore(self) -> None: """Asserts that the dataset can restore a sample from a key.""" assert self.can_restore_sample(), "This dataset cannot restore samples." - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + def restore_sample(self, restore_key: "RestoreKey") -> T_sample: """ - Generic key type, because it might be either an integer (for a core dataset), or something - more complex (e.g. for blended datasets). + Restores a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. - Default raises an exception (assumed non-deterministic if not implemented, does not - guarantee determinism). + Returns: + The restored sample. """ raise NotImplementedError( - "This dataset does not support indexing, because it is not safely deterministic." + "This dataset does not support restoring, because it is not safely deterministic." ) + def close(self) -> None: + """Closes all shared resources.""" + pass + + def __getattribute__(self, name: str) -> Any: + if name in ("_savable_fields", "_worker_local_fields", "_thread_state", "worker_config"): + return object.__getattribute__(self, name) + elif name in self._savable_fields or name in self._worker_local_fields: + try: + return getattr(self._thread_state, name) + except AttributeError: + return object.__getattribute__(self, name) + else: + return object.__getattribute__(self, name) + + def __delattr__(self, name: str) -> None: + if name in self._savable_fields or name in self._worker_local_fields: + delattr(self._thread_state, name) + else: + object.__delattr__(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_savable_fields", "_worker_local_fields", "_thread_state", "worker_config"): + object.__setattr__(self, name, value) + elif name in self._savable_fields or name in self._worker_local_fields: + setattr(self._thread_state, name, value) + else: + object.__setattr__(self, name, value) + class BaseCoreDatasetFactory(Generic[T_sample], ABC): """Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for @@ -420,39 +469,60 @@ def __len__(self) -> int: ... -def add_sample_restore_key( - sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False -) -> T_sample: - """Adds a key to a sample. The sample must be a valid `Sample` or dict containing - __restore_key__, which is a tuple of keys that can be used to restore the inner sample. - This restore key is prepended with the `key`.""" - if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): - try: - sample.__restore_key__ = (type(src).__name__, *key, *sample.__restore_key__) - except KeyError: - pass - elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (type(src).__name__, *key, *sample["__restore_key__"]) - elif fail_otherwise: - raise RuntimeError( - "Did not yield a sample with a restore key, but is marked stateless/deterministic." +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class RestoreKey(ABC): + """Base class for restore keys.""" + + def _tupleify(self, value: Any) -> Any: + if isinstance(value, (int, str, float, bool)): + return value + elif isinstance(value, RestoreKey): + return value.as_tuple() + elif isinstance(value, (list, tuple)): + return tuple(self._tupleify(v) for v in value) + else: + return value + + def as_tuple(self) -> tuple[Any, ...]: + return ( + self.__class__.__name__, + *(self._tupleify(getattr(self, field.name)) for field in dataclasses.fields(self)), ) - return sample + + @staticmethod + def _untupleify(value: Any) -> Any: + if isinstance(value, (int, str, float, bool)): + return value + elif isinstance(value, RestoreKey): + return value.from_tuple(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], str) and hasattr(megatron.energon, value[0]): + return getattr(megatron.energon, value[0]).from_tuple(value[1:]) + else: + return tuple(RestoreKey._untupleify(v) for v in value) + + @staticmethod + def from_tuple(json: tuple[Any, ...]) -> "RestoreKey": + cls = getattr(megatron.energon, json[0]) + kwargs = {} + for field in dataclasses.fields(cls): + kwargs[field.name] = RestoreKey._untupleify(json[1:]) + return cls(**kwargs) def set_sample_restore_key( - sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False + sample: T_sample, restore_key: RestoreKey, fail_otherwise: bool = False ) -> T_sample: """Sets the restore key for a sample. The sample must be a valid `Sample` or dict containing __restore_key__, which is a tuple of keys that can be used to restore the inner sample. This restore key is prepended with the `key`.""" if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): try: - sample.__restore_key__ = (type(src).__name__, *key) + sample.__restore_key__ = restore_key except KeyError: pass elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (type(src).__name__, *key) + sample["__restore_key__"] = restore_key elif fail_otherwise: raise RuntimeError( "Did not yield a sample with a restore key, but is marked stateless/deterministic." diff --git a/src/megatron/energon/flavors/jsonl/ijsonl.py b/src/megatron/energon/flavors/jsonl/ijsonl.py index c292aa2b..7952ced1 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl.py @@ -112,9 +112,13 @@ class CachedIJsonlOffsetReader: cache_size: The number of entries to keep in the cache. By default, we keep 32. """ + ijsonl_index_reader_cache: Dict[int, CacheEntry] + cache_size: int + jsonl_file: EPath + def __init__(self, jsonl_file: Union[str, EPath], cache_size: int = 32): # Maps current_offset -> CacheEntry - self.ijsonl_index_reader_cache: Dict[int, CacheEntry] = {} + self.ijsonl_index_reader_cache = {} self.cache_size = cache_size self.jsonl_file = EPath(jsonl_file) @@ -219,11 +223,6 @@ def get_ijsonl_byte_offset( return result_byte_offset, length - def __len__(self) -> int: - if len(self.ijsonl_index_reader_cache) == 0: - return IJsonlIndexReader.count_samples(self.jsonl_file) - return len(next(iter(self.ijsonl_index_reader_cache.values())).ijsonl_index_reader) - 1 - def get_total_size(self) -> int: if len(self.ijsonl_index_reader_cache) == 0: self.ijsonl_index_reader_cache[0] = CacheEntry( @@ -259,6 +258,10 @@ class IJsonlFile: """ def __init__(self, fileobj: BinaryIO): + """ + Args: + fileobj: The file object to read from. Takes ownership of the file object. + """ self.fileobj = fileobj def seek(self, offset: int): diff --git a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py index 202f1926..b3304f01 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC +import threading from typing import ( Callable, Generator, @@ -17,13 +17,15 @@ IJsonlIndexReader, IJsonlSamplePointer, ) -from megatron.energon.flavors.webdataset.structs import FilteredSample +from megatron.energon.flavors.webdataset.itar_reader import RawSampleReaderInterface +from megatron.energon.flavors.webdataset.multi_key_cache import MultiKeyCache +from megatron.energon.flavors.webdataset.structs import FilteredSample, WebdatasetRestoreKey from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) -class IJsonlReader(ABC): +class IJsonlReader(RawSampleReaderInterface[int | str]): """ Class for reading indexed jsonl files containing json samples. @@ -40,8 +42,12 @@ class IJsonlReader(ABC): jsonl_path: EPath sample_filter: Optional[Callable[[str], bool]] - cached_offset_reader: CachedIJsonlOffsetReader - ijsonl_file: IJsonlFile | None = None + _length: int + _total_size: int + + thread_local: threading.local + cache_lock: threading.Lock + ijsonl_files_cache: MultiKeyCache[int, IJsonlFile] def __init__( self, @@ -51,16 +57,62 @@ def __init__( ): self.jsonl_path = jsonl_path self.sample_filter = sample_filter - self.cached_offset_reader = CachedIJsonlOffsetReader( - jsonl_path, cache_size=index_cache_size - ) + self.index_cache_size = index_cache_size + self.thread_local = threading.local() + self.ijsonl_files_cache = MultiKeyCache() + self.cache_lock = threading.Lock() + + with IJsonlIndexReader(jsonl_path) as ijsonl_index_reader: + # Number of samples + self._length = len(ijsonl_index_reader) - 1 + # Byte size + self._total_size = ijsonl_index_reader[self._length] def __len__(self) -> int: - return len(self.cached_offset_reader) + return self._length def __str__(self) -> str: return f"IJsonlReader(jsonl_path={self.jsonl_path})" + @property + def _cached_offset_reader(self) -> CachedIJsonlOffsetReader: + return self.thread_local._cached_offset_reader + + def worker_init(self): + self.thread_local._cached_offset_reader = CachedIJsonlOffsetReader( + self.jsonl_path, cache_size=self.index_cache_size + ) + + def worker_close(self): + if hasattr(self.thread_local, "_cached_offset_reader"): + self.thread_local._cached_offset_reader.close() + del self.thread_local._cached_offset_reader + + def _get_ijsonl_file_cached(self, sample_idx: int) -> IJsonlFile: + """ + Get the IJsonlFile object for the given sample index. + If the file is not already open, open it. + """ + with self.cache_lock: + reader = self.ijsonl_files_cache.pop(sample_idx) + if reader is None: + if len(self.ijsonl_files_cache) < self.index_cache_size: + reader = IJsonlFile(fileobj=self.jsonl_path.open(mode="rb")) + else: + # Reuse the oldest file + reader = self.ijsonl_files_cache.pop() + return reader + + def _update_ijsonl_file_cache(self, sample_idx: int, reader: IJsonlFile) -> None: + """ + Update the IJsonlFile object for the given sample index. + """ + with self.cache_lock: + while len(self.ijsonl_files_cache) >= self.index_cache_size: + # Evict the oldest file + self.ijsonl_files_cache.pop().close() + self.ijsonl_files_cache.add(sample_idx, reader) + def _get_item_by_sample_pointer( self, sample_pointer: IJsonlSamplePointer, @@ -69,8 +121,7 @@ def _get_item_by_sample_pointer( Get a sample from the dataset or slice it. Args: - sample_pointer: The sample pointer to get the sample from. - sample_index: The global index of the sample in the dataset. + sample_pointer: Pointer to the sample in the jsonl file. Returns: The sample or None if the sample is invalid. @@ -80,20 +131,22 @@ def _get_item_by_sample_pointer( if self.sample_filter is not None and not self.sample_filter(key): return None - if self.ijsonl_file is None: - self.ijsonl_file = IJsonlFile(self.jsonl_path.open("rb")) + ijsonl_file = self._get_ijsonl_file_cached(sample_pointer.index) + + json_data = ijsonl_file.next(sample_pointer.byte_offset, sample_pointer.byte_size) - json_data = self.ijsonl_file.next(sample_pointer.byte_offset, sample_pointer.byte_size) if json_data is None: return None + self._update_ijsonl_file_cache(sample_pointer.index + 1, ijsonl_file) + return FilteredSample( __key__=key, __shard__=self.jsonl_path.name, - __restore_key__=("Webdataset", sample_pointer.index), + __restore_key__=WebdatasetRestoreKey(index=sample_pointer.index), __sources__=( SourceInfo( - dataset_path=str(self.jsonl_path), + dataset_path=self.jsonl_path, index=sample_pointer.index, shard_name=self.jsonl_path.name, file_names=(f"{key}.json",), @@ -118,7 +171,7 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf except ValueError: raise ValueError(f"Invalid JSONL sample key: {idx}") - byte_offset, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(idx) + byte_offset, byte_size = self._cached_offset_reader.get_ijsonl_byte_offset(idx) sample: FilteredSample | None = self._get_item_by_sample_pointer( IJsonlSamplePointer( index=idx, @@ -178,13 +231,13 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int], except ValueError: raise ValueError(f"Invalid JSONL sample key: {sample_key}") - _, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(sample_idx) + _, byte_size = self._cached_offset_reader.get_ijsonl_byte_offset(sample_idx) yield f"{sample_key}.json", byte_size, 0 def get_total_size(self) -> int: - return self.cached_offset_reader.get_total_size() + return self._total_size def close(self): - if self.ijsonl_file is not None: - self.ijsonl_file.close() - self.cached_offset_reader.close() + with self.cache_lock: + for ijsonl_file in self.ijsonl_files_cache.flush(): + ijsonl_file.close() diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index b3ef98d1..fa524183 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -507,7 +507,7 @@ def get_media_metadata(self, entry_key: str) -> Tuple[str, str] | None: def close(self): """Close the database connection.""" - if self.db is not None: + if getattr(self, "db", None) is not None: self.db.thread_close() del self.db diff --git a/src/megatron/energon/flavors/webdataset/itar.py b/src/megatron/energon/flavors/webdataset/itar.py index ef6e435e..e5299add 100644 --- a/src/megatron/energon/flavors/webdataset/itar.py +++ b/src/megatron/energon/flavors/webdataset/itar.py @@ -1,15 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import contextlib import struct import tarfile -from types import TracebackType -from typing import BinaryIO, Dict, Generator, Optional, Tuple, Type, Union +from typing import Dict, Generator, Optional, Tuple, Union from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath -from megatron.energon.retry_stream import RetryReadStream ITAR_SUFFIX = ".tar.idx" @@ -106,76 +103,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close(finalize=exc_val is None) -class SubFileReader(BinaryIO): - """A file-like object that reads a subfile (i.e. offset, size defined portion) of a larger - file.""" - - def __init__(self, stream: BinaryIO, offset: int, size: int): - self.offset = offset - self._pos = 0 - self.size = size - self.stream = stream - self.stream.seek(self.offset) - - def read(self, n: int = -1) -> bytes: - if n == -1: - n = self.size - self._pos - else: - n = min(n, self.size - self._pos) - if n == 0: - return b"" - read = self.stream.read(n) - self._pos += len(read) - return read - - def seek(self, offset: int, whence: int = 0) -> int: - if whence == 0: - self._pos = offset - elif whence == 1: - self._pos += offset - elif whence == 2: - self._pos = self.size + offset - else: - raise ValueError("Invalid whence value") - self._pos = max(0, min(self._pos, self.size)) - self.stream.seek(self.offset + self._pos) - return self._pos - - def tell(self) -> int: - return self._pos - - def __enter__(self) -> BinaryIO: - return self - - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType - ) -> None: - self.close() - - def close(self) -> None: - self.stream.close() - - def isatty(self) -> bool: - return False - - def seekable(self) -> bool: - return True - - def writable(self) -> bool: - return False - - -def get_itar_byte_offset( - path: Union[str, EPath], - sample_offset: int = 0, -) -> int: - """Gets the byte offset from sample offsets.""" - if sample_offset == 0: - return 0 - with TarIndexReader(path) as itar: - return itar[sample_offset] - - @edataclass class CacheEntry: tar_index_reader: TarIndexReader @@ -308,6 +235,11 @@ def get_itar_byte_offset( return result_byte_offset, length + def close(self): + for cache_entry in self.tar_index_reader_cache.values(): + cache_entry.tar_index_reader.close() + self.tar_index_reader_cache.clear() + class ITarFile(tarfile.TarFile): """This class is a subclass of tarfile.TarFile that allows for reading a tarfile, @@ -341,40 +273,3 @@ def next(self): self.fileobj.seek(self.offset) return super().next() - - -@contextlib.contextmanager -def open_itar(path: Union[str, EPath], byte_offset: int = 0, byte_size: Optional[int] = None): - """ - Open an indexed tarfile with offset and size. - Args: - path: Path to the tarfile to open - byte_offset: Byte offset within the file - byte_size: Size of the file to read - - Returns: - The opened tarfile - """ - path = EPath(path) - - # TODO: if tar file startswith(b"\x1f\x8b\x08") -> Seekable gzip file - with path.open("rb") as f: - if f.read(3) == b"\x1f\x8b\x08": - # Open as seekable tgz - raise ValueError("Seekable tgz not supported yet") - - if byte_offset != 0 or byte_size is not None: - if byte_size is None: - byte_size = path.size() - byte_offset - with RetryReadStream(path) as stream: - with SubFileReader( - stream, - offset=byte_offset, - size=byte_size, - ) as fileobj: - with ITarFile.open(fileobj=fileobj, mode="r:") as f: - yield f - else: - with RetryReadStream(path) as fileobj: - with ITarFile.open(fileobj=fileobj, mode="r:") as f: - yield f diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index 0134f338..7ca59785 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import threading from abc import ABC, abstractmethod from bisect import bisect_right from typing import ( @@ -30,13 +31,64 @@ ITarSamplePointer, ) from megatron.energon.flavors.webdataset.metadata import get_info_shard_files -from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo +from megatron.energon.flavors.webdataset.multi_key_cache import MultiKeyCache +from megatron.energon.flavors.webdataset.structs import ( + FilteredSample, + ShardInfo, + WebdatasetRestoreKey, +) from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) -class ITarReader(ABC, Generic[T_index]): +class RawSampleReaderInterface(ABC, Generic[T_index]): + """ + An abstract base class for reading a sequence of raw samples. + """ + + @abstractmethod + def __len__(self) -> int: + """Returns the total number of samples in the reader.""" + ... + + @abstractmethod + def __str__(self) -> str: + """ + Must return a descriptive string of the concrete reader. + """ + ... + + @abstractmethod + def worker_init(self): + """ + Initialize the reader for the worker. + """ + ... + + @abstractmethod + def worker_close(self): + """ + Close the reader for the worker. + """ + ... + + @abstractmethod + def close(self): + """ + Close the reader and clear all shared resources. + """ + ... + + @abstractmethod + def __getitem__(self, idx: T_index) -> FilteredSample | None: + """ + Get a sample from the dataset or slice it. Thread-safe. + """ + ... + + +class ITarReader(RawSampleReaderInterface[T_index], Generic[T_index]): """ An abstract base class for reading a sequence of tar files containing samples. @@ -55,7 +107,8 @@ class ITarReader(ABC, Generic[T_index]): tar_filenames: List[str] tar_filepaths: List[EPath] part_filter: Optional[Callable[[str], bool]] - itar_files_cache: Dict[int, ITarFile] + cache_lock: threading.Lock + itar_files_cache: MultiKeyCache[int, ITarFile] sample_filter: Optional[Callable[[str], bool]] def __init__( @@ -75,27 +128,19 @@ def __init__( self.tar_filenames = tar_filenames self.tar_filepaths = tar_filepaths self.part_filter = part_filter - self.itar_files_cache = {} + self.cache_lock = threading.Lock() + self.itar_files_cache = MultiKeyCache() self.itar_cache_size = itar_cache_size self.sample_filter = sample_filter - @abstractmethod - def __len__(self) -> int: - """Returns the total number of samples in the reader.""" - raise NotImplementedError - - @abstractmethod - def __str__(self) -> str: - """ - Must return a descriptive string of the concrete reader. - """ - raise NotImplementedError - def close(self): - for tar_file in self.itar_files_cache.values(): - tar_file.fileobj.close() - tar_file.close() - self.itar_files_cache.clear() + """Effectively clears the internal shared cache.""" + with self.cache_lock: + for tar_file in self.itar_files_cache.flush(): + fileobj = tar_file.fileobj + tar_file.close() + if fileobj is not None: + fileobj.close() @abstractmethod def _get_itar_sample_pointer(self, idx: T_index) -> ITarSamplePointer: @@ -105,24 +150,24 @@ def _get_itar_sample_pointer(self, idx: T_index) -> ITarSamplePointer: def _get_itarfile_cached(self, tar_file_id: int) -> ITarFile: """ Get the ITarFile object for the given tar file id. - If the file is not already open, open it. If we exceed - the global cache limit, close the least recently used file. + If the file is not already open, open it. """ - if tar_file_id not in self.itar_files_cache: - file_object = self.tar_filepaths[tar_file_id].open(mode="rb") - tar_file = ITarFile.open(fileobj=file_object, mode="r:") - self.itar_files_cache[tar_file_id] = tar_file - - # If we hit the limit of open files, close the least recently used file - while len(self.itar_files_cache) > self.itar_cache_size: - # Get the oldest file - lru_key = next(iter(self.itar_files_cache)) - - self.itar_files_cache[lru_key].fileobj.close() - self.itar_files_cache[lru_key].close() - del self.itar_files_cache[lru_key] - - return self.itar_files_cache[tar_file_id] + with self.cache_lock: + reader = self.itar_files_cache.pop(tar_file_id) + if reader is None: + file_object = self.tar_filepaths[tar_file_id].open(mode="rb") + reader = ITarFile.open(fileobj=file_object, mode="r:") + return reader + + def _update_itarfile_cache(self, tar_file_id: int, reader: ITarFile) -> None: + """ + Update the ITarFile object for the given tar file id. + """ + with self.cache_lock: + while len(self.itar_files_cache) >= self.itar_cache_size: + # Evict the oldest file + self.itar_files_cache.pop().close() + self.itar_files_cache.add(tar_file_id, reader) def _get_part_by_raw_sample_pointer( self, @@ -175,7 +220,6 @@ def _get_item_by_sample_pointer( """ # Open the tar file (cached) - tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id) shard_name = self.tar_filenames[sample_pointer.tar_file_id] sample_base_name = None sample_name = None @@ -183,55 +227,63 @@ def _get_item_by_sample_pointer( file_names: list[str] = [] # Position the tar file at the correct offset - tar_file.offset = sample_pointer.byte_offset - - while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: - tarinfo = tar_file.next() - if tarinfo is None: - raise ValueError( - f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" - ) - fname = tarinfo.name - if not tarinfo.isfile() or fname is None: - continue - if skip_meta_re.match(fname): - continue - - # Extract the base_name and extension - m = split_name_re.match(fname) - if not m: - continue - cur_base_name, cur_ext = m.groups() - + tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id) + try: + tar_file.offset = sample_pointer.byte_offset + + while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: + tarinfo = tar_file.next() + if tarinfo is None: + if tar_file.offset == sample_pointer.byte_offset + sample_pointer.byte_size: + break + else: + raise ValueError( + f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" + ) + fname = tarinfo.name + if not tarinfo.isfile() or fname is None: + continue + if skip_meta_re.match(fname): + continue + + # Extract the base_name and extension + m = split_name_re.match(fname) + if not m: + continue + cur_base_name, cur_ext = m.groups() + + if sample_base_name is None: + sample_base_name = cur_base_name + sample_name = f"{shard_name}/{cur_base_name}" + if self.sample_filter is not None and not self.sample_filter(sample_name): + return None + else: + if sample_base_name != cur_base_name: + raise ValueError( + f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" + ) + + if entry_match_fn is not None: + # If entry_match_fn is provided, use it to determine if we should take this entry + take_entry = entry_match_fn(fname) + else: + # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry + take_entry = self.part_filter is None or self.part_filter(cur_ext) + + if take_entry: + member_bytes = tar_file.extractfile(tarinfo).read() + group_parts[cur_ext] = member_bytes + file_names.append(fname) if sample_base_name is None: - sample_base_name = cur_base_name - sample_name = f"{shard_name}/{cur_base_name}" - if self.sample_filter is not None and not self.sample_filter(sample_name): - return None - else: - if sample_base_name != cur_base_name: - raise ValueError( - f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" - ) - - if entry_match_fn is not None: - # If entry_match_fn is provided, use it to determine if we should take this entry - take_entry = entry_match_fn(fname) - else: - # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry - take_entry = self.part_filter is None or self.part_filter(cur_ext) - - if take_entry: - member_bytes = tar_file.extractfile(tarinfo).read() - group_parts[cur_ext] = member_bytes - file_names.append(fname) - if sample_base_name is None: - raise ValueError(f"No valid files found in sample {sample_pointer}") + raise ValueError(f"No valid files found in sample {sample_pointer}") + finally: + # Return the reader to the cache + self._update_itarfile_cache(sample_pointer.tar_file_id, tar_file) return FilteredSample( __key__=sample_base_name, __shard__=self.tar_filenames[sample_pointer.tar_file_id], - __restore_key__=("Webdataset", restore_index), + __restore_key__=WebdatasetRestoreKey(index=restore_index), __sources__=( SourceInfo( dataset_path=self.base_path, @@ -259,7 +311,9 @@ class JoinIndexFileITarReader(ITarReader[int]): index_file: EPath column: int - index_reader_cache: Dict[int, JoinIndexReader] + index_reader_cache_lock: threading.Lock + index_reader_cache: MultiKeyCache[int, JoinIndexReader] + active_readers: int = 0 index_reader_cache_size: int def __init__( @@ -278,7 +332,8 @@ def __init__( # Create the full path to each tar file tar_filepaths = [base_path / fn for fn in tar_filenames] - self.index_reader_cache = {} + self.index_reader_cache_lock = threading.Lock() + self.index_reader_cache = MultiKeyCache() self.index_reader_cache_size = itar_cache_size super().__init__( @@ -290,24 +345,36 @@ def __init__( sample_filter=sample_filter, ) + def worker_init(self): + pass + + def worker_close(self): + pass + def _get_join_index_reader_cached(self, sample_idx: int) -> JoinIndexReader: """ Get the JoinIndexReader object for the given sample index, or create it if it doesn't exist. """ - - if sample_idx not in self.index_reader_cache: - index_reader = JoinIndexReader(self.index_file, column=self.column) - self.index_reader_cache[sample_idx] = index_reader - - # If we hit the limit of open files, close the least recently used file - while len(self.index_reader_cache) > self.index_reader_cache_size: - # Get the oldest file - lru_key = next(iter(self.index_reader_cache)) - - self.index_reader_cache[lru_key].close() - del self.index_reader_cache[lru_key] - - return self.index_reader_cache[sample_idx] + with self.index_reader_cache_lock: + index_reader = self.index_reader_cache.pop(sample_idx) + if index_reader is None: + if len(self.index_reader_cache) < self.index_reader_cache_size: + index_reader = JoinIndexReader(self.index_file, column=self.column) + else: + # Just reuse the oldest reader + index_reader = self.index_reader_cache.pop() + + return index_reader + + def _update_index_reader_cache(self, sample_idx: int, reader: JoinIndexReader) -> None: + """ + Update the JoinIndexReader object for the given tar file id. + """ + with self.index_reader_cache_lock: + # If we hit the limit of open files, close the least recently used file + while len(self.index_reader_cache) >= self.index_reader_cache_size: + self.index_reader_cache.pop().close() + self.index_reader_cache.add(sample_idx, reader) def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: """ @@ -318,8 +385,11 @@ def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: # Update cache entry new_offset = index_reader.tell_row() - del self.index_reader_cache[sample_idx] - self.index_reader_cache[new_offset] = index_reader + assert new_offset == sample_idx + 1, ( + f"Expected new offset to be {sample_idx + 1}, got {new_offset}" + ) + + self._update_index_reader_cache(new_offset, index_reader) assert len(row) == 1 shard_idx, byte_offset, byte_size = row[0] @@ -333,8 +403,8 @@ def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: def __len__(self) -> int: try: # Get any reader, they will all work - index_reader = next(iter(self.index_reader_cache.values())) - except StopIteration: + index_reader = self.index_reader_cache.pop() + except IndexError: # If there's no reader yet, we need to create one to get the length index_reader = self._get_join_index_reader_cached(0) @@ -357,7 +427,7 @@ class ShardInfosITarReader(ITarReader[int]): shard_infos: List[ShardInfo] shard_tar_file_idxs: List[int] shard_count_cumsum: List[int] - cached_offset_reader: CachedItarOffsetReader + _thread_local: threading.local def __init__( self, @@ -394,8 +464,8 @@ def __init__( tar_filenames = list(cur_tar_files.keys()) tar_filepaths = [p[1] for p in cur_tar_files.values()] - # Instantiate cached reader for the .tar.idx files - self.cached_offset_reader = CachedItarOffsetReader(cache_size=itar_cache_size) + self._itar_cache_size = itar_cache_size + self._thread_local = threading.local() super().__init__( base_path=base_path, @@ -406,6 +476,20 @@ def __init__( sample_filter=sample_filter, ) + @property + def _cached_offset_reader(self) -> CachedItarOffsetReader: + return self._thread_local._cached_offset_reader + + def worker_init(self): + self._thread_local._cached_offset_reader = CachedItarOffsetReader( + cache_size=self._itar_cache_size + ) + + def worker_close(self): + if hasattr(self._thread_local, "_cached_offset_reader"): + self._thread_local._cached_offset_reader.close() + del self._thread_local._cached_offset_reader + def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: """ Get the ITarSample object for the given index. @@ -423,7 +507,7 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: # Now we know the tar file and the sample offset in the file. # We need to figure out the byte offset and size of the sample, # by looking it up in the .tar.idx file. - byte_offset, byte_size = self.cached_offset_reader.get_itar_byte_offset( + byte_offset, byte_size = self._cached_offset_reader.get_itar_byte_offset( shard.path, sample_idx_in_shard_file ) @@ -450,9 +534,10 @@ class SqliteITarEntryReader(ITarReader[str]): A concrete ITarReader that constructs its internal sample list from a SQLite database. """ - sqlite_reader: SqliteIndexReader db_has_sample_parts: int + thread_local: threading.local + def __init__( self, base_path: EPath, @@ -469,12 +554,12 @@ def __init__( tar_filepaths = [base_path / fn for fn in tar_filenames] # Initialize the SQLite reader - sqlite_path = base_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME - self.sqlite_reader = SqliteIndexReader(sqlite_path) - - self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts() + self.sqlite_path = base_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME + with SqliteIndexReader(self.sqlite_path) as check_db: + self.db_has_sample_parts = check_db.db_has_sample_parts() self.key_is_full_entryname = key_is_full_entryname + self.thread_local = threading.local() super().__init__( base_path=base_path, @@ -485,12 +570,24 @@ def __init__( sample_filter=sample_filter, ) + @property + def _sqlite_reader(self) -> SqliteIndexReader: + return self.thread_local._sqlite_reader + + def worker_init(self): + self.thread_local._sqlite_reader = SqliteIndexReader(self.sqlite_path) + + def worker_close(self): + if getattr(self.thread_local, "_sqlite_reader", None) is not None: + self.thread_local._sqlite_reader.close() + del self.thread_local._sqlite_reader + def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer: """ Get the ITarSample object for the given index. """ - return self.sqlite_reader.get_sample_pointer_by_key(sample_key) + return self._sqlite_reader.get_sample_pointer_by_key(sample_key) def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: """List all samples in the jsonl file. @@ -498,7 +595,7 @@ def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: Returns: A generator of tuples of (sample_key, size, tar_file_id) """ - return self.sqlite_reader.list_all_samples() + return self._sqlite_reader.list_all_samples() def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]: """List all sample parts in the jsonl file. @@ -506,7 +603,7 @@ def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]: Returns: A generator of tuples of (sample_key + "." + part_name, size, tar_file_id) """ - return self.sqlite_reader.list_all_sample_parts() + return self._sqlite_reader.list_all_sample_parts() def list_sample_parts( self, sample_key: str, slow_mode: bool = False @@ -528,7 +625,7 @@ def list_sample_parts( """ if not slow_mode: - yield from self.sqlite_reader.list_sample_parts(sample_key) + yield from self._sqlite_reader.list_sample_parts(sample_key) else: sample_pointer = self._get_itar_sample_pointer(sample_key) @@ -540,7 +637,7 @@ def list_sample_parts( yield ext, len(sample[ext]), sample_pointer.tar_file_id def get_total_size(self) -> int: - return self.sqlite_reader.get_total_size() + return self._sqlite_reader.get_total_size() @overload def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ... @@ -571,7 +668,7 @@ def __getitem__( if self.db_has_sample_parts: # Directly fetch the sample part (byte offset and size) from the database - raw_sample_pointer = self.sqlite_reader.get_sample_part(sample_key, sample_ext) + raw_sample_pointer = self._sqlite_reader.get_sample_part(sample_key, sample_ext) raw_data, source_info = self._get_part_by_raw_sample_pointer( raw_sample_pointer, key ) @@ -598,7 +695,7 @@ def __getitem__( def __len__(self) -> int: """Return the total number of samples in the database.""" - return self.sqlite_reader.get_sample_count() + return self._sqlite_reader.get_sample_count() def __str__(self) -> str: """Return a descriptive string of this reader.""" @@ -612,12 +709,11 @@ def __str__(self) -> str: def close(self): """Close the SQLite reader and any open ITarFiles.""" # Close the SQLite reader - if hasattr(self, "sqlite_reader") and self.sqlite_reader is not None: - self.sqlite_reader.close() + if hasattr(self, "_sqlite_reader") and self._sqlite_reader is not None: + self._sqlite_reader.close() # Close any open ITarFiles (using parent class implementation) - for tar_file_id in list(self.itar_files_cache.keys()): - tar_file = self.itar_files_cache[tar_file_id] + for tar_file in self.itar_files_cache.flush(): if ( tar_file is not None and hasattr(tar_file, "fileobj") @@ -626,4 +722,3 @@ def close(self): tar_file.fileobj.close() if tar_file is not None and hasattr(tar_file, "close"): tar_file.close() - del self.itar_files_cache[tar_file_id] diff --git a/src/megatron/energon/flavors/webdataset/joined_webdataset.py b/src/megatron/energon/flavors/webdataset/joined_webdataset.py index 49394dd2..29c3f16c 100644 --- a/src/megatron/energon/flavors/webdataset/joined_webdataset.py +++ b/src/megatron/energon/flavors/webdataset/joined_webdataset.py @@ -220,8 +220,7 @@ def load_sample(self, samples: RawSampleData) -> T_sample: # Then combine the loaded smaples into the final type return set_sample_restore_key( self._sample_joiner(*loaded_samples), - *samples.__restore_key__, - src=self, + samples.__restore_key__, fail_otherwise=True, ) diff --git a/src/megatron/energon/flavors/webdataset/multi_key_cache.py b/src/megatron/energon/flavors/webdataset/multi_key_cache.py new file mode 100644 index 00000000..40865591 --- /dev/null +++ b/src/megatron/energon/flavors/webdataset/multi_key_cache.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Generator, Generic, TypeVar, overload + +T_key = TypeVar("T_key") +T_value = TypeVar("T_value") + + +class MultiKeyCache(Generic[T_key, T_value]): + """A cache that can store multiple values for the same key.""" + + _size: int + _cache: dict[T_key, list[T_value]] + _lru_keys: list[T_key] + + def __init__(self) -> None: + self._size = 0 + self._cache = {} + self._lru_keys = [] + + @overload + def pop(self, key: None = None) -> T_value: ... + + @overload + def pop(self, key: T_key) -> T_value | None: ... + + def pop(self, key: T_key | None = None) -> T_value | None: + """Pop the value for the given key from the cache. + + If no key is provided, pop the oldest key from the cache. + + Args: + key: The key to pop from the cache. If None, pop the oldest key from the cache. + + Returns: + The value popped from the cache. + """ + if key is None: + key = self._lru_keys.pop(0) + elif key not in self._cache: + return None + else: + self._lru_keys.pop(len(self._lru_keys) - 1 - self._lru_keys[::-1].index(key)) + + l = self._cache[key] + value = l.pop(0) + if len(l) == 0: + del self._cache[key] + self._size -= 1 + return value + + def add(self, key: T_key, value: T_value) -> None: + """Add a value to the cache.""" + if key not in self._cache: + self._cache[key] = [value] + else: + self._cache[key].insert(0, value) + + self._lru_keys.append(key) + self._size += 1 + + def flush(self) -> Generator[T_value, None, None]: + """Flush the cache.""" + for values in self._cache.values(): + for value in values: + yield value + self._cache.clear() + self._lru_keys.clear() + self._size = 0 + + def __len__(self) -> int: + return self._size diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index df4ad103..f067c370 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -1,14 +1,14 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple import torch from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import FlexState, SavableDataset -from megatron.energon.flavors.webdataset.itar_reader import ITarReader -from megatron.energon.flavors.webdataset.structs import FilteredSample +from megatron.energon.flavors.base_dataset import FlexState, RestoreKey, SavableDataset +from megatron.energon.flavors.webdataset.itar_reader import RawSampleReaderInterface +from megatron.energon.flavors.webdataset.structs import FilteredSample, WebdatasetRestoreKey from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig @@ -18,7 +18,7 @@ class RawSampleData: """Represents the iteration state of a single slice slice to the index.""" #: Index of the sample. This is also the restore key - __restore_key__: Tuple[str, int] + __restore_key__: WebdatasetRestoreKey #: The sample data data: Tuple[Optional[FilteredSample], ...] @@ -37,10 +37,11 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): """Internal class for loading samples from webdataset slices""" #: The readers for each joined dataset - join_readers: Sequence[ITarReader] + join_readers: Sequence[RawSampleReaderInterface[int]] - #: The offsets of the slice slices to iterate over for the current worker - slice_offsets: Optional[Sequence[int]] + #: The offsets of the slice slices to iterate over for each worker + # On worker initialization, this is set to _slice_offsets for the current worker. + workers_slice_offsets: Sequence[Sequence[int]] # If = 1, every sample is seen exactly once per epoch. If > 1, samples # (or rather slice slices) are shuffled within this number of epochs (i.e. randomly @@ -53,6 +54,9 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): # Worker's random generator _worker_rng: WorkerRng + #: The offsets of the slice slices to iterate over for the current worker + _slice_offsets: Optional[Sequence[int]] + #: The RNG state to be used for regenerating the pending slices _pending_slices_rng_state: Optional[FlexState] #: The number of slices that have already been opened / processed and thus been removed from the @@ -81,9 +85,11 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): "_epoch_sample_count", ) + _worker_local_fields = ("_slice_offsets",) + def __init__( self, - join_readers: Sequence[ITarReader], + join_readers: Sequence[RawSampleReaderInterface[int]], workers_sample_slice_offsets: Sequence[Sequence[int]], *, worker_config: WorkerConfig, @@ -95,7 +101,7 @@ def __init__( Args: join_readers: A sequence of the joined readers (or just a single reader) to iterate over. - worker_slice_offsets: The offsets of the slice slices to iterate over, for each worker. + workers_sample_slice_offsets: The offsets of the slice slices to iterate over, for each worker. worker_config: The worker configuration. shuffle_over_epochs: If None, disable shuffling. If = 1, every sample is seen exactly once per epoch. @@ -110,20 +116,15 @@ def __init__( super().__init__(worker_config=worker_config) self.join_readers = join_readers + self.workers_slice_offsets = workers_sample_slice_offsets self.shuffle_over_epochs = shuffle_over_epochs self.parallel_slice_iters = parallel_slice_iters - # Store the slices for all workers - # The slices for the current worker, will have to be extracted from this list later - self.workers_slice_offsets = workers_sample_slice_offsets - self.slice_offsets = None - - self.reset_state_own() - assert shuffle_over_epochs is None or shuffle_over_epochs == -1 or shuffle_over_epochs >= 1 assert self.parallel_slice_iters >= 1 - def reset_state_own(self) -> None: + def reset_state(self) -> None: + super().reset_state() self._worker_rng = WorkerRng(self.worker_config) self._pending_slice_indexes = None self._pending_slices_offset = None @@ -132,24 +133,21 @@ def reset_state_own(self) -> None: self._sample_count = 0 self._epoch_count = 0 self._epoch_sample_count = 0 - - def ensure_slice_offsets(self) -> None: - self.worker_config.assert_worker() - - if self.slice_offsets is None: - self.slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()] + self._slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()] + for reader in self.join_readers: + reader.worker_init() def _get_sample(self, index: int) -> RawSampleData: return RawSampleData( - __restore_key__=("Webdataset", index), + __restore_key__=WebdatasetRestoreKey(index=index), data=tuple(reader[index] for reader in self.join_readers), ) def _slices_once(self) -> List[int]: """Yields the indexes to slice offsets once. Possibly shuffles the list.""" - assert self.slice_offsets is not None + assert self._slice_offsets is not None - num_slices = len(self.slice_offsets) - 1 + num_slices = len(self._slice_offsets) - 1 slices_offset = self._pending_slices_offset if self.shuffle_over_epochs is None: @@ -197,17 +195,17 @@ def _slices_iter(self) -> Generator[RawSampleData, None, None]: """Iterates the samples in a list of slices, possibly using multiple parallel iterators over the slices.""" - assert self.slice_offsets is not None + assert self._slice_offsets is not None active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) active_slices = self._active_slice_state pending_slice_indexes = self._pending_slice_indexes def slice_at(idx: int) -> SliceState: - assert self.slice_offsets is not None + assert self._slice_offsets is not None return SliceState( index=idx, - current=self.slice_offsets[idx], + current=self._slice_offsets[idx], ) # Weight the slices by their size to get a more even distribution of samples @@ -223,8 +221,8 @@ def slice_at(idx: int) -> SliceState: for idx, slice_state in enumerate(active_slices): if slice_state is not None: active_slice_probs[idx] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + self._slice_offsets[slice_state.index + 1] + - self._slice_offsets[slice_state.index] ) if self.worker_config.should_log(level=1): @@ -282,8 +280,8 @@ def slice_at(idx: int) -> SliceState: self._pending_slices_offset += 1 slice_state = slice_at(slice_index) active_slice_probs[len(active_slices)] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + self._slice_offsets[slice_state.index + 1] + - self._slice_offsets[slice_state.index] ) active_slices.append(slice_state) # Fill up the slice iterators with None @@ -317,7 +315,7 @@ def slice_at(idx: int) -> SliceState: slice_state.current += 1 self._sample_count += 1 self._epoch_sample_count += 1 - if slice_state.current >= self.slice_offsets[slice_state.index + 1]: + if slice_state.current >= self._slice_offsets[slice_state.index + 1]: # Iterator exhausted -> take next / remove from list if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: if len(pending_slice_indexes) > 0: @@ -327,12 +325,12 @@ def slice_at(idx: int) -> SliceState: self._pending_slices_offset += 1 else: # Randomly select a new slice directly (with replacement) - num_slices = len(self.slice_offsets) - 1 + num_slices = len(self._slice_offsets) - 1 next_idx = self._worker_rng.randbelow(num_slices) next_slice_state = slice_at(next_idx) active_slice_probs[slice_idx] = ( - self.slice_offsets[next_slice_state.index + 1] - - self.slice_offsets[next_slice_state.index] + self._slice_offsets[next_slice_state.index + 1] + - self._slice_offsets[next_slice_state.index] ) active_slices[slice_idx] = next_slice_state # print( @@ -370,7 +368,7 @@ def slice_at(idx: int) -> SliceState: "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), - "index": sample.__restore_key__[1], + "index": sample.__restore_key__.index, "key": sample.data[0]["__key__"], "shard": sample.data[0]["__shard__"], "count": self._sample_count, @@ -411,15 +409,13 @@ def len_worker(self, worker_idx: int | None = None) -> int: def worker_has_samples(self) -> bool: self.worker_config.assert_worker() - self.ensure_slice_offsets() - assert self.slice_offsets is not None - return len(self.slice_offsets) > 1 + assert self._slice_offsets is not None + return len(self._slice_offsets) > 1 def __iter__(self) -> Iterator[RawSampleData]: self.worker_config.assert_worker() - self.ensure_slice_offsets() - assert self.slice_offsets is not None + assert self._slice_offsets is not None if self.worker_config.should_log(level=1): self.worker_config.worker_log( @@ -427,13 +423,13 @@ def __iter__(self) -> Iterator[RawSampleData]: "t": "WebdatasetSampleLoaderDataset.__iter__", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), - "slice_offsets": self.slice_offsets, + "slice_offsets": self._slice_offsets, "parallel_slice_iters": self.parallel_slice_iters, "shuffle_over_epochs": self.shuffle_over_epochs, } ) - if len(self.slice_offsets) <= 1: + if len(self._slice_offsets) <= 1: return yield from self._slices_iter() @@ -444,13 +440,23 @@ def can_restore_sample(self) -> bool: def assert_can_restore(self) -> None: pass - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> RawSampleData: - # Key is: ("Webdataset", index) + def restore_sample(self, restore_key: RestoreKey) -> RawSampleData: # The key is joined in the dataset's typed joining (i.e. load_sample of JoinedWebdatasetFactory). - id, index = restore_key - assert id == "Webdataset" - assert isinstance(index, int) - return self._get_sample(index) + assert isinstance(restore_key, WebdatasetRestoreKey) + assert isinstance(restore_key.index, int), ( + "WebdatasetRestoreKey.index must be an integer, cannot restore by sample key" + ) + return self._get_sample(restore_key.index) + + def worker_close(self) -> None: + for reader in self.join_readers: + reader.worker_close() + super().worker_close() + + def close(self) -> None: + for reader in self.join_readers: + reader.close() + super().close() def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/flavors/webdataset/structs.py b/src/megatron/energon/flavors/webdataset/structs.py index 1b8842cd..da9d9add 100644 --- a/src/megatron/energon/flavors/webdataset/structs.py +++ b/src/megatron/energon/flavors/webdataset/structs.py @@ -1,11 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import field -from typing import Dict, List, Optional, Tuple, TypedDict +from dataclasses import dataclass, field +from typing import Dict, List, Optional, TypedDict from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath +from megatron.energon.flavors.base_dataset import RestoreKey from megatron.energon.source_info import SourceInfo @@ -42,6 +43,15 @@ class ShardInfo: count: int +@dataclass(kw_only=True, slots=True, frozen=True) +class WebdatasetRestoreKey(RestoreKey): + """Restore key for the webdataset sample loader.""" + + #: The global index of the sample. Either an integer representing the global sample index, + # or a string representing the sample key within the shard. + index: int | str + + class FilteredSample(TypedDict): """This is just a definition for the internal loaders. Not exposed to the user.""" @@ -52,8 +62,7 @@ class FilteredSample(TypedDict): #: The base name of the shard file e.g. "shard_000" __shard__: str #: Globally unique key to restore a sample from disk. - #: For example `("Webdataset", 123)` would restore the sample at index 123. - __restore_key__: Tuple[str, int] + __restore_key__: WebdatasetRestoreKey #: The source information for the sample. __sources__: tuple[SourceInfo, ...] diff --git a/src/megatron/energon/loader.py b/src/megatron/energon/loader.py index b973b672..48f5cfb0 100644 --- a/src/megatron/energon/loader.py +++ b/src/megatron/energon/loader.py @@ -1,12 +1,17 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, TypeVar +from typing import Literal, Optional, TypeVar from megatron.energon.cache import CachePool +from megatron.energon.dataloader import ( + DataLoader, + DataLoaderWorker, + ForkDataLoaderWorker, + ThreadDataLoaderWorker, +) from megatron.energon.deprecation import warn_deprecated from megatron.energon.flavors import SavableDataset -from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER @@ -17,16 +22,16 @@ def get_savable_loader( dataset: SavableDataset[T], *, worker_config: Optional[WorkerConfig] = None, - checkpoint_every_sec: float = 60, - checkpoint_every_min_n_samples: Optional[int] = None, - n_checkpoints: Optional[int] = None, + worker_type: Literal["main", "fork", "thread"] | type[DataLoaderWorker] = "fork", + gc_freeze_at_start: bool = True, gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, prefetch_factor: int = 2, cache_pool: Optional[CachePool] = None, watchdog_timeout_seconds: Optional[float] = 60, watchdog_initial_timeout_seconds: Optional[float] = None, fail_on_timeout: bool = False, -) -> SavableDataLoader[T]: + pin_memory: bool = True, +) -> DataLoader[T]: """ Get a dataloader for the given dataset. @@ -34,21 +39,21 @@ def get_savable_loader( Args: dataset: The dataset to create a loader for. worker_config: Deprecated. Please pass this to the dataset instead. - checkpoint_every_sec: This is the time in seconds after which an internal checkpoint is - saved. It may take the same duration to restore a checkpoint, but introduces additional - overhead during reading data from the dataset, so this should be chosen accordingly. - Only applies if using workers. - checkpoint_every_min_n_samples: Overwrites the minimum number of samples between - checkpoints. Defaults to `number of workers * 2`. Only applies if using workers. - n_checkpoints: The number of internal checkpoints to keep. Only applies if using workers. - If None, computes a suitable value. + worker_type: The type of worker to use. Options: + "fork": forked workers (default), + "thread": threaded workers (should be used with free-threaded python), + "main": iterate data in the main process without parallelization. + gc_freeze_at_start: If True, the garbage collector is frozen at the start of the loader. + gc_collect_every_n_steps: The number of steps after which the garbage collector is called. + prefetch_factor: The factor by which to prefetch the dataset. cache_pool: If set, the cache pool to use for the dataset. watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + pin_memory: If True, the data iterated by the dataset is pinned to memory, such that it can be quickly used by CUDA. + Returns: - The instantiated :class:`megatron.energon.SavableDataLoader`, yielding batches from the dataset, - allowing to save the state of the dataset. + The instantiated :class:`megatron.energon.DataLoader`, yielding batches from the dataset. """ if worker_config is not None: if worker_config != dataset.worker_config: @@ -61,17 +66,32 @@ def get_savable_loader( "Passing a worker_config to get_savable_loader() is deprecated and will have no effect." ) - return SavableDataLoader( + if worker_type == "fork": + worker_type = ForkDataLoaderWorker + elif worker_type == "thread": + worker_type = ThreadDataLoaderWorker + elif worker_type == "main": + worker_type = DataLoaderWorker + elif not issubclass(worker_type, DataLoaderWorker): + raise ValueError(f"Invalid worker type: {worker_type}") + if dataset.worker_config.num_workers == 0: + assert prefetch_factor == 2 + prefetch_factor = 1 + pin_memory_arg = None + else: + pin_memory_arg = "automatic" if pin_memory else None + + return DataLoader( dataset, - checkpoint_every_sec=checkpoint_every_sec, - checkpoint_every_min_n_samples=checkpoint_every_min_n_samples, - n_checkpoints=n_checkpoints, - gc_collect_every_n_steps=gc_collect_every_n_steps, prefetch_factor=prefetch_factor, + worker_type=worker_type, cache_pool=cache_pool, + gc_collect_every_n_steps=gc_collect_every_n_steps, + gc_freeze_at_start=gc_freeze_at_start, watchdog_timeout_seconds=watchdog_timeout_seconds, watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds, fail_on_timeout=fail_on_timeout, + pin_memory=pin_memory_arg, ) @@ -84,7 +104,7 @@ def get_loader( watchdog_timeout_seconds: Optional[float] = 60, watchdog_initial_timeout_seconds: Optional[float] = None, fail_on_timeout: bool = False, -) -> BasicDataLoader[T]: +) -> DataLoader[T]: """ Get a dataloader for the given dataset. @@ -95,8 +115,9 @@ def get_loader( watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + Returns: - The instantiated :class:`torch.data.DataLoader`, yielding batches from the dataset. + The instantiated :class:`DataLoader`, yielding batches from the dataset. """ if worker_config is not None: if worker_config != dataset.worker_config: @@ -109,11 +130,19 @@ def get_loader( "Passing a worker_config to get_loader() is deprecated and will have no effect." ) - return BasicDataLoader( + if dataset.worker_config.num_workers == 0: + assert prefetch_factor == 2 + prefetch_factor = 1 + pin_memory = None + else: + pin_memory = "automatic" + + return DataLoader( dataset, prefetch_factor=prefetch_factor, cache_pool=cache_pool, watchdog_timeout_seconds=watchdog_timeout_seconds, watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds, fail_on_timeout=fail_on_timeout, + pin_memory=pin_memory, ) diff --git a/src/megatron/energon/retry_stream.py b/src/megatron/energon/retry_stream.py deleted file mode 100644 index b678dddf..00000000 --- a/src/megatron/energon/retry_stream.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -from types import TracebackType -from typing import BinaryIO, Optional, Type - -from megatron.energon.epathlib import EPath - - -class RetryReadStream(BinaryIO): - """A stream that retries reading from a file. Only supports reading bytes.""" - - _path: EPath - _file: Optional[BinaryIO] - _pos: int - _size: int - - def __init__(self, path: EPath): - """Construct a RetryReadStream. It reads only bytes from a file.""" - self._path = path - self._file = None - self._pos = 0 - self._size = path.size() - - def __enter__(self) -> "RetryReadStream": - return self - - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType - ) -> None: - self.close() - - def close(self) -> None: - if self._file is not None: - self._file.close() - - def read(self, n: int = -1) -> bytes: - buf = b"" - for retry in range(10): - try: - if self._file is None: - self._file = self._path.open("rb") - self._file.seek(self._pos) - res = self._file.read(n) - self._pos += len(res) - buf += res - if ( - (n == -1 and self._pos >= self._size) - or len(buf) == n - or self._pos >= self._size - ): - return res - except IOError: - try: - self._file.close() - except IOError: - pass - self._file = None - if retry == 9: - raise - continue - - def seek(self, offset: int, whence: int = 0) -> int: - if whence == 0: - pass - elif whence == 1: - offset += self._pos - elif whence == 2: - offset += self._size - else: - raise ValueError(f"Invalid whence value: {whence}") - offset = min(max(offset, 0), self._size) - self._pos = offset - try: - if self._file is not None: - self._file.seek(offset) - except IOError: - pass - return self._pos - - def tell(self) -> int: - return self._pos - - def isatty(self) -> bool: - return False - - def readable(self) -> bool: - return True - - def seekable(self) -> bool: - return True - - def writable(self) -> bool: - return False diff --git a/src/megatron/energon/rng.py b/src/megatron/energon/rng.py index b2991580..2773af08 100644 --- a/src/megatron/energon/rng.py +++ b/src/megatron/energon/rng.py @@ -17,6 +17,31 @@ T = TypeVar("T") +@edataclass +class WorkerRngState: + rng: Any + + def _hashable_value(self, value: Any) -> Any: + if isinstance(value, (int, float, bool, str)) or value is None: + return value + elif isinstance(value, torch.Tensor): + return self._hashable_value(value.tolist()) + elif isinstance(value, numpy.ndarray): + return self._hashable_value(value.tolist()) + elif isinstance(value, Mapping): + return tuple( + (self._hashable_value(k), self._hashable_value(v)) for k, v in value.items() + ) + elif isinstance(value, Sequence): + return tuple(self._hashable_value(v) for v in value) + else: + raise ValueError(f"Cannot hash value of type {type(value)}: {value!r}") + + def __repr__(self): + # If the hash is the same, the state is the same. Should suffice to identify the state. + return f"WorkerRngState(hash={hash(self._hashable_value((self.rng)))})" + + class WorkerRng(Savable): """Helper class for getting a worker random generator, which is still in itself deterministic. If not in a worker, uses the global random generator's seed to initialize a new rng.""" @@ -79,14 +104,57 @@ def shuffle(self, l: List[T]) -> List[T]: def rand_pop(self, l: List[T]) -> T: return l.pop(self.randbelow(len(l))) - def save_state(self) -> FlexState: - return FlexState(rng=None if self.rng is None else bytes(self.rng.get_state().tolist())) + def save_state(self) -> WorkerRngState: + return WorkerRngState( + rng=None if self.rng is None else bytes(self.rng.get_state().tolist()) + ) - def restore_state(self, state: FlexState): - if state["rng"] is None: + def restore_state(self, state: WorkerRngState): + if state.rng is None: self._restore_state = None else: - self._restore_state = state["rng"] + self._restore_state = state.rng + + +class UserRng: + """User random generators. To be used within the task encoder, providing local seeding.""" + + def __init__(self, seed: int): + self.torch = torch.Generator() + self.torch.manual_seed(seed) + if torch.cuda.is_available(): + self.torch_cuda = torch.Generator(device="cuda") + self.torch_cuda.manual_seed(seed) + + self.numpy = numpy.random.default_rng(seed=seed) + self.random = random.Random(seed) + + def seed(self, seed: int) -> None: + self.torch.manual_seed(seed) + if torch.cuda.is_available(): + self.torch_cuda.manual_seed(seed) + self.numpy.bit_generator.state = numpy.random.default_rng(seed).bit_generator.state + self.random.seed(seed) + + def seed_args(self, *args: Any) -> None: + self.seed(SystemRng.get_seed_from_args(*args)) + + def save_state(self) -> FlexState: + state = FlexState( + torch=self.torch.get_state().tolist(), + numpy=self.numpy.bit_generator.state, + random=self.random.getstate(), + ) + if torch.cuda.is_available(): + state["torch_cuda"] = self.torch_cuda.get_state().tolist() + return state + + def restore_state(self, state: FlexState): + self.torch.set_state(torch.as_tensor(state["torch"])) + if torch.cuda.is_available(): + self.torch_cuda.set_state(torch.as_tensor(state["torch_cuda"], device="cuda")) + self.numpy.bit_generator.state = state["numpy"] + self.random.setstate(state["random"]) @edataclass @@ -169,4 +237,5 @@ def get_seed_from_args(*args: Any) -> int: @staticmethod def seed_args(*args: Any) -> None: """Seeds the global random generators deterministically from the given arguments.""" + # print(f"Seeding with args: {args}") SystemRng.seed(SystemRng.get_seed_from_args(*args)) diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py deleted file mode 100644 index 06ce51ce..00000000 --- a/src/megatron/energon/savable_loader.py +++ /dev/null @@ -1,1409 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -import gc -import multiprocessing -import queue -import threading -import time -import traceback -import weakref -from functools import partial -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -import numpy as np -import torch -import torch.distributed -import torch.multiprocessing -from torch.utils.data import DataLoader, IterableDataset - -from megatron.energon.cache import CachePool -from megatron.energon.deprecation import deprecated -from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import ( - FlexState, - SavableDataset, - State, - add_sample_restore_key, -) -from megatron.energon.rng import SystemRng, SystemRngState -from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset -from megatron.energon.wrappers.batch_dataset import BatchDataset -from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.log_sample_dataset import default_get_batch_keys -from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset - -T = TypeVar("T") - - -def _init_worker(seed_per_worker: List[int], worker_id: int): - """Initializes the the worker process. - - Sets the random seeds and prepare EPath for the forked worker process. - """ - gc_init_worker(worker_id) - - worker_seed = seed_per_worker[worker_id] - - SystemRng.seed(worker_seed) - - -class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]): - """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is - not intended to be used directly.""" - - #: The cache pool to use for the dataset. - cache_pool: CachePool - - _state_restored: bool - _sample_index: int - - _savable_fields = ("_sample_index",) - - def __init__( - self, dataset: SavableDataset[T], worker_config: WorkerConfig, cache_pool: CachePool - ): - """ - Args: - dataset: The dataset to wrap. - worker_config: The worker config to use for the dataset. - cache_pool: The cache pool to use for the dataset. - """ - super().__init__(dataset, worker_config=worker_config) - self.cache_pool = cache_pool - - self.reset_state_own() - - def reset_state_own(self) -> None: - self._sample_index = 0 - self._state_restored = False - - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - - @property - def __len__(self): - # Note: This disables hasattr(self, "__len__"), because that attr will - raise AttributeError("Disabled direct length access to avoid DataLoader warnings.") - - def __iter__(self): - self._state_restored = True - worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - while self._state_restored: - self._state_restored = False - self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - sample_index = self._sample_index - src_data = add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - yield worker_id, sample_index, src_data - if self._state_restored: - # Restart iterator after restore - break - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool - ) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, - ) - finally: - self.worker_config.worker_deactivate() - - def config(self) -> Dict[str, Any]: - return self.dataset.config() - - def __str__(self): - return f"SimpleSavableDatasetWrapper(dataset={self.dataset})" - - -@edataclass -class SavableDatasetState(State): - """State of the dataset wrapper. It stores the global random states and the index of the next - sample to be returned from the dataset. This class is not intended to be used directly, but by - :class:`megatron.energon.SavableDatasetWrapper`.""" - - #: The state of all the system random number generators - rng: SystemRngState - #: The state of the savable dataset - dataset_state: FlexState - #: Index of the next sample to be returned from the dataset - sample_index: int - - def __repr__(self): - return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" - - -@edataclass -class SavableCheckpoint: - """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created - regularly to be able to save the state of the dataset wrapper before the currently emitted - sample. - """ - - #: The state of the wrapper - state: Optional[SavableDatasetState] - #: The time at which the checkpoint was created - checkpoint_time: float - #: Index of the next sample to be returned from the dataset after restoring the checkpoint - sample_index: int - - -@edataclass -class SavableDatasetCheckpoint(State): - """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. The checkpoint state - represents a state before that checkpoint, with an offset (i.e. samples to be skipped).""" - - #: The state of the wrapper at the sample index when the checkpoint was created. - state: Optional[SavableDatasetState] - #: Offset of the checkpoint to the actual sample index to be restored. - offset: int - - -class SavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]): - """Internal class for wrapping a savable dataset for a worker process. Provides communication - with the :class:`megatron.energon.SavableDataLoader`. This class is not intended to be used directly. - See :class:`megatron.energon.SavableDataLoader` for more information.""" - - #: The wrapped dataset - dataset: SavableDataset[T] - #: The configuration of the worker process - worker_config: WorkerConfig - #: The time interval in seconds to wait at minimum between two checkpoints - checkpoint_every_sec: float - #: The minimum number of samples to be emitted between two checkpoints. Should be `number of - # workers * 2`. - checkpoint_every_min_n_samples: int - #: The number of checkpoints to keep in memory, before discarding. Should be 2. - n_checkpoints: int - #: The cache pool to use for the dataset. - cache_pool: CachePool - #: The queue of the worker process to receive commands from the `SavableDataLoader`. - _cmd_queues: List[torch.multiprocessing.Queue] - #: The queue of the worker process to send results to the `SavableDataLoader`. - _result_queues: List[torch.multiprocessing.Queue] - - _sample_index: int = 0 - _worker_offset: int = 0 - _last_checkpoints: List[SavableCheckpoint] - - _workers_restore_from: List[Optional[SavableDatasetState]] = list() - _workers_skip_samples: List[int] - - _running: bool = False - _command_lock: Optional[threading.RLock] = None - _cmd_thread: Optional[threading.Thread] = None - - def __init__( - self, - dataset: SavableDataset[T], - worker_config: WorkerConfig, - checkpoint_every_sec: float, - checkpoint_every_min_n_samples: int, - n_checkpoints: int = 2, - *, - cmd_queues: List[torch.multiprocessing.Queue], - result_queues: List[torch.multiprocessing.Queue], - cache_pool: CachePool, - ): - """ - Create the savable dataset wrapper for multiprocessing data loading. - - Args: - dataset: The dataset to wrap - worker_config: The worker config as used by all datasets - checkpoint_every_sec: The time interval in seconds to wait at minimum between two - checkpoints. - checkpoint_every_min_n_samples: The minimum number of samples to be emitted between - two checkpoints. Should be `number of workers * 2`. - n_checkpoints: Number of checkpoints to keep. - cmd_queues: The command queues for communicating with the worker processes. - result_queues: The result queues for communicating with the worker processes. - cache_pool: The cache pool to use for the dataset. - """ - num_workers = max(worker_config.num_workers, 1) - - self.dataset = dataset - self.worker_config = worker_config - self.checkpoint_every_sec = checkpoint_every_sec - self.checkpoint_every_min_n_samples = checkpoint_every_min_n_samples - self.n_checkpoints = n_checkpoints - self._last_checkpoints = [ - SavableCheckpoint(state=None, checkpoint_time=time.perf_counter(), sample_index=-1) - ] - self._workers_restore_from = [None] * num_workers - self._workers_skip_samples = [0] * num_workers - self._cmd_queues = cmd_queues - self._result_queues = result_queues - self.cache_pool = cache_pool - - @staticmethod - def _command_thread(self: "SavableDatasetWrapper"): - """The internal thread, which processes the command and result queues. This thread is - static, because `self` is actually passed as weakref proxy, to avoid keeping the dataset - alive via the thread. - """ - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting") - assert self._command_lock is not None - - try: - while self._running: - try: - cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) - except queue.Empty: - continue - # print(f"recv cmd {cmd_args}") - with self._command_lock: - cmd = cmd_args[0] - if cmd is None: - break - try: - fn = getattr(self, cmd) - self._result_queues[self._worker_id].put( - {self._worker_id: fn(*cmd_args[1:])} - ) - # print(f"result sent") - except Exception as e: - traceback.print_exc() - self._result_queues[self._worker_id].put({self._worker_id: e}) - # print(f"exc sent") - except BaseException: - traceback.print_exc() - raise - finally: - pass - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") - - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - - def len_rank(self): - return self.dataset.len_rank() - - @property - def __len__(self): - # Note: This disables hasattr(self, "__len__"), because that attr will - raise AttributeError("Disabled direct length access to avoid DataLoader warnings.") - - def __del__(self): - if self._cmd_thread is not None: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Closing cmd thread") - self._running = False - self._cmd_thread.join() - self._command_lock = None - self._cmd_thread = None - # print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed") - - def __iter__(self): - # First: Set the worker offset globally for the current worker - WorkerConfig.worker_id_offset = self._worker_offset - self._worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - if self._cmd_thread is None: - self._running = True - self._command_lock = threading.RLock() - weakref_self = weakref.proxy(self) - self._cmd_thread = threading.Thread( - target=SavableDatasetWrapper._command_thread, - name="command_thread", - args=(weakref_self,), - daemon=True, - ) - self._cmd_thread.start() - # atexit.register(lambda: weakref_self.__del__()) - try: - assert self._command_lock is not None - with self._command_lock: - if self._workers_restore_from[self._worker_id] is not None: - my_state = self._workers_restore_from[self._worker_id] - my_ds_state = my_state.dataset_state - assert my_state is not None - if my_ds_state is None: - self.dataset.reset_state_deep() - else: - self.dataset.restore_state(my_ds_state) - self._restore_state(my_state) - self._workers_restore_from[self._worker_id] = None - else: - # Store the initial state of the worker if we stop before the first sample - self._store_checkpoint() - # If skipping, also restart the iterator to reach the start of the restored - # checkpoint - last_was_skip = True - while last_was_skip: - dataset_has_samples = False - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - dataset_has_samples = True - if self._workers_skip_samples[self._worker_id] > 0: - # Skip ahead to reach the start of the restored checkpoint - # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") - self._workers_skip_samples[self._worker_id] -= 1 - self._sample_index += 1 - last_was_skip = True - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - continue - last_was_skip = False - sample_index = self._sample_index - add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - self._store_checkpoint() - try: - self._command_lock.release() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") - # Commands may be executed only when data was yielded, not during - # iteration fetching. - # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") - yield self._worker_id, sample_index, src_data - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") - self._command_lock.acquire() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - # If the dataset is empty, don't try again and again - if not dataset_has_samples: - break - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") - # Always store a final checkpoint (it's likely to be saved) - self._store_checkpoint(force=True) - - def _store_checkpoint(self, force: bool = False) -> None: - """ - Internally create a checkpoint for the current state. This is required to store states - from the past, which have already been yielded here, but not yet been retrieved from the - intermediate queues. - - Args: - force: If true, ignore time or frequency condition. - """ - if ( - force - or ( - self._last_checkpoints[-1].checkpoint_time + self.checkpoint_every_sec - < time.perf_counter() - and self._last_checkpoints[-1].sample_index + self.checkpoint_every_min_n_samples - <= self._sample_index - ) - or self._sample_index <= 1 - ): - # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") - self._last_checkpoints.append( - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, - ) - ) - if len(self._last_checkpoints) > self.n_checkpoints: - self._last_checkpoints.pop(0) - - def _save_state(self) -> SavableDatasetState: - """Saves the internal state""" - return SavableDatasetState( - rng=SystemRng.save_state(), - dataset_state=self.dataset.save_state(), - sample_index=self._sample_index, - ) - - def _restore_state(self, state: SavableDatasetState) -> None: - """Restores the internal worker state""" - assert torch.utils.data.get_worker_info() is not None, "Can only restore in worker process" - if state.rng is None: - SystemRng.seed(torch.initial_seed() & 0xFFFFFFFF) - else: - SystemRng.restore_state(state.rng) - - self._sample_index = state.sample_index - self._last_checkpoints = [ - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, - ) - ] - - def get_checkpoint(self, last_sample_indexes: List[int]) -> SavableDatasetCheckpoint: - """ - Get a checkpoint given the last emitted sample indexes for all workers. - - Args: - last_sample_indexes: The last emitted sample indexes for all workers. - - Returns: - The found checkpoint including the offset to the next sample index - """ - sample_index = last_sample_indexes[self._worker_id] + 1 - for checkpoint in reversed(self._last_checkpoints): - if checkpoint.sample_index <= sample_index: - # print(f"Found cp for {sample_index} at {checkpoint.sample_index}") - return SavableDatasetCheckpoint( - state=checkpoint.state, - offset=sample_index - checkpoint.sample_index, - ) - - # Immediate save after restore - if len(self._last_checkpoints) == 0: - return SavableDatasetCheckpoint( - state=self._workers_restore_from[self._worker_id], - offset=self._workers_skip_samples[self._worker_id], - ) - raise ValueError("No checkpoint found") - - def restore_checkpoint( - self, - worker_states: Optional[List[SavableDatasetCheckpoint]], - worker_offset: int, - ) -> None: - """ - Restores the merged checkpoint from all worker processes. - - Args: - worker_states: The state to restore for each worker - worker_offset: The offset of the last worker which has emitted a sample. This will be - set in all worker processes to ensure the right worker starts as first. - """ - assert torch.utils.data.get_worker_info() is None, "Cannot restore in worker process" - num_workers = max(self.worker_config.num_workers, 1) - - if worker_states is None: - self._workers_restore_from = [None] * num_workers - assert worker_offset == 0 - self._worker_offset = 0 - self._workers_skip_samples = [0] * num_workers - else: - assert isinstance(worker_states, list) - assert len(worker_states) == num_workers - assert isinstance(worker_states[0], SavableDatasetCheckpoint) - - self._worker_offset = worker_offset - - # Tear the state_list apart (which has len=num_workers) - # and store the states in the internal arrays - self._workers_restore_from = [state.state for state in worker_states] - self._workers_skip_samples = [state.offset for state in worker_states] - - def get_initial_checkpoint(self) -> Optional[List[SavableDatasetCheckpoint]]: - """ - Get the initial checkpoint for all worker processes if they have not started yet. - - Returns: - The initial checkpoint for all worker processes and the worker offset. - """ - assert torch.utils.data.get_worker_info() is None, ( - "Cannot get initial checkpoint in worker process" - ) - if all(s is None for s in self._workers_restore_from): - assert all(s == 0 for s in self._workers_skip_samples) - # Initial state, no checkpoint - return None - - return [ - SavableDatasetCheckpoint( - state=state, - offset=offset, - ) - for state, offset in zip(self._workers_restore_from, self._workers_skip_samples) - ] - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, - ) - finally: - self.worker_config.worker_deactivate() - - def config(self) -> Dict[str, Any]: - return self.dataset.config() - - def __str__(self): - return f"SavableDatasetWrapper(dataset={self.dataset})" - - -@edataclass -class SavableDataLoaderState(State): - """Saved state of the :class:`megatron.energon.SavableDataLoader`. Contains the state for all worker - processed of a single rank.""" - - #: The internal state of the dataset (for each worker process) - worker_states: List[Union[SavableDatasetCheckpoint, FlexState]] - #: Which worker will be the next to emit a sample. Used to restore the proper order - next_worker_id: int - - #: The micro batch size that was used, if available. - #: On restore, this is used to compare the new and old micro batch size. - micro_batch_size: Optional[int] - - -class SavableDataLoader(DataLoader[T], Generic[T]): - """DataLoader that supports saving and restoring the state of the dataset. - When restoring, the dataloader and dataset must be instantiated with the exactly same - parameters. - - How this works (for no worker processes) - ---------------------------------------- - - 1. The state of the dataset is saved using :meth:`megatron.energon.SavableDataset.save_state` - 2. (for compatibility) The state of the dataset is converted to using inner arrays using - :meth:`megatron.energon.SavableDataset.merge_states`. - 3. The state can be restored using :meth:`megatron.energon.SavableDataset.restore_state` given the - previously saved (and merged) state. - - How this works (for worker processes) - ------------------------------------- - - - First issue is, that worker processes work with internal queues between processes to pass - loaded samples to the main process (also to perform collating). This means that the whole - state of the dataset is not directly accessible from the main process. - - To solve this issue, the dataset regularly saves a checkpoint of its state to be able to - resume from that state (and skip the samples that have already been yielded). - - To have a consistent state, the sample index from the latest yielded samples is saved for all - worker instances. Thus, the main process knows exactly which sample indexes should come next - from which worker. - - Internally, pytorch iterates through the workers in order to retrieve the next worker's - samples. Unfortunately, that next worker index cannot be restored in pytorch's dataloader, - thus the workers are shifted internally by that offset - (see :attr:`megatron.energon.WorkerConfig.worker_id_offset`). - - 1. The dataset is wrapped in a :class:`megatron.energon.SavableDatasetWrapper`. This allows the main - process to communicate with the worker and send commands to the workers and retrieve the - results. - 2. The state of the dataset is saved using - :meth:`megatron.energon.SavableDatasetWrapper.get_checkpoint`. This gives the last checkpoint - from the requested sample index and stores the offset (i.e. number of samples to skip) from - that checkpoint. - 3. The state is merged using :meth:`megatron.energon.SavableDatasetWrapper.merge_checkpoints`. This - merges the states of all workers and returns a single state that can be used to restore the - state of the dataset. - 4. The state can be restored using :meth:`megatron.energon.SavableDatasetWrapper.restore_state` - before a worker is started, such that all workers initially receive the same state array. - The worker firstly sets the worker index offset, then uses its (shifted) own index to get its - required state from the merged state array. - - """ - - #: The worker config - worker_config: WorkerConfig - #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]] - - #: The global ID counter - _next_id: ClassVar[int] = 0 - #: Class instance id - id: int = 0 - - #: The queues used to send commands to the workers - cmd_queues: List[torch.multiprocessing.Queue] - #: The queues used to receive results from the workers - result_queues: List[torch.multiprocessing.Queue] - - #: Instance of the current data iterator. There shall be only one active iterator, such that the - # dataset is not iterated multiple times in parallel. The state will continue between epochs. - _epoch_iterator: Optional[Iterator[T]] = None - #: Whether the dataloader has running workers. - _has_workers: bool = False - #: The index of the current worker. -1 if not started yet. - _worker_sample_counters: List[int] - #: Id of the next worker to retrieve data from - _next_worker_id: int = 0 - #: Global index of the last yielded sample - _global_sample_idx: int = 0 - #: Current iterator index of the last yielded sample - _sample_idx: int = 0 - - def __init__( - self, - dataset: SavableDataset[T], - *, - checkpoint_every_sec: float = 60, - checkpoint_every_min_n_samples: Optional[int] = None, - n_checkpoints: Optional[int] = None, - gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, - gc_freeze_at_start: bool = True, - prefetch_factor: int = 2, - cache_pool: Optional[CachePool] = None, - watchdog_timeout_seconds: Optional[float] = 60, - watchdog_initial_timeout_seconds: Optional[float] = None, - fail_on_timeout: bool = False, - ): - """ - Create the dataloader supporting saving and restoring the state. - - Args: - dataset: The dataset to load. - worker_config: The worker config to use - checkpoint_every_sec: This is the time in seconds after which a checkpoint is saved. - It may take the same duration to restore a checkpoint, but introduces additional - overhead during reading data from the dataset, so this should be chosen accordingly. - Only applies if using workers. - checkpoint_every_min_n_samples: Overwrites the minimum number of samples between - checkpoints. Defaults to `number of workers * 2`. Only applies if using workers. - n_checkpoints: The number of checkpoints to keep in memory. Only applies if using - workers. If None, computes a suitable value. - gc_collect_every_n_steps: The number of steps after which the garbage collector is - called. As we're usually handling large (but few) tensors here, and the python - garbage collection is already full of objects just by importing, this can improve - the memory footprint quite a lot, and may even be necessary to avoid memory - overflow. - gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker - processes. This improves the garbage collection performance by a lot. - In rare cases, this may cause issues and can be disabled. Keep enabled if you - experience no issues. - cache_pool: If set, the cache pool to use for the dataset. - watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. - watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. - fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - """ - self.worker_config = dataset.worker_config - self.id = self.next_id() - - dataset = WatchdogDataset( - dataset, - worker_config=self.worker_config, - timeout_seconds=watchdog_timeout_seconds, - initial_timeout_seconds=watchdog_initial_timeout_seconds, - fail_on_timeout=fail_on_timeout, - ) - - if gc_collect_every_n_steps > 0: - dataset = GcDataset( - dataset, - worker_config=self.worker_config, - every_n_iter=gc_collect_every_n_steps, - freeze=gc_freeze_at_start, - ) - - self.cmd_queues = [multiprocessing.Queue() for _ in range(self.worker_config.num_workers)] - self.result_queues = [ - multiprocessing.Queue() for _ in range(self.worker_config.num_workers) - ] - - num_procs = max(self.worker_config.num_workers, 1) - - if n_checkpoints is None: - n_checkpoints = prefetch_factor * num_procs + 1 - - if self.worker_config.num_workers > 0: - if checkpoint_every_min_n_samples is None: - checkpoint_every_min_n_samples = self.worker_config.num_workers * 2 - - dataset = SavableDatasetWrapper( - dataset, - self.worker_config, - checkpoint_every_sec=checkpoint_every_sec, - checkpoint_every_min_n_samples=checkpoint_every_min_n_samples, - n_checkpoints=n_checkpoints, - cmd_queues=self.cmd_queues, - result_queues=self.result_queues, - cache_pool=cache_pool, - ) - else: - dataset = SimpleSavableDatasetWrapper( - dataset, self.worker_config, cache_pool=cache_pool - ) - - self._worker_sample_counters = [-1] * num_procs - - kwargs = {} - if self.worker_config.num_workers > 0: - kwargs["persistent_workers"] = True - kwargs["prefetch_factor"] = prefetch_factor - - # Assert that prefetch_factor works well with num_checkpoints. - # This ensures that the oldest checkpoint is old enough to cover - # all the buffered samples in the torch dataloader. - assert prefetch_factor * num_procs + 1 <= n_checkpoints, ( - "When increasing prefetch_factor, also increase n_checkpoints, so that " - "the number of checkpoints is at least as large as num_workers * prefetch_factor + 1" - ) - - # Compute seeds for each worker, based on current rank - seed_per_worker = [ - self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) - ] - - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) - - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": dataset.config(), - } - ) - - @staticmethod - def next_id() -> int: - next_id = SavableDataLoader._next_id - SavableDataLoader._next_id += 1 - return next_id - - def __len__(self): - # We override this, because otherwise we'll see warnings - return self.dataset.len_rank() - - def _epoch_iter(self): - """Iterator for one epoch, i.e. until the inner dataset raises StopIteration.""" - iter_idx = 0 - id = self.next_id() - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableDataLoader.iter", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - try: - for worker_id, sample_idx, sample in super().__iter__(): - self._worker_sample_counters[worker_id] = sample_idx - # If the next sample will be from the first worker, we can safely resume - self._next_worker_id = (worker_id + 1) % max(self.num_workers, 1) - # self._debugf.write( - # f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n" - # ) - # self._debugf.flush() - if self.worker_config.should_log(level=1): - keys = default_get_batch_keys(sample) - self.worker_config.worker_log( - { - **{ - "t": "SavableDataLoader.yield", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": self._sample_idx, - "iter_idx": iter_idx, - "global_idx": self._global_sample_idx, - }, - **({} if keys is None else {"keys": keys}), - } - ) - self._sample_idx += 1 - self._global_sample_idx += 1 - iter_idx += 1 - yield sample - self._epoch_iterator = None - self._next_worker_id = 0 - finally: - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableDataLoader.StopIteration", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": self.id, - } - ) - - def __iter__(self): - if self.num_workers > 0: - # Always keep same iterator alive, as long as it yields data - if self._epoch_iterator is None: - self._epoch_iterator = self._epoch_iter() - self._sample_idx = 0 - self._has_workers = True - # print("New Iterator", self._persistent_iterator) - return self._epoch_iterator - else: - return self._epoch_iter() - - def _worker_command(self, *cmd_args) -> List[Any]: - """Executes a command in all workers and returns the results.""" - # print(f"cmd: {cmd_args}") - for cmd_queue in self.cmd_queues: - cmd_queue.put(cmd_args) - # print(f"waiting for res") - assert len(self.result_queues) == self.worker_config.num_workers - res = {k: v for results_queue in self.result_queues for k, v in results_queue.get().items()} - res = [res[i] for i in range(len(res))] - # print(f"res: {res}") - for r in res: - if isinstance(r, Exception): - raise r - return res - - def _get_batch_size(self) -> Optional[int]: - """Try to infer micro batch size from the dataset""" - if isinstance(self.dataset, (SavableDatasetWrapper, SimpleSavableDatasetWrapper)): - dataset = self.dataset.dataset - else: - dataset = self.dataset - - if ( - isinstance(dataset, BaseWrapperDataset) - and (bds := dataset._find_wrapped_dataset(BatchDataset)) is not None - ): - assert isinstance(bds, BatchDataset) - return bds.batch_size - else: - return None - - def save_state_rank(self) -> Optional[SavableDataLoaderState]: - """ - Saves the state of the dataset for the current rank. Allows for restoring the state later - using `restore_state_rank`, given the result of this method. - - Returns: - The state of the dataset. - """ - # Fetch current rank's worker's state - if self.num_workers == 0: - # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapper) - worker_states = [self.dataset.save_state()] - assert self._next_worker_id == 0 - elif self._has_workers: - # Fetch from worker processes - worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) - else: - # Workers configured, but not started yet. - # If a state has already been restored, it will be returned. - assert isinstance(self.dataset, SavableDatasetWrapper) - worker_states = self.dataset.get_initial_checkpoint() - - if worker_states is None: - return None - - # Merge the states - merged_state = SavableDataLoaderState( - worker_states=worker_states, - next_worker_id=self._next_worker_id, - micro_batch_size=self._get_batch_size(), - ) - - # Not distributed -> return the merged state - return merged_state - - def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: - """ - Restores the saved state for the current rank. - - Args: - state: The state to restore, as saved by `save_state_rank`. - """ - assert not self._has_workers, "Cannot restore state while workers are running" - if state is None: - # Assume initial state - return - assert isinstance(state, SavableDataLoaderState) - - old_micro_batch_size = state.micro_batch_size - micro_batch_size = self._get_batch_size() - - if self.num_workers == 0: - # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapper) - assert micro_batch_size == old_micro_batch_size, ( - "Changing micro batch size is not allowed without workers" - ) - - assert len(state.worker_states) == 1 - assert isinstance(state.worker_states[0], FlexState) - self.dataset.restore_state(state.worker_states[0]) - else: - # Workers configured - assert isinstance(self.dataset, SavableDatasetWrapper) - assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states) - - # Check batch sizes (before and after) - if micro_batch_size != old_micro_batch_size: - assert micro_batch_size is not None and old_micro_batch_size is not None, ( - "Cannot resume with different batching mode " - "(batching to non-batching or vice versa)" - ) - - if micro_batch_size > old_micro_batch_size: - raise ValueError( - "Resuming with larger micro batch size is not allowed: " - f"{micro_batch_size} > {state.micro_batch_size}" - ) - elif ( - micro_batch_size < old_micro_batch_size - and old_micro_batch_size % micro_batch_size != 0 - ): - raise ValueError( - "Resuming with smaller micro batch size only allowed if the old " - f"micro batch size is a multiple of the new one: {micro_batch_size} < {state.micro_batch_size}" - ) - batch_size_ratio = old_micro_batch_size // micro_batch_size - for worker_state in state.worker_states: - assert isinstance(worker_state, SavableDatasetCheckpoint) - # When resuming with a smaller micro batch size, the offset must be scaled - # up to the new micro batch size to skip the same number of samples as before. - worker_state.offset *= batch_size_ratio - - self.dataset.restore_checkpoint(state.worker_states, worker_offset=state.next_worker_id) - - # Initialize the worker-sample counters so that every worker owns a valid - # "last emitted sample" index. Workers that have not emitted anything yet keep - # the default value ``-1``. - - assert isinstance(state.worker_states, list) - - self._worker_sample_counters = [ - ( - ws.state.sample_index - 1 - if (isinstance(ws, SavableDatasetCheckpoint) and ws.state is not None) - else -1 - ) - for ws in state.worker_states - ] - - self._next_worker_id = state.next_worker_id - - @deprecated( - "`save_state` is deprecated and was renamed to `save_state_global` and will be removed " - "in a future update. If you actually do not want to gather the states to a rank, use " - "`save_state_rank` instead." - ) - def save_state(self, dst_rank: int) -> Optional[Sequence[Optional[SavableDataLoaderState]]]: - """Deprecated. Use `save_state_global` (or `save_state_rank`) instead.""" - - return self.save_state_global(dst_rank) - - def save_state_global( - self, global_dst_rank: int - ) -> Optional[Sequence[Optional[SavableDataLoaderState]]]: - """ - Saves the state of the dataset globally, collecting the state from all ranks using torch - distributed. Allows for restoring the state later using `restore_state_global`, given the - result of this method. - Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not - save the state. Later, restore the state either only loaded on the `dst_rank` or - loading on all ranks separately using `restore_state_global`. - - Note: If you want to save/restore the state per rank separately, use `save_state_rank` and - the corresponding `restore_state_rank`. Also, these do not rely on torch distributed. - - Args: - global_dst_rank: The state will be gathered to this rank. The rank refers to the - global rank, not the rank within the data parallel group. - - Returns: - The state of the dataset (or `None`, if not on `dst_rank`). - """ - # Fetch current rank's worker's state - merged_state = self.save_state_rank() - - # Gather the merged states - if self.worker_config.world_size > 1: - output: Optional[Sequence[Optional[SavableDataLoaderState]]] - if self.worker_config.global_rank() == global_dst_rank: - output = [None] * self.worker_config.world_size - else: - # Check if the global_dst_rank is in the same group at all - if self.worker_config.data_parallel_group is not None: - try: - _ = torch.distributed.get_group_rank( - self.worker_config.data_parallel_group, global_dst_rank - ) - except RuntimeError: - raise ValueError( - f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" - ) - - output = None - - torch.distributed.gather_object( - merged_state, - output, - global_dst_rank, - group=self.worker_config.data_parallel_group, - ) - - return output - else: - # Not distributed -> return the merged state - return [merged_state] - - @deprecated( - "`restore_state` was renamed to `restore_state_global` and will be removed in a future update." - ) - def restore_state( - self, - state: Optional[Sequence[Optional[SavableDataLoaderState]]], - ) -> None: - """Deprecated. Use `restore_state_global` (or `restore_state_rank`) instead.""" - - return self.restore_state_global(state) - - def restore_state_global( - self, - state: Optional[Sequence[Optional[SavableDataLoaderState]]], - *, - src_rank: Optional[int] = None, - ) -> None: - """ - Restores the saved state from `save_state_global` (in torch distributed setup). - The global state needs be loaded on every rank that has a data loader instance. - - Optionally, one can specify a src_rank and only provide the state once. - In case of multiple data parallel groups, you must provide the state once - in each data parallel group. In this case the `src_rank` is the rank within the - data parallel group. - - Args: - state: The state to restore, as saved by `save_state_global`. - src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups). - """ - - assert self._epoch_iterator is None, "Cannot restore state while workers are running" - - # Only restore multi-rank if state is actually a list and we are in a distributed setup. - # Otherwise treat as single rank state. - if src_rank is None or self.worker_config.world_size == 1: - assert isinstance(state, list), "State must be a list in distributed setup" - assert len(state) == self.worker_config.world_size, ( - "State must be a list of size world_size" - ) - - # All ranks have the state - # Select the state of the current rank - rank_state = state[self.worker_config.rank] - else: - if self.worker_config.data_parallel_group is not None: - # Only the src_rank has the state within this dp group - try: - global_src_rank = torch.distributed.get_global_rank( - self.worker_config.data_parallel_group, src_rank - ) - except RuntimeError: - raise ValueError( - f"src_rank {src_rank} is not in the group of the current rank's worker config" - ) - else: - # If no DP group is given, we assume the global rank is - # the same as the data parallel rank - global_src_rank = src_rank - - if self.worker_config.rank != src_rank: - # Send the state to all other ranks - assert state is None - # Must still be a list of Nones - state = [None] * self.worker_config.world_size - else: - assert isinstance(state, list), "State must be a list in distributed setup" - assert len(state) == self.worker_config.world_size, ( - "State must be a list of size world_size" - ) - - local_object = [None] - torch.distributed.scatter_object_list( - local_object, - state, - src=global_src_rank, - group=self.worker_config.data_parallel_group, - ) - rank_state = local_object[0] - - self.restore_state_rank(rank_state) - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - """Restores a sample from a key. This is useful to debug the dataset.""" - return self.dataset.restore_sample(restore_key) - - def config(self): - """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` - and `restore_state` to match the configuration as well.""" - return { - "type": type(self).__qualname__, - "num_workers": self.num_workers, - "persistent_workers": self.persistent_workers, - "pin_memory": self.pin_memory, - "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, - "dataset": self.dataset.config(), - } - - -class BasicDataLoader(DataLoader[T], Generic[T]): - """DataLoader that supports debugging the dataset without saving capability (e.g. for val/eval).""" - - #: The worker config - worker_config: WorkerConfig - #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SavableDataset[T]] - - id: int - _sample_idx: int = 0 - - def __init__( - self, - dataset: SavableDataset[T], - gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, - gc_freeze_at_start: bool = True, - prefetch_factor: int = 2, - cache_pool: Optional[CachePool] = None, - watchdog_timeout_seconds: Optional[float] = 60, - watchdog_initial_timeout_seconds: Optional[float] = None, - fail_on_timeout: bool = False, - ): - """ - Create the dataloader supporting saving and restoring the state. - - Args: - dataset: The dataset to load. - gc_collect_every_n_steps: The number of steps after which the garbage collector is - called. As we're usually handling large (but few) tensors here, and the python - garbage collection is already full of objects just by importing, this can improve - the memory footprint quite a lot, and may even be necessary to avoid memory - overflow. - gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker - processes. This improves the garbage collection performance by a lot. - In rare cases, this may cause issues and can be disabled. Keep enabled if you - experience no issues. - cache_pool: If set, the cache pool to use for the dataset. - watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. - watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. - fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - """ - self.worker_config = dataset.worker_config - - self.id = SavableDataLoader.next_id() - - dataset = WatchdogDataset( - dataset, - worker_config=self.worker_config, - timeout_seconds=watchdog_timeout_seconds, - initial_timeout_seconds=watchdog_initial_timeout_seconds, - fail_on_timeout=fail_on_timeout, - ) - - if gc_collect_every_n_steps > 0: - dataset = GcDataset( - dataset, - worker_config=self.worker_config, - every_n_iter=gc_collect_every_n_steps, - freeze=gc_freeze_at_start, - ) - - dataset = SimpleSavableDatasetWrapper( - dataset, worker_config=self.worker_config, cache_pool=cache_pool - ) - - self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1) - - kwargs = {} - if self.worker_config.num_workers > 0: - # These must not be specified for num_workers =0 - kwargs["persistent_workers"] = True - kwargs["prefetch_factor"] = prefetch_factor - - seed_per_worker = [ - self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) - ] - - gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes - - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": self.config(), - } - ) - - def __len__(self): - # We override this, because otherwise we'll see warnings - return self.dataset.len_rank() - - def __iter__(self): - def _inner_generator(iterator): - iter_idx = 0 - id = SavableDataLoader.next_id() - - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.iter", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - - try: - for worker_id, sample_idx, sample in iterator: - # If the next sample will be from the first worker, we can safely resume - if self.worker_config.should_log(level=1): - keys = default_get_batch_keys(sample) - self.worker_config.worker_log( - { - **{ - "t": "BasicDataLoader.yield", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": self.id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": iter_idx, - "iter_idx": iter_idx, - "global_idx": self._sample_idx, - }, - **({} if keys is None else {"keys": keys}), - } - ) - self._sample_idx += 1 - iter_idx += 1 - yield sample - finally: - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.StopIteration", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - - return _inner_generator(super().__iter__()) - - def config(self): - """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` - and `restore_state` to match the configuration as well.""" - return { - "type": type(self).__qualname__, - "num_workers": self.worker_config.num_workers, - "persistent_workers": self.persistent_workers, - "pin_memory": self.pin_memory, - "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, - "dataset": self.dataset.config(), - } - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - """Restores a sample from a key. This is useful to debug the dataset.""" - return self.dataset.restore_sample(restore_key) - - -def _sample_str(self, sample): - """Returns a human readable debug string for a single sample, also uniquely identifying it.""" - import dataclasses - import hashlib - - if isinstance(sample, torch.Tensor): - return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.detach().cpu().numpy().tobytes()).hexdigest()!r})" - elif isinstance(sample, np.ndarray): - return f"ndarray(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.tobytes()).hexdigest()!r})" - elif isinstance(sample, (str, bytes)): - if len(sample) > 100: - return f"{sample[:100]!r}..." - return repr(sample) - elif isinstance(sample, (int, float)): - return repr(sample) - elif isinstance(sample, dict): - innerstr = ", ".join(f"{k}={self._sample_str(v)}, " for k, v in sample.items()) - return f"dict({innerstr})" - elif isinstance(sample, (list, tuple)): - innerstr = ", ".join(self._sample_str(v) for v in sample) - return f"{type(sample)}({innerstr})" - elif dataclasses.is_dataclass(sample): - innerstr = ", ".join( - f"{field.name}={self._sample_str(getattr(sample, field.name))}, " - for field in dataclasses.fields(sample) - ) - return f"{type(sample).__name__}({innerstr})" - else: - return repr(sample) diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 5335d018..2b4d9db7 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -4,6 +4,7 @@ import dataclasses import functools import inspect +import threading from abc import ABC from dataclasses import is_dataclass from types import MethodType @@ -38,9 +39,9 @@ SampleDecoder, SavableDataset, ) -from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin +from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin, RestoreKey from megatron.energon.metadataset.loader_interface import DatasetBlendMode, LoadedDataset -from megatron.energon.rng import SystemRng +from megatron.energon.rng import SystemRng, UserRng from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder.cooking import Cooker from megatron.energon.worker import WorkerConfig @@ -56,6 +57,7 @@ PackingDataset, ShuffleBufferDataset, ) +from megatron.energon.wrappers.file_store_init_wrapper import FileStoreInitWrapper from megatron.energon.wrappers.repeat_dataset import RepeatDataset T = TypeVar("T") @@ -135,6 +137,7 @@ def stateless( fn: Optional[Callable[..., T]] = None, *, restore_seeds: bool = False, + restore_task_encoder_seeds: bool = False, failure_tolerance: Optional[int] = None, ) -> Union[Callable[[Callable[..., T]], Callable[..., T]], Callable[..., T]]: """Decorator to mark a function of the task encoder as restorable. @@ -144,6 +147,9 @@ def stateless( restore_seeds: Whether to restore the seeds for the function. I.e. the seeds are set from the sample index and the worker seed, such that they can be restored when a sample is restored from that function. + restore_task_encoder_seeds: Whether to restore the seeds for the task encoder. I.e. the seeds are set + from the sample index and the worker seed, such that they can be restored when a sample + is restored from that function. failure_tolerance: The number of consecutive exceptions that are handled, after which a `FatalSampleError` is raised for this function. Set to 0 to disable. @@ -168,77 +174,132 @@ def encode_sample(self, sample: T_sample) -> T_encoded_sample: ) if restore_seeds: worker_seed = None + orig_fn = fn - @functools.wraps(fn) - def seed_wrapper_generator(self, *args, **kwargs): - nonlocal worker_seed - if worker_seed is None: - worker_seed = WorkerConfig.active_worker_config.worker_seed() + if inspect.isgeneratorfunction(orig_fn): - # Save the RNG states and set the new seed - outer_rng_state = SystemRng.save_state() + @functools.wraps(orig_fn) + def seed_wrapper_generator(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() - # Before constructing the generator and before the first - # iteration, set inner RNG based on seed computed - # from worker_seed and current sample index - SystemRng.seed_args(worker_seed, self.current_sample_index) + # Save the RNG states and set the new seed + outer_rng_state = SystemRng.save_state() + + # Before constructing the generator and before the first + # iteration, set inner RNG based on seed computed + # from worker_seed and current sample index + SystemRng.seed_args(worker_seed, self.current_sample_index) + + it = iter(orig_fn(self, *args, **kwargs)) + + inner_rand_state = None + + while True: + if inner_rand_state is not None: + # Restore inner random state before calling the generator + # This will not be done on the first iteration + SystemRng.restore_state(inner_rand_state) + + try: + # Now call the generator. This will yield the sample + # But note it may also throw an exception or a StopIteration + sample = next(it) + + # Save inner random state after calling the generator + inner_rand_state = SystemRng.save_state() + except StopIteration: + # We're stopping here, but the outer random state + # will be restored before returning (in finally below) + break + finally: + # Restore outer rand state before yielding or when an exception was raised + SystemRng.restore_state(outer_rng_state) + + # Now yield the sample. + # This will give control back to the caller who may + # change the random state. + yield sample + + # Save outer random state after yielding + outer_rng_state = SystemRng.save_state() + + fn = seed_wrapper_generator + else: - it = iter(fn(self, *args, **kwargs)) + @functools.wraps(orig_fn) + def seed_wrapper(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() - inner_rand_state = None + # Save the RNG states and set the new seed + rng_state = SystemRng.save_state() - while True: - if inner_rand_state is not None: - # Restore inner random state before calling the generator - # This will not be done on the first iteration - SystemRng.restore_state(inner_rand_state) + SystemRng.seed_args(worker_seed, self.current_sample_index) try: - # Now call the generator. This will yield the sample - # But note it may also throw an exception or a StopIteration - sample = next(it) - - # Save inner random state after calling the generator - inner_rand_state = SystemRng.save_state() - except StopIteration: - # We're stopping here, but the outer random state - # will be restored before returning (in finally below) - break + return orig_fn(self, *args, **kwargs) finally: - # Restore outer rand state before yielding or when an exception was raised - SystemRng.restore_state(outer_rng_state) + # Restore the RNGs + SystemRng.restore_state(rng_state) - # Now yield the sample. - # This will give control back to the caller who may - # change the random state. - yield sample + fn = seed_wrapper - # Save outer random state after yielding - outer_rng_state = SystemRng.save_state() + if restore_task_encoder_seeds: + te_orig_fn = fn + worker_seed = None + if inspect.isgeneratorfunction(te_orig_fn): - @functools.wraps(fn) - def seed_wrapper(self, *args, **kwargs): - nonlocal worker_seed - if worker_seed is None: - worker_seed = WorkerConfig.active_worker_config.worker_seed() + @functools.wraps(te_orig_fn) + def seed_wrapper_generator(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() - # Save the RNG states and set the new seed - rng_state = SystemRng.save_state() + te_outer_rng_state = self.rng.save_state() - SystemRng.seed_args(worker_seed, self.current_sample_index) + self.rng.seed_args(worker_seed, self.current_sample_index) - try: - return fn(self, *args, **kwargs) - finally: - # Restore the RNGs - SystemRng.restore_state(rng_state) - - if inspect.isgeneratorfunction(fn): - setattr(seed_wrapper_generator, "__stateless__", True) - return seed_wrapper_generator + it = iter(te_orig_fn(self, *args, **kwargs)) + + inner_rand_state = None + + while True: + if inner_rand_state is not None: + self.rng.restore_state(inner_rand_state) + try: + sample = next(it) + inner_rand_state = self.rng.save_state() + except StopIteration: + break + finally: + self.rng.restore_state(te_outer_rng_state) + + yield sample + + te_outer_rng_state = self.rng.save_state() else: - setattr(seed_wrapper, "__stateless__", True) - return seed_wrapper + + @functools.wraps(te_orig_fn) + def seed_wrapper(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() + + # Save the RNG states and set the new seed + te_rng_state = self.rng.save_state() + + self.rng.seed_args(worker_seed, self.current_sample_index) + + try: + return te_orig_fn(self, *args, **kwargs) + finally: + # Restore the RNGs + self.rng.restore_state(te_rng_state) + + fn = seed_wrapper setattr(fn, "__stateless__", True) if failure_tolerance is not None: @@ -266,7 +327,7 @@ class Batch(PinMemoryMixin, ExtendableDataclassMixin): __key__: list[str] #: Key for restoring the sample. This is used to restore the sample from a checkpoint. It # should be a (nested) tuple of strings and integers, which can be used to index the dataset. - __restore_key__: Tuple[Union[str, int, tuple], ...] + __restore_key__: Tuple[RestoreKey | None, ...] #: A dataset may define a subflavors to distinguish between samples of the same sample type. __subflavors__: Optional[list[Optional[Dict[str, Any]]]] = None @@ -347,7 +408,7 @@ def from_samples(cls: Type[T_batch], samples: Sequence[Sample], **kwargs) -> T_b return cls(**init_args) -class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): +class TaskEncoder(Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): """ Base class for task encoders. @@ -373,6 +434,13 @@ class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch] #: The decoder to use for decoding samples. Set manually as needed to override options. decoder: Optional[SampleDecoder] = SampleDecoder() + #: Thread-local state. Used for properties, that are worker-local. + _worker_local: threading.local + + def __init__(self): + # Create a thread-local state for the workers. + self._worker_local = threading.local() + def _is_overridden( self, bound_method: Callable[..., Any], bases: Optional[Sequence[Type[Any]]] = None ) -> bool: @@ -626,6 +694,7 @@ def build_batch( fixed_batch_size=batch_size, sample_group_key=self.batch_group_criterion, batcher=self.batch, + batcher_stateless=get_stateless(self.batch), drop_last=batch_drop_last, worker_config=worker_config, failure_tolerance=get_failure_tolerance( @@ -697,9 +766,13 @@ def build_cook_crude_sample( else: raise ValueError(f"No cooker found for subflavors: {subflavors}") + all_aux_datasets = list(aux.values()) + if cooker.need_primary and "primary" not in aux: try: primary_aux = get_primary_aux() + primary_aux.worker_init() + all_aux_datasets.append(primary_aux) assert primary_aux is not None, "Primary auxiliary dataset must always exist" if self.decoder is not None: primary_aux = DecodeFileStore(primary_aux, decoder=self.decoder) @@ -710,19 +783,25 @@ def build_cook_crude_sample( cook_fn = functools.partial(self.cook_crude_sample, cooker=cooker, aux=aux) - return MapDataset( - dataset, - cook_fn, - worker_config=worker_config, - stateless_map_fn=get_stateless(self.cook_crude_sample), - map_fn_config=dict( - cooker=dict( - cook=SavableDataset._function_config(cooker.cook), - has_subflavors=cooker.has_subflavors, - aux={k: {"_path": str(v.get_path())} for k, v in aux.items()}, + return FileStoreInitWrapper( + MapDataset( + dataset, + cook_fn, + worker_config=worker_config, + stateless_map_fn=get_stateless(self.cook_crude_sample), + map_fn_config=dict( + cooker=dict( + cook=SavableDataset._function_config(cooker.cook), + has_subflavors=cooker.has_subflavors, + aux={k: {"_path": str(v.get_path())} for k, v in aux.items()}, + ), + ), + failure_tolerance=get_failure_tolerance( + cook_fn, self.__default_failure_tolerance__ ), ), - failure_tolerance=get_failure_tolerance(cook_fn, self.__default_failure_tolerance__), + auxiliary_datasets=all_aux_datasets, + worker_config=worker_config, ) def _load_dataset( @@ -981,10 +1060,21 @@ def cache(self) -> CachePool: assert WorkerConfig.active_worker_config is not None, ( "The cache can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package." ) - assert WorkerConfig.active_worker_config._cache_pool is not None, ( - "Cache pool must be set by the loader." - ) - return WorkerConfig.active_worker_config._cache_pool + cache_pool = WorkerConfig.active_worker_cache_pool() + assert cache_pool is not None, "Cache pool must be set by the loader." + return cache_pool + + # State fields, they are initialized when the dataloader is started. + @property + def rng(self) -> UserRng: + """The random generator that should be used within user methods (like `encode_sample`) for reproducibility (and + thus for savability). + """ + if not hasattr(self._worker_local, "rng"): + # Initialize when needed. + self._worker_local.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) + + return self._worker_local.rng class DefaultTaskEncoder( diff --git a/src/megatron/energon/tools/analyze_debug.py b/src/megatron/energon/tools/analyze_debug.py index e1eb1754..c0755d50 100644 --- a/src/megatron/energon/tools/analyze_debug.py +++ b/src/megatron/energon/tools/analyze_debug.py @@ -97,24 +97,6 @@ ) -class YieldBatchLogLine(TypedDict): - # Json example: - # { - # "t": "yield_batch", - # "r": 1, - # "w": 1, - # "m": "train", - # "idx": 1, - # "keys": ["parts/data-train-000051.tar/528866", ...], - # } - t: Literal["yield_batch"] - r: int - w: int - m: Literal["train", "val"] - idx: int - keys: List[str] - - class SampleLoaderYieldLogLine(TypedDict): # Json example: # { @@ -442,7 +424,7 @@ def command( class LoaderInitLogLine(TypedDict): - t: Literal["SavableLoader.__init__", "BasicDataLoader.__init__"] + t: Literal["DataLoader.__init__"] r: int w: None id: int @@ -450,33 +432,32 @@ class LoaderInitLogLine(TypedDict): class LoaderIterLogLine(TypedDict): - t: Literal["SavableDataLoader.iter", "BasicDataLoader.iter"] + t: Literal["DataLoader.epoch_iter"] r: int w: None id: int - iter_id: int + epoch_id: int class LoaderYieldLogLine(TypedDict): - t: Literal["SavableDataLoader.yield", "BasicDataLoader.yield"] + t: Literal["DataLoader.epoch_iter.yield"] r: int w: None id: int - iter_id: int + epoch_id: int worker_id: int - worker_idx: int - idx: int - iter_idx: int - global_idx: int + worker_sample_idx: int + epoch_sample_idx: int + global_sample_idx: int keys: Optional[List[str]] class LoaderStopLogLine(TypedDict): - t: Literal["SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"] + t: Literal["DataLoader.epoch_iter.StopIteration"] r: int w: None id: int - iter_id: int + epoch_id: int LoaderLines = Union[ @@ -487,14 +468,10 @@ class LoaderStopLogLine(TypedDict): ] LOADER_LOG_LINE_TYPES_T = ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.iter", - "BasicDataLoader.iter", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - "SavableDataLoader.StopIteration", - "BasicDataLoader.StopIteration", + "DataLoader.__init__", + "DataLoader.epoch_iter", + "DataLoader.epoch_iter.yield", + "DataLoader.epoch_iter.StopIteration", ) @@ -553,34 +530,29 @@ def loaders(self) -> Dict[int, LoaderInfo]: loaders = {} for log_line in self._iter_log_lines( ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.yield", - "BasicDataLoader.yield", + "DataLoader.__init__", + "DataLoader.epoch_iter.yield", ) ): - if log_line["t"] in ("SavableLoader.__init__", "BasicDataLoader.__init__"): + if log_line["t"] == "DataLoader.__init__": loaders[log_line["id"]] = LoaderInfo( id=log_line["id"], modality=self._find_config_modality(log_line["config"]), path=self._find_config_path(log_line["config"]), global_count=0, ) - elif log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield"): - loaders[log_line["id"]].global_count = log_line["global_idx"] + elif log_line["t"] == "DataLoader.epoch_iter.yield": + loaders[log_line["id"]].global_count = log_line["global_sample_idx"] return loaders def log_entries(self, loader_ids: Container[int]) -> Generator[Optional[List[str]], None, None]: idx = self._start_idx - for log_line in self._iter_log_lines(("SavableDataLoader.yield", "BasicDataLoader.yield")): - if ( - log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield") - and log_line["id"] in loader_ids - ): - assert log_line["global_idx"] >= idx, ( + for log_line in self._iter_log_lines(("DataLoader.epoch_iter.yield",)): + if log_line["t"] == "DataLoader.epoch_iter.yield" and log_line["id"] in loader_ids: + assert log_line["global_sample_idx"] >= idx, ( f"Found entry {log_line} with wrong idx <{idx}" ) - while log_line["global_idx"] != idx: + while log_line["global_sample_idx"] != idx: yield None idx += 1 if "keys" in log_line: diff --git a/src/megatron/energon/tools/checkpoint.py b/src/megatron/energon/tools/checkpoint.py index 08115a35..121954c1 100644 --- a/src/megatron/energon/tools/checkpoint.py +++ b/src/megatron/energon/tools/checkpoint.py @@ -1,17 +1,22 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses import re -from typing import List, Optional +from typing import Callable, Generator, List, Optional import click import torch +from megatron.energon.dataloader.dataloader import RankState +from megatron.energon.dataloader.workers.base_worker import WorkerState from megatron.energon.epathlib import EPath -from megatron.energon.savable_loader import SavableDataLoaderState +from megatron.energon.flavors.base_dataset import RestoreKey +from megatron.energon.wrappers.base import WrappedRestoreKey +from megatron.energon.wrappers.batch_dataset import BatchRestoreKey -def natural_sort_key(s): +def natural_sort_key(s: str) -> List[str | int]: """ Function to use for natural sorting of filenames. @@ -21,7 +26,7 @@ def natural_sort_key(s): return [int(text) if text.isdigit() else text.lower() for text in re.split(r"(\d+)", s)] -def detect_and_replicate_pattern(file_list): +def detect_and_replicate_pattern(file_list: List[str]) -> Callable[[int], str]: """ Given a list of file paths, detect the single numeric pattern and return a function that, when called with integer n (starting from 0), generates @@ -156,7 +161,7 @@ def __init__(self, state_files: List[EPath]): else: self.megatron_style = False - if isinstance(first_state, SavableDataLoaderState): + if isinstance(first_state, RankState): if self.megatron_style: self.rank_states = [first_state] + [ torch.load(str(state_file), weights_only=False)["dataloader_state_dict"] @@ -170,7 +175,7 @@ def __init__(self, state_files: List[EPath]): self.is_global_checkpoint = False elif isinstance(first_state, list): assert len(state_files) == 1, "Global checkpoint must contain exactly one file" - assert all(isinstance(state, SavableDataLoaderState) for state in first_state) + assert all(isinstance(state, RankState) for state in first_state) self.rank_states = first_state self.is_global_checkpoint = True else: @@ -184,9 +189,12 @@ def __init__(self, state_files: List[EPath]): self.rank_num_workers[0] == num_workers for num_workers in self.rank_num_workers ), "All ranks must have the same number of workers." - def write_new_states_to_folder( - self, output_folder: EPath, new_states: List[SavableDataLoaderState] - ): + assert all( + rank_state.micro_batch_size == self.rank_states[0].micro_batch_size + for rank_state in self.rank_states[1:] + ), "All ranks must have the same micro batch size." + + def write_new_states_to_folder(self, output_folder: EPath, new_states: List[RankState]): for rank_idx, rank_state in enumerate(new_states): output_file = output_folder / self.file_pattern_func(rank_idx) if self.megatron_style: @@ -197,20 +205,69 @@ def write_new_states_to_folder( else: torch.save(rank_state, str(output_file)) - def get_num_ranks(self): + def get_num_ranks(self) -> int: return len(self.rank_states) - def get_num_workers(self): + def get_num_workers(self) -> int: return self.rank_num_workers[0] - def get_micro_batch_size(self): + def get_micro_batch_size(self) -> int | None: return self.rank_states[0].micro_batch_size - def __iter__(self): - """Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion.""" - for rank, state in enumerate(self.rank_states): - for worker_state in state.worker_states: - yield worker_state + def __iter__(self) -> Generator[tuple[WorkerState | None, list[RestoreKey | None]], None, None]: + """Iterates the WorkerStates of multiple ranks in a round-robin fashion.""" + for rank_state in self.rank_states: + for worker_state, prefetched_samples_keys in zip( + rank_state.worker_states, rank_state.prefetched_restore_keys + ): + yield worker_state, prefetched_samples_keys + + +def split_batch_restore_key( + restore_key: RestoreKey | None, batch_split_factor: int +) -> list[RestoreKey | None]: + """Split the given restore_key into multiple restore keys, one for each batch.""" + if restore_key is None: + raise ValueError("Cannot split None restore key") + if isinstance(restore_key, BatchRestoreKey): + # Split the inner keys into batch_split_factor keys + # Duplicate the sample_idx for each batch + assert len(restore_key.inner) % batch_split_factor == 0, ( + "Batch size must be a multiple of the batch split factor" + ) + split_size = len(restore_key.inner) // batch_split_factor + return [ + BatchRestoreKey( + inner=tuple(restore_key.inner[i : i + split_size]), + sample_idx=restore_key.sample_idx, + ) + for i in range(0, len(restore_key.inner), split_size) + ] + elif isinstance(restore_key, WrappedRestoreKey): + inner_restore_keys = split_batch_restore_key(restore_key.inner, batch_split_factor) + inner_kwargs = { + field.name: getattr(restore_key, field.name) + for field in dataclasses.fields(restore_key) + } + inner_kwargs.pop("inner") + return [ + type(restore_key)(**inner_kwargs, inner=inner_restore_key) + for inner_restore_key in inner_restore_keys + ] + else: + raise ValueError(f"Unsupported restore key type for splitting batch: {type(restore_key)}") + + +def split_batch_restore_keys( + restore_keys: list[RestoreKey | None], batch_split_factor: int +) -> list[RestoreKey | None]: + if batch_split_factor == 1: + return restore_keys + return [ + new_restore_key + for restore_key in restore_keys + for new_restore_key in split_batch_restore_key(restore_key, batch_split_factor) + ] @click.command(name="redist") @@ -227,8 +284,12 @@ def __iter__(self): @click.option( "--new-world-size", type=int, help="Number of ranks to redistribute to", required=False ) +@click.option("--new-micro-batch-size", type=int, help="New micro batch size", required=False) def command_redist( - input_files: List[EPath], output_path: EPath, new_world_size: Optional[int] = None + input_files: List[EPath], + output_path: EPath, + new_world_size: Optional[int] = None, + new_micro_batch_size: Optional[int] = None, ): """Redistribute a checkpoint. @@ -267,22 +328,52 @@ def command_redist( # Ensure output directory exists output_path.mkdir(exist_ok=True, parents=True) - new_rank_states = [list() for _ in range(new_world_size)] + # A list (rank) of lists (workers) of (worker_state, prefetched_sample_keys) for each new rank + new_rank_states = [[] for _ in range(new_world_size)] rsi_iter = iter(rsi) for rank_idx in range(new_world_size): for _ in range(new_workers_per_rank): - state = next(rsi_iter) - new_rank_states[rank_idx].append(state) + worker_state, prefetched_sample_keys = next(rsi_iter) + new_rank_states[rank_idx].append((worker_state, prefetched_sample_keys)) assert all( - len(new_rank_states[0]) == len(new_rank_states[rank]) for rank in range(1, new_world_size) + len(new_rank_states[0]) == len(rank_states) for rank_states in new_rank_states[1:] ), "All ranks must have the same number of workers, also for the new distribution." + # Check batch sizes (before and after) + old_micro_batch_size = rsi.get_micro_batch_size() + if old_micro_batch_size is not None and new_micro_batch_size != old_micro_batch_size: + assert new_micro_batch_size is not None and old_micro_batch_size is not None, ( + "Cannot resume with different batching mode (batching to non-batching or vice versa)" + ) + + if new_micro_batch_size > old_micro_batch_size: + raise ValueError( + "Resuming with larger micro batch size is not allowed: " + f"{new_micro_batch_size} > {old_micro_batch_size}" + ) + elif ( + new_micro_batch_size < old_micro_batch_size + and old_micro_batch_size % new_micro_batch_size != 0 + ): + raise ValueError( + "Resuming with smaller micro batch size only allowed if the old " + f"micro batch size is a multiple of the new one: {new_micro_batch_size} < {old_micro_batch_size}" + ) + batch_split_factor = old_micro_batch_size // new_micro_batch_size + print(f"Splitting batches by {batch_split_factor}x") + else: + batch_split_factor = 1 + new_states = [ - SavableDataLoaderState( - worker_states=new_rank_state, + RankState( + worker_states=[worker_state for worker_state, prefetched_sample_keys in new_rank_state], next_worker_id=0, # Reset the next worker ID - micro_batch_size=rsi.get_micro_batch_size(), + micro_batch_size=new_micro_batch_size, + prefetched_restore_keys=[ + split_batch_restore_keys(prefetched_sample_keys, batch_split_factor) + for worker_state, prefetched_sample_keys in new_rank_state + ], ) for new_rank_state in new_rank_states ] diff --git a/src/megatron/energon/typed_converter.py b/src/megatron/energon/typed_converter.py index 23c4adf7..9d3f39c7 100644 --- a/src/megatron/energon/typed_converter.py +++ b/src/megatron/energon/typed_converter.py @@ -5,10 +5,8 @@ import dataclasses import importlib import inspect -import json -import re import typing -from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Literal, Optional, Tuple, Type, TypeVar, Union from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath @@ -63,6 +61,8 @@ def _check_instance_type(cls, inst_type: Type) -> bool: @edataclass class JsonParser: strict: bool = True + #: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit instantiation of types + allow_imports: bool = True def _resolve_object( self, @@ -77,7 +77,7 @@ def _resolve_object( """Resolve an object reference to the object.""" try: module = importlib.import_module(module_name) - except ModuleNotFoundError: + except ModuleNotFoundError: # pragma: no cover try_energon_module = importlib.import_module("megatron.energon", package=None) if hasattr(try_energon_module, object_name): module = try_energon_module @@ -85,7 +85,7 @@ def _resolve_object( raise try: return getattr(module, object_name) - except AttributeError: + except AttributeError: # pragma: no cover raise ModuleNotFoundError(f"Object {object_name} not found in {module_name}") def raw_to_instance( @@ -188,7 +188,7 @@ def raw_to_instance( inst = cls else: # Do not assert the other cases, we fallback to the passed cls - inst = self.safe_call_function(kwargs, cls, allow_imports=True) + inst = self.safe_call_function(kwargs, cls) assert not isinstance(cls, type) or _check_instance_type(type(inst), inst_type), ( f"Expected {inst_type}, got {cls}" ) @@ -198,7 +198,6 @@ def raw_to_typed( # noqa: C901 self, raw_data: Union[dict, list, str, int, bool, float, None], inst_type: Type[TType], - allow_imports: bool = False, _path: str = "root", _stage: Tuple[int, ...] = (), ) -> TType: @@ -217,8 +216,6 @@ class MyNamedTuple(NamedTuple): Args: raw_data: The raw (e.g. json) data to be made as `inst_type` inst_type: The type to return - allow_imports: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit - instantiation of types _path: (internal for recursive call) The path to the object being converted from the root _stage: (internal for recursive call) Numbers representing the position of the current object being converted from the root @@ -227,7 +224,7 @@ class MyNamedTuple(NamedTuple): The input data as `inst_type`. """ type_name = getattr(inst_type, "__name__", repr(inst_type)) - if raw_data is _missing_value: + if raw_data is _missing_value: # pragma: no cover raise JsonValueError( f"Missing value at {_path}", inst_type, @@ -239,7 +236,7 @@ class MyNamedTuple(NamedTuple): # Literal types or missing data if not isinstance(raw_data, inst_type) and not ( isinstance(raw_data, int) and inst_type is float - ): + ): # pragma: no cover raise JsonValueError( f"Type does not match, expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -250,7 +247,7 @@ class MyNamedTuple(NamedTuple): return raw_data elif inst_type is Any: if ( - allow_imports + self.allow_imports and isinstance(raw_data, dict) and "__module__" in raw_data and ("__class__" in raw_data or "__function__" in raw_data) @@ -261,7 +258,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is Literal: # Literal[value[, ...]] values = typing.get_args(inst_type) - if raw_data not in values: + if raw_data not in values: # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -284,7 +281,6 @@ class MyNamedTuple(NamedTuple): return self.raw_to_typed( raw_data, subtype, - allow_imports, f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}", _stage + (1,), ) @@ -304,7 +300,7 @@ class MyNamedTuple(NamedTuple): except JsonValueError as e: cur_exc = e raise cur_exc - else: + else: # pragma: no cover raise JsonValueError( f"Expected {inst_type} at {_path}, got {raw_data!r}", inst_type, @@ -312,6 +308,13 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) + elif ( + self.allow_imports + and isinstance(raw_data, dict) + and "__module__" in raw_data + and ("__class__" in raw_data or "__function__" in raw_data) + ): + return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage) elif ( isinstance(inst_type, type) and issubclass(inst_type, tuple) @@ -333,7 +336,6 @@ class MyNamedTuple(NamedTuple): field_name: self.raw_to_typed( raw_data.get(field_name, defaults.get(field_name, _missing_value)), field_type, - allow_imports, f"{_path} -> {type_name}:{field_name}", _stage + (idx,), ) @@ -359,7 +361,7 @@ class MyNamedTuple(NamedTuple): ) elif dataclasses.is_dataclass(inst_type): # dataclass - if not isinstance(raw_data, dict): + if not isinstance(raw_data, dict): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -367,25 +369,24 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) - kwargs = { - field.name: self.raw_to_typed( - raw_data.get( - field.name, - ( - ( - _missing_value - if field.default_factory is dataclasses.MISSING - else field.default_factory() - ) - if field.default is dataclasses.MISSING - else field.default - ), - ), + + def get_field_value(field: dataclasses.Field, idx: int) -> Any: + value = raw_data.get(field.name, _missing_value) + if value is _missing_value: + # Use the factory value directly, without going through the conversion + if field.default_factory is not dataclasses.MISSING: + return field.default_factory() + elif field.default is not dataclasses.MISSING: + return field.default + return self.raw_to_typed( + value, field.type, - allow_imports, f"{_path} -> {type_name}:{field.name}", _stage + (idx,), ) + + kwargs = { + field.name: get_field_value(field, idx) for idx, field in enumerate(dataclasses.fields(inst_type)) if field.init } @@ -401,7 +402,7 @@ class MyNamedTuple(NamedTuple): ) try: return inst_type(**kwargs) - except BaseException: + except BaseException: # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -412,7 +413,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is list: # List[inner_type] (inner_type,) = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, list): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -421,15 +422,13 @@ class MyNamedTuple(NamedTuple): _stage, ) return [ - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ] elif typing.get_origin(inst_type) is set: # Set[inner_type] (inner_type,) = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, list): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -438,12 +437,10 @@ class MyNamedTuple(NamedTuple): _stage, ) res = set( - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ) - if len(res) != len(raw_data): + if len(res) != len(raw_data): # pragma: no cover raise JsonValueError( f"Duplicate element at {_path}", inst_type, @@ -455,7 +452,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is tuple: # Tuple[inner_types[0], inner_types[1], ...] or Tuple[inner_types[0], Ellipsis/...] inner_types = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, (list, tuple)): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -467,15 +464,13 @@ class MyNamedTuple(NamedTuple): # Tuple of arbitrary length, all elements same type # Tuple[inner_types[0], Ellipsis/...] return tuple( - self.raw_to_typed( - val, inner_types[0], allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_types[0], f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ) else: # Fixed size/typed tuple # Tuple[inner_types[0], inner_types[1], ...] - if len(raw_data) != len(inner_types): + if len(raw_data) != len(inner_types): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -483,17 +478,15 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) - return [ - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + return tuple( + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, (val, inner_type) in enumerate(zip(raw_data, inner_types)) - ] + ) elif typing.get_origin(inst_type) is dict: # Dict[str, value_type] key_type, value_type = typing.get_args(inst_type) assert key_type is str - if not isinstance(raw_data, dict): + if not isinstance(raw_data, dict): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -502,14 +495,12 @@ class MyNamedTuple(NamedTuple): _stage, ) return { - key: self.raw_to_typed( - val, value_type, allow_imports, f"{_path} -> {key!r}", _stage + (idx,) - ) + key: self.raw_to_typed(val, value_type, f"{_path} -> {key!r}", _stage + (idx,)) for idx, (key, val) in enumerate(raw_data.items()) } elif inst_type in (dict, list): # dict, list (no subtyping) - if not isinstance(raw_data, inst_type): + if not isinstance(raw_data, inst_type): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -521,7 +512,7 @@ class MyNamedTuple(NamedTuple): elif inst_type is EPath: if isinstance(raw_data, str): return EPath(raw_data) - elif not isinstance(raw_data, EPath): + elif not isinstance(raw_data, EPath): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -530,13 +521,6 @@ class MyNamedTuple(NamedTuple): _stage, ) return raw_data - elif ( - allow_imports - and isinstance(raw_data, dict) - and "__module__" in raw_data - and ("__class__" in raw_data or "__function__" in raw_data) - ): - return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage) else: return raw_data @@ -544,7 +528,6 @@ def safe_call_function( self, raw_data: Union[dict, list, str, int, bool, float, None], fn: Callable[..., TType], - allow_imports: bool = False, ) -> TType: """ Converts raw data (i.e. dicts, lists and primitives) to typed call arguments. @@ -562,7 +545,6 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raw_data: The raw (e.g. json) data to be made as `inst_type` fn: The function to call with the converted data strict: If true, don't allow additional attributes - allow_imports: If true, allow instantiating objects by specifying __module__ and __class__/__function__. Returns: The return value of `fn` @@ -587,15 +569,12 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: kwargs[key] = self.raw_to_typed( unused_args.pop(key, param.default), t, - allow_imports, _path=key, _stage=(idx,), ) elif param.kind == inspect.Parameter.VAR_KEYWORD: for arg_key, arg_val in unused_args.items(): - kwargs[arg_key] = self.raw_to_typed( - arg_val, t, allow_imports, _path=key, _stage=(idx,) - ) + kwargs[arg_key] = self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,)) unused_args.clear() elif param.kind == inspect.Parameter.VAR_POSITIONAL: # No way to pass positional arguments @@ -607,7 +586,7 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise RuntimeError(f"Unknown parameter kind {param.kind!r}") if self.strict and len(unused_args) > 0: raise ValueError(f"Unexpected arguments: {unused_args!r}") - elif isinstance(raw_data, list): + elif isinstance(raw_data, list): # pragma: no cover unused_args = raw_data.copy() for idx, (key, param) in enumerate(parameters): t = Any if param.annotation is inspect.Parameter.empty else param.annotation @@ -616,11 +595,7 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise ValueError( f"Missing required positional-only argument {key!r} at index {idx}" ) - args.append( - self.raw_to_typed( - unused_args.pop(), t, allow_imports, _path=key, _stage=(idx,) - ) - ) + args.append(self.raw_to_typed(unused_args.pop(), t, _path=key, _stage=(idx,))) elif param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: if param.default is inspect.Parameter.empty and len(unused_args) == 0: raise ValueError( @@ -630,14 +605,10 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: arg_val = param.default else: arg_val = unused_args.pop() - args.append( - self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,)) - ) + args.append(self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,))) elif param.kind == inspect.Parameter.VAR_POSITIONAL: for arg_val in unused_args: - args.append( - self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,)) - ) + args.append(self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,))) unused_args.clear() elif param.kind == inspect.Parameter.VAR_KEYWORD: # No way to pass keyword arguments @@ -648,424 +619,12 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise RuntimeError(f"Unknown parameter kind {param.kind!r}") if self.strict and len(unused_args) > 0: raise ValueError(f"Unexpected arguments: {unused_args!r}") - else: + else: # pragma: no cover raise ValueError( f"Cannot call function with raw data of type {type(raw_data)!r}, require list or dict" ) return fn(*args, **kwargs) - def override( # noqa: C901 - self, - value: TType, - overrides: Any, - inst_type: Optional[Type[TType]] = None, - allow_imports: bool = False, - _path: str = "root", - _stage: Tuple[int, ...] = (), - ) -> TType: - """ - Allows overriding values of a typed object using environment config. - Allows overriding single config variables, or whole objects. - - Examples:: - - class MyNamedTuple(NamedTuple): - x: int - y: str - - class MyNested(NamedTuple): - nested: MyNamedTuple - - assert override( - MyNested(nested=MyNamedTuple(x=42, y="foo")), - {'nested.x': 5}, - ) == MyNested(nested=MyNamedTuple(x=5, y="foo")) - assert override( - MyNested(nested=MyNamedTuple(x=42, y="foo")), - {'nested': '{"x": 5, "y": "bar"}'}, - ) == MyNested(nested=MyNamedTuple(x=5, y="bar")) - - Args: - value: The base value to override. - overrides: The overrides to apply - strict: If true, no additional keys are allowed - inst_type: If given, validate against this base type instead of the type of `value`. - allow_imports: If true, allow instantiating types with dicts of __module__ and __class__/__function__. - _path: Internal: The path to the current value. - _stage: Internal: The current stage of the override. - - Returns: - Same type as the input object (or `inst_type` if set), copied and updated from the - overrides. - """ - if inst_type is None: - inst_type = type(value) - type_name = getattr(inst_type, "__name__", repr(inst_type)) - if inst_type in (str, int, float, bool, None, type(None)): - # Literal types - if inst_type in (None, type(None)) and overrides == "None": - overrides = None - elif inst_type is bool and overrides in ("True", "true", "1", "False", "false", "0"): - overrides = overrides in ("True", "true", "1") - elif inst_type in (int, float) and isinstance(overrides, str): - overrides = inst_type(overrides) - if not isinstance(overrides, inst_type) and not ( - isinstance(overrides, int) and inst_type is float - ): - raise JsonValueError( - f"Type does not match, expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - return overrides - elif inst_type is Any: - # Any - if isinstance(overrides, str): - if overrides.isnumeric(): - return int(overrides) - elif overrides == "True": - return True - elif overrides == "False": - return True - return overrides - if isinstance(value, (dict, list, tuple)): - # Merge with dict, list, str - return self.override(value, overrides, type(value), allow_imports, _path, _stage) - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif typing.get_origin(inst_type) is Literal: - # Literal[value] - (value,) = typing.get_args(inst_type) - if value != overrides: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - return value - elif typing.get_origin(inst_type) is Union: - # Union[union_types[0], union_types[1], ...] - union_types = typing.get_args(inst_type) - if isinstance(overrides, str): - for subtype in union_types: - if subtype is None and overrides == "None": - return None - elif subtype is bool: - if overrides == "True": - return True - elif overrides == "False": - return False - elif subtype is int and overrides.strip().isnumeric(): - return int(overrides) - elif subtype is str: - return overrides - elif subtype is float and float_pattern.fullmatch(overrides): - return float(overrides) - if overrides.lstrip().startswith("{") or overrides.lstrip().startswith("["): - overrides = json.loads(overrides) - return self.raw_to_typed( - overrides, - inst_type, - allow_imports, - _path, - _stage, - ) - for subtype in union_types: - if _isinstance_deep(value, subtype): - return self.override( - value, - overrides, - subtype, - allow_imports, - f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}", - _stage + (1,), - ) - raise JsonValueError( - f"Expected {type_name} at {_path}, existing is {value!r} which is invalid", - inst_type, - value, - _path, - _stage, - ) - elif ( - isinstance(inst_type, type) - and issubclass(inst_type, tuple) - and hasattr(inst_type, "__annotations__") - ): - # class MyClass(NamedTuple): ... - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - kwargs = { - field_name: ( - self.override( - getattr(value, field_name), - local_overrides.pop(field_name), - field_type, - allow_imports, - f"{_path} -> {type_name}:{field_name}", - _stage + (idx,), - ) - if field_name in local_overrides - else getattr(value, field_name) - ) - for idx, (field_name, field_type) in enumerate(inst_type.__annotations__.items()) - } - if self.strict and len(local_overrides) != 0: - raise JsonValueError( - f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at " - f"{_path}", - inst_type, - overrides, - _path, - _stage, - ) - try: - return inst_type(**kwargs) - except BaseException: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif dataclasses.is_dataclass(inst_type): - # dataclass - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - kwargs = { - field.name: ( - self.override( - getattr(value, field.name), - local_overrides.pop(field.name), - field.type, - allow_imports, - f"{_path} -> {type_name}:{field.name}", - _stage + (idx,), - ) - if field.name in local_overrides - else getattr(value, field.name) - ) - for idx, field in enumerate(dataclasses.fields(inst_type)) - if field.init - } - if self.strict and len(local_overrides) != 0: - raise JsonValueError( - f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at " - f"{_path}", - inst_type, - overrides, - _path, - _stage, - ) - try: - return inst_type(**kwargs) - except BaseException: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif ( - typing.get_origin(inst_type) is list - or typing.get_origin(inst_type) is tuple - or inst_type in (list, tuple) - ): - # List[inner_type] or Tuple[inner_type, Ellipsis] or - # Tuple[inner_type[0], inner_type[1], ...] - if inst_type is list: - inner_type = Any - inner_types = [] - cls = list - elif inst_type is tuple: - inner_type = Any - inner_types = [] - cls = tuple - elif typing.get_origin(inst_type) is list: - (inner_type,) = typing.get_args(inst_type) - inner_types = [] - cls = list - else: - inner_types = typing.get_args(inst_type) - if len(inner_types) == 2 and inner_types[1] is Ellipsis: - inner_type = inner_types[0] - else: - inner_type = None - cls = tuple - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if not all(key.isnumeric() for key in local_overrides.keys()): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, expected integer keys", - inst_type, - overrides, - _path, - _stage, - ) - local_overrides_int = {int(key): value for key, value in local_overrides.items()} - new_max_idx = max(local_overrides_int.keys()) - original_max_idx = len(value) - if inner_type is None and new_max_idx >= len(inner_types): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, index {new_max_idx} out of " - f"bounds", - inst_type, - overrides, - _path, - _stage, - ) - for i in range(original_max_idx, new_max_idx): - if i not in local_overrides_int: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, missing value for index " - f"{i}", - inst_type, - overrides, - _path, - _stage, - ) - return cls( - ( - self.override( - value[idx], - local_overrides_int[idx], - inner_type, - allow_imports, - f"{_path} -> {idx}", - _stage + (idx,), - ) - if idx in local_overrides_int - else value[idx] - ) - for idx in range(max(new_max_idx + 1, original_max_idx)) - ) - elif typing.get_origin(inst_type) is dict or inst_type is dict: - # Dict[str, value_type] - if inst_type is dict: - value_type = Any - else: - key_type, value_type = typing.get_args(inst_type) - assert key_type is str - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - res = { - key: ( - self.override( - subvalue, - local_overrides.pop(key), - value_type, - allow_imports, - f"{_path} -> {type_name}:{key!r}", - _stage + (idx,), - ) - if key in local_overrides - else subvalue - ) - for idx, (key, subvalue) in value.items() - } - for key, val in local_overrides.items(): - if not isinstance(val, str): - raise JsonValueError( - f"Expected new {type_name} at {_path} -> {type_name}:{key!r}, got {val!r}", - inst_type, - overrides, - _path, - _stage, - ) - res[key] = self.raw_to_typed( - json.loads(val), - value_type, - allow_imports, - f"{_path} -> {type_name}:{key!r}", - _stage + (len(res),), - ) - return res - else: - raise RuntimeError(f"Unknown type {inst_type}") - def to_json_object(obj: Any) -> Any: """ @@ -1086,6 +645,16 @@ def to_json_object(obj: Any) -> Any: field_name: to_json_object(getattr(obj, field_name)) for field_name in obj.__annotations__.keys() } + elif isinstance(obj, type): + return { + "__module__": obj.__module__, + "__class__": obj.__name__, + } + elif isinstance(obj, Callable): + return { + "__module__": obj.__module__, + "__function__": obj.__name__, + } elif dataclasses.is_dataclass(obj): # dataclass return { @@ -1093,7 +662,7 @@ def to_json_object(obj: Any) -> Any: for field in dataclasses.fields(obj) if field.init } - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, (list, tuple, set)): return [to_json_object(val) for val in obj] elif isinstance(obj, dict): return {key: to_json_object(val) for key, val in obj.items()} @@ -1101,41 +670,18 @@ def to_json_object(obj: Any) -> Any: raise RuntimeError(f"Unknown type {type(obj)}") -float_pattern = re.compile(r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?") - - -def _split_dict_keys(dct: Dict[str, Any]) -> Dict[str, Any]: - """Splits the given dict keys by first '.' to subdicts.""" - res = {} - for key, value in dct.items(): - if "." in key: - outer_key, _, inner_key = key.partition(".") - if outer_key in res: - if not isinstance(res[outer_key], dict): - raise ValueError(f"Cannot combine {outer_key!r} with {res!r}") - res[outer_key][inner_key] = value - else: - res[outer_key] = {inner_key: value} - else: - if key in res: - raise ValueError(f"Cannot combine {key!r} with {res!r}") - res[key] = value - - return res - - def _isinstance_deep(val: Any, tp_chk: Type) -> bool: """Verifies if the given value is an instance of the tp_chk, allowing for typing extensions.""" if tp_chk is Any: return True elif typing.get_origin(tp_chk) is Literal: - (value,) = typing.get_args(val) - return val == value + values = typing.get_args(tp_chk) + return val in values elif typing.get_origin(tp_chk) is list: - (inner_type,) = typing.get_args(val) + (inner_type,) = typing.get_args(tp_chk) return isinstance(val, list) and all(_isinstance_deep(v, inner_type) for v in val) elif typing.get_origin(tp_chk) is tuple: - inner_types = typing.get_args(val) + inner_types = typing.get_args(tp_chk) if len(inner_types) == 2 and inner_types[1] == Ellipsis: return isinstance(val, tuple) and all(_isinstance_deep(v, inner_types[0]) for v in val) else: @@ -1145,7 +691,7 @@ def _isinstance_deep(val: Any, tp_chk: Type) -> bool: and all(_isinstance_deep(v, inner_type) for v, inner_type in zip(val, inner_types)) ) elif typing.get_origin(tp_chk) is dict: - key_type, value_type = typing.get_args(val) + key_type, value_type = typing.get_args(tp_chk) return isinstance(val, dict) and all( _isinstance_deep(k, key_type) and _isinstance_deep(v, value_type) for k, v in val.items() diff --git a/src/megatron/energon/watchdog.py b/src/megatron/energon/watchdog.py index 2561d6a0..9755dd21 100644 --- a/src/megatron/energon/watchdog.py +++ b/src/megatron/energon/watchdog.py @@ -67,7 +67,9 @@ def __init__( # Condition variable to manage state changes self._cv = threading.Condition() # Background thread (daemon) that monitors timeouts - self._worker_thread = threading.Thread(target=self._worker, daemon=True) + self._worker_thread = threading.Thread( + name=f"watchdog-{id(self)}", target=self._worker, daemon=True + ) self._worker_thread.start() def _get_next_timeout(self) -> float: diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index fca26369..ec8e2c9c 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -4,6 +4,7 @@ import hashlib import json import multiprocessing +import threading from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, ClassVar, Dict, List, Optional, TextIO, TypeVar @@ -20,6 +21,51 @@ T = TypeVar("T") +class ActiveWorkerState: + """ + Thread local state for the active worker config. + """ + + _thread_local: threading.local + + @property + def sample_index_stack(self) -> Optional[List[int]]: + """The current sample index stack for the worker.""" + return getattr(self._thread_local, "sample_index_stack", None) + + @property + def cache_pool(self) -> Optional[CachePool]: + """The current cache pool for the worker.""" + return getattr(self._thread_local, "cache_pool", None) + + @property + def worker_config(self) -> "WorkerConfig | None": + return getattr(self._thread_local, "worker_config", None) + + @sample_index_stack.setter + def sample_index_stack(self, value: List[int]): + self._thread_local.sample_index_stack = value + + @cache_pool.setter + def cache_pool(self, value: Optional[CachePool]): + self._thread_local.cache_pool = value + + @worker_config.setter + def worker_config(self, value: "WorkerConfig | None"): + self._thread_local.worker_config = value + + def __init__(self): + self._thread_local = threading.local() + + +class classproperty: + def __init__(self, getter): + self.getter = getter + + def __get__(self, instance, owner): + return self.getter(owner) + + @dataclass(slots=True, kw_only=True, eq=False) class WorkerConfig: """ @@ -72,61 +118,60 @@ class WorkerConfig: [Exception, Any | list[Any], Optional[list[SourceInfo]]], None ] = reraise_exception - #: The current sample index within the current iterating worker - _sample_index_stack: ClassVar[Optional[List[int]]] = None - #: The current worker config within the current iterating worker - active_worker_config: ClassVar[Optional["WorkerConfig"]] = None + _active_state: ClassVar[ActiveWorkerState] = ActiveWorkerState() - #: The global rank override for the worker. Required for restoring samples. - _worker_override_global_rank: ClassVar[Optional[List[int]]] = None - - #: The current cache pool for the worker. - _cache_pool: "ClassVar[Optional[CachePool]]" = None + @classproperty + def active_worker_config(cls) -> Optional["WorkerConfig"]: + """The current worker config within the current iterating worker""" + return cls._active_state.worker_config def worker_activate( self, sample_index: int, - override_global_rank: Optional[int] = None, cache_pool: "Optional[CachePool]" = None, ): """Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.""" - assert WorkerConfig.active_worker_config is None - WorkerConfig._sample_index_stack = [sample_index] - WorkerConfig.active_worker_config = self - WorkerConfig._worker_override_global_rank = override_global_rank - WorkerConfig._cache_pool = cache_pool + assert WorkerConfig._active_state.worker_config is None, ( + f"Worker config already active for thread={threading.get_ident()}" + ) + WorkerConfig._active_state.sample_index_stack = [sample_index] + WorkerConfig._active_state.worker_config = self + WorkerConfig._active_state.cache_pool = cache_pool def worker_push_sample_index(self, sample_index: int): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" - assert WorkerConfig.active_worker_config is not None - WorkerConfig._sample_index_stack.append(sample_index) + assert WorkerConfig._active_state.sample_index_stack is not None + WorkerConfig._active_state.sample_index_stack.append(sample_index) def worker_pop_sample_index(self): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" - assert WorkerConfig.active_worker_config is not None - return WorkerConfig._sample_index_stack.pop() + assert WorkerConfig._active_state.sample_index_stack is not None + return WorkerConfig._active_state.sample_index_stack.pop() def worker_deactivate(self): """Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.""" - if WorkerConfig.active_worker_config is not None: - assert len(WorkerConfig._sample_index_stack) == 1, ( - f"Sample index stack not empty: {WorkerConfig._sample_index_stack}" - ) - WorkerConfig._sample_index_stack = None - WorkerConfig.active_worker_config = None - WorkerConfig._worker_override_global_rank = None + assert WorkerConfig._active_state.worker_config is self, "Worker config mismatch" + assert WorkerConfig._active_state.sample_index_stack is not None + assert len(WorkerConfig._active_state.sample_index_stack) == 1, ( + f"Sample index stack not empty: {WorkerConfig._active_state.sample_index_stack}" + ) + WorkerConfig._active_state.sample_index_stack = None + WorkerConfig._active_state.worker_config = None + WorkerConfig._active_state.cache_pool = None @property def active_worker_sample_index(self) -> int: """Returns the current sample index for the actively iterating worker.""" # Internal sample index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. + assert WorkerConfig._active_state.sample_index_stack is not None return ( - WorkerConfig._sample_index_stack[-1] * max(self.num_workers, 1) + self.rank_worker_id() + WorkerConfig._active_state.sample_index_stack[-1] * max(self.num_workers, 1) + + self.rank_worker_id() ) @property @@ -134,10 +179,22 @@ def active_worker_batch_index(self) -> int: """Returns the current batch index for the actively iterating worker.""" # Internal batch index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. + assert WorkerConfig._active_state.sample_index_stack is not None return ( - WorkerConfig._sample_index_stack[0] * max(self.num_workers, 1) + self.rank_worker_id() + WorkerConfig._active_state.sample_index_stack[0] * max(self.num_workers, 1) + + self.rank_worker_id() ) + @staticmethod + def active_worker_cache_pool() -> Optional[CachePool]: + """Returns the current cache pool for the actively iterating worker.""" + return WorkerConfig._active_state.cache_pool + + @property + def safe_num_workers(self) -> int: + """Returns the number of workers, but at least 1.""" + return max(self.num_workers, 1) + def global_rank(self) -> int: """Returns the global rank of this worker config but as a global rank, not as a rank within the data parallel group.""" @@ -179,22 +236,6 @@ def default_worker_config( data_parallel_group=data_parallel_group, ) - def rank_worker_id(self) -> int: - """Returns the self worker id within the current rank.""" - if self._worker_override_global_rank: - assert self.worker_id_offset == 0 - return self._worker_override_global_rank % self.num_workers - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - return self.worker_id_offset - assert worker_info.num_workers == self.num_workers - # Apply the worker_id_offset as a left rotation of the logical worker ids. - # This ensures that after restoring a checkpoint the first physical - # worker (id=0) corresponds to the logical worker that should emit the - # next sample. For example, if `worker_id_offset` is 1, logical worker - # 1 becomes the first to emit a sample, shifting the ordering forward. - return (worker_info.id + self.worker_id_offset) % worker_info.num_workers - def assert_worker(self): """Checks if the current process is a worker (if configured so), and that the workers are properly configured.""" @@ -208,23 +249,49 @@ def assert_worker(self): f"match the configured number of workers ({self.num_workers})" ) + def rank_worker_id(self, override_global_worker_id: Optional[int] = None) -> int: + """Returns the self worker id within the current rank. + Optionally computes the worker id from a global worker id. + + Args: + override_global_worker_id: The global worker id to compute the rank worker id from. + None means the current worker, which is the default. If not set, must be called + within the worker. + """ + if override_global_worker_id is not None: + assert ( + self.rank * self.safe_num_workers + <= override_global_worker_id + < (self.rank + 1) * self.safe_num_workers + ), f"Invalid global worker id: {override_global_worker_id}" + return override_global_worker_id - self.rank * self.safe_num_workers + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return self.worker_id_offset + assert worker_info.num_workers == self.num_workers + # Apply the worker_id_offset as a left rotation of the logical worker ids. + # This ensures that after restoring a checkpoint the first physical + # worker (id=0) corresponds to the logical worker that should emit the + # next sample. For example, if `worker_id_offset` is 1, logical worker + # 1 becomes the first to emit a sample, shifting the ordering forward. + return (worker_info.id + self.worker_id_offset) % max(worker_info.num_workers, 1) + def global_worker_id(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the global worker index by multiplying the rank with the number of workers. Alternatively, you can override the local worker id. Args: - override_local_worker_id (int, optional): The local worker id to override. None means - the current worker, which is the default. + override_local_worker_id: The local worker id to override. None means + the current worker, which is the default. If not set, must be called + within the worker. """ - if self._worker_override_global_rank is not None: - assert override_local_worker_id is None - return self._worker_override_global_rank - if override_local_worker_id is not None: - return self.rank * self.num_workers + override_local_worker_id - else: - self.assert_worker() - return self.rank * self.num_workers + self.rank_worker_id() + assert 0 <= override_local_worker_id < self.safe_num_workers, ( + f"Invalid local worker id: {override_local_worker_id}" + ) + return self.rank * self.safe_num_workers + override_local_worker_id + self.assert_worker() + return self.rank * self.safe_num_workers + self.rank_worker_id() def worker_seed(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the seed for the current worker (or a specified worker). diff --git a/src/megatron/energon/wrappers/__init__.py b/src/megatron/energon/wrappers/__init__.py index 808ca50e..221a270c 100644 --- a/src/megatron/energon/wrappers/__init__.py +++ b/src/megatron/energon/wrappers/__init__.py @@ -5,6 +5,7 @@ from megatron.energon.wrappers.blend_dataset import BlendDataset from megatron.energon.wrappers.concat_dataset import ConcatDataset from megatron.energon.wrappers.epochize_dataset import EpochizeDataset +from megatron.energon.wrappers.file_store_init_wrapper import FileStoreInitWrapper from megatron.energon.wrappers.filter_dataset import FilterDataset from megatron.energon.wrappers.gc_dataset import GcDataset from megatron.energon.wrappers.group_batch_dataset import GroupBatchDataset @@ -24,6 +25,7 @@ from megatron.energon.wrappers.skip import SkipSample __all__ = [ + "FileStoreInitWrapper", "BatchDataset", "BlendDataset", "ConcatDataset", diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 2f4dba89..55c405bf 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -1,15 +1,17 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Generator, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union from megatron.energon.flavors.base_dataset import ( FlexState, + RestoreKey, Sample, SavableDataset, - add_sample_restore_key, ) from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig @@ -54,6 +56,9 @@ def dataset(self) -> SavableDataset: assert len(self.datasets) == 1 return self.datasets[0] + def len_worker(self, worker_idx: int | None = None) -> int: + return sum(ds.len_worker(worker_idx) for ds in self.datasets) + def can_restore_sample(self) -> bool: return all(ds.can_restore_sample() for ds in self.datasets) @@ -76,19 +81,19 @@ def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDa return res return None - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: - if len(self.datasets) == 1: - return self.datasets[0].restore_sample(restore_key) - else: - id, ds_idx = restore_key[:2] - assert id == type(self).__name__ - restore_key = restore_key[2:] - assert isinstance(ds_idx, int) - return add_sample_restore_key( - self.datasets[ds_idx].restore_sample(restore_key), - ds_idx, - src=self, - ) + @abstractmethod + def reset_state_own(self) -> None: + """Resets the state of the dataset, excl. the inner datasets.""" + ... + + def reset_state(self) -> None: + """Resets the state of the inner datasets and then the own state.""" + + for ds in self.datasets: + ds.reset_state() + + super().reset_state() + self.reset_state_own() def save_state(self) -> FlexState: own_state = super().save_state() @@ -102,21 +107,19 @@ def restore_state(self, state: FlexState) -> None: super().restore_state(state) - def reset_state_deep(self) -> None: - """Resets the state of the inner datasets and then the own state.""" + def restore_sample(self, restore_key: RestoreKey) -> T_sample_out: + assert len(self.datasets) == 1, "Must be implemented by subclass" + return self.dataset.restore_sample(restore_key) + def worker_close(self) -> None: for ds in self.datasets: - if isinstance(ds, BaseWrapperDataset): - ds.reset_state_deep() - else: - ds.reset_state_own() - - self.reset_state_own() + ds.worker_close() + super().worker_close() - @abstractmethod - def reset_state_own(self) -> None: - """Resets the state of the dataset, excl. the inner datasets.""" - ... + def close(self) -> None: + for ds in self.datasets: + ds.close() + super().close() class SampleIndex(Savable): @@ -141,7 +144,9 @@ def get_next(self) -> int: def ctx(self, sample_idx: Optional[int] = None): if sample_idx is None: sample_idx = self.get_next() - assert WorkerConfig.active_worker_config is not None + assert WorkerConfig.active_worker_config is not None, ( + f"WorkerConfig.active_worker_config is None on thread {threading.get_ident()}" + ) WorkerConfig.active_worker_config.worker_push_sample_index(sample_idx) # print(" " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}") SampleIndex.actives += 1 @@ -182,7 +187,37 @@ def restore_state(self, state: Optional[int]) -> None: self.current_idx = state -def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]: +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class WrappedRestoreKey(RestoreKey): + inner: RestoreKey + + +def wrap_sample_restore_key( + sample: T_sample, t: "Type[WrappedRestoreKey]", fail_otherwise: bool = False, **kwargs +) -> T_sample: + """Adds a key to a sample. The sample must be a valid `Sample` or dict containing + __restore_key__, which is a tuple of keys that can be used to restore the inner sample. + This restore key is prepended with the `key`.""" + if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): + try: + sample.__restore_key__ = t(inner=sample.__restore_key__, **kwargs) + except KeyError: + pass + elif isinstance(sample, dict) and "__restore_key__" in sample: + sample["__restore_key__"] = t(inner=sample["__restore_key__"], **kwargs) + elif fail_otherwise: + raise RuntimeError( + "Did not yield a sample with a restore key, but is marked stateless/deterministic." + ) + return sample + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class MultiWrappedRestoreKey(RestoreKey): + inner: tuple[RestoreKey | None, ...] + + +def get_sample_restore_key(sample: Any) -> RestoreKey | None: """Gets the restore key from an arbitrary sample.""" if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): return sample.__restore_key__ diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 8cdfa8eb..5466aad1 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -11,7 +12,6 @@ Iterator, List, Optional, - Tuple, TypeVar, Union, ) @@ -19,12 +19,28 @@ from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + RestoreKey, + SampleIndex, + get_sample_restore_key, +) T_batch = TypeVar("T_batch", covariant=True) T_batch_sample = TypeVar("T_batch_sample", covariant=True) +@dataclass(kw_only=True, slots=True, frozen=True) +class BatchRestoreKey(MultiWrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class BatchGenRestoreKey(BatchRestoreKey): + gen_idx: int | None = None + + class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]): """This dataset wrapper transforms a dataset of samples into a dataset of batches.""" @@ -32,11 +48,12 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool _sample_index: SampleIndex - _generator_sample_keys: Optional[Any] + _generator_sample_keys: Optional[list[Any]] _generator_offset: Optional[int] _batch_failure_handler: ErrorContext _savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset") + _worker_local_fields = ("_last_batch_failures",) def __init__( self, @@ -78,8 +95,6 @@ def __init__( tolerance=failure_tolerance, ) - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) self._generator_sample_keys = None @@ -94,7 +109,7 @@ def len_worker(self, worker_idx: int | None = None) -> int: def __iter__(self) -> Iterator[T_batch]: batch: List[T_batch_sample] = [] - sample_restore_keys = [] + sample_restore_keys: list[RestoreKey | None] = [] if self._generator_sample_keys is not None: sample_restore_keys = self._generator_sample_keys @@ -116,10 +131,11 @@ def __iter__(self) -> Iterator[T_batch]: self._generator_offset = batch_sub_idx + 1 yield set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, + BatchGenRestoreKey( + sample_idx=sample_idx, + gen_idx=batch_sub_idx, + inner=tuple(sample_restore_keys), + ), ) self._generator_sample_keys = None self._generator_offset = None @@ -143,16 +159,20 @@ def flush() -> Generator[T_batch, None, None]: self._batch_failure_handler.reset() yield set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, + BatchGenRestoreKey( + sample_idx=sample_idx, + gen_idx=batch_sub_idx, + inner=tuple(sample_restore_keys), + ), ) self._generator_sample_keys = None self._generator_offset = None else: self._batch_failure_handler.reset() - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key( + batch_sample, + BatchRestoreKey(sample_idx=sample_idx, inner=tuple(sample_restore_keys)), + ) yield batch_sample sample_restore_keys.clear() @@ -176,42 +196,36 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_batch: + def restore_sample(self, restore_key: RestoreKey) -> T_batch: # We need to store multiple indices to restore a batch. self.assert_can_restore() + assert isinstance(restore_key, BatchRestoreKey) if inspect.isgeneratorfunction(self.batcher): - id, sample_idx, batch_sub_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, sample_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - + assert isinstance(restore_key, BatchGenRestoreKey) + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in restore_key.inner] with handle_restore_errors(self.worker_config.restore_error_handler, batch): - with self._sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(restore_key.sample_idx): batch_sample = self.batcher(batch) if isinstance(batch_sample, Generator): assert inspect.isgeneratorfunction(self.batcher), ( f"Generator in {self.batcher} but not marked as such." ) + assert isinstance(restore_key, BatchGenRestoreKey) for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx( + batch_sample, restore_key.sample_idx + ) ): - if cur_batch_sub_idx == batch_sub_idx: + if cur_batch_sub_idx == restore_key.gen_idx: return set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *samples_restore_keys, - src=self, + restore_key, ) - assert False, f"Batch sub-index {batch_sub_idx} not found in batch" + assert False, f"Batch sub-index {restore_key.gen_idx} not found in batch" else: return set_sample_restore_key( batch_sample, - sample_idx, - *samples_restore_keys, - src=self, + restore_key, ) def config(self) -> Dict[str, Any]: diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index e1ed5738..c8f49b5d 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -1,18 +1,29 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar import torch -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + RestoreKey, + WrappedRestoreKey, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class BlendRestoreKey(WrappedRestoreKey): + dataset_idx: int + + class BlendDataset(BaseWrapperDataset[T_sample, T_sample]): """ This dataset wrapper blends multiple iterable datasets together give a weighting. @@ -46,7 +57,6 @@ def __init__( super().__init__(self.datasets, worker_config=worker_config) self.dataset_weights = dataset_weights - self.reset_state_own() def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) @@ -105,10 +115,14 @@ def __iter__(self) -> Iterator[T_sample]: if all(dataset_iter is None for dataset_iter in dataset_iters): break else: - yield add_sample_restore_key(sample, ds_idx, src=self) + yield wrap_sample_restore_key(sample, BlendRestoreKey, dataset_idx=ds_idx) self.exhausted = [False] * len(self.dataset_weights) + def restore_sample(self, restore_key: RestoreKey) -> T_sample: + assert isinstance(restore_key, BlendRestoreKey) + return self.datasets[restore_key.dataset_idx].restore_sample(restore_key.inner) + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/buffer.py b/src/megatron/energon/wrappers/buffer.py index 09da5739..0d2eb2b8 100644 --- a/src/megatron/energon/wrappers/buffer.py +++ b/src/megatron/energon/wrappers/buffer.py @@ -10,32 +10,37 @@ List, Optional, Sequence, - Tuple, TypeVar, Union, ) -from megatron.energon.flavors.base_dataset import FlexState, SavableDataset +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset +from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key +from megatron.energon.wrappers.base import get_sample_restore_key T_sample = TypeVar("T_sample") -class SavableSampleBuffer(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): - """A buffer of samples, savable.""" +@edataclass +class SavableSampleBufferState: + """State of a SavableSampleBuffer.""" - _buffer: List[T_sample] - _restore_keys: List[Tuple[Union[str, int, tuple], ...]] + restore_keys: list[RestoreKey | None] + + +class SavableSampleBuffer(Savable, Generic[T_sample]): + """A buffer of samples, savable. State is shared, create a state-local instance.""" + + _buffer: list[T_sample] + _restore_keys: list[RestoreKey | None] - _savable_fields = ("_restore_keys",) _restore_pending: bool = False def __init__(self, dataset: SavableDataset[T_sample], *, worker_config: WorkerConfig): - super().__init__(dataset, worker_config=worker_config) - self.reset_state_own() - - def reset_state_own(self) -> None: + self.dataset = dataset + self.worker_config = worker_config self._buffer = [] self._restore_keys = [] @@ -44,7 +49,7 @@ def worker_start(self) -> None: assert len(self._buffer) == 0 self._restore_pending = False for restore_key in self._restore_keys: - self._buffer.append(self.restore_sample(restore_key)) + self._buffer.append(self.dataset.restore_sample(restore_key)) assert len(self._buffer) == len(self._restore_keys) def append(self, sample: T_sample) -> T_sample: @@ -67,12 +72,12 @@ def pop(self, index: int) -> T_sample: self._restore_keys.pop(index) return self._buffer.pop(index) - def flush(self) -> Tuple[List[T_sample], Tuple[Any, ...]]: + def flush(self) -> tuple[list[T_sample], tuple[RestoreKey | None, ...]]: buffer = list(self._buffer) - restore_key = tuple(self._restore_keys) + restore_keys = tuple(self._restore_keys) self._buffer.clear() self._restore_keys.clear() - return buffer, restore_key + return buffer, restore_keys @property def buffer(self) -> List[T_sample]: @@ -105,28 +110,27 @@ def len_worker(self, worker_idx: int | None = None) -> int: def len_rank(self) -> int: raise NotImplementedError("len_rank is not available for SavableSampleBuffer") - def save_state(self) -> FlexState: + def save_state(self) -> SavableSampleBufferState: # Don't call super().save_state() because we don't want to save the wrapped datasets # Just save the own state - return SavableDataset.save_state(self) + return SavableSampleBufferState(restore_keys=self._restore_keys.copy()) - def restore_state(self, state: FlexState) -> None: + def restore_state(self, state: SavableSampleBufferState) -> None: # Don't call super().restore_state() because we don't want to restore the wrapped datasets # Just restore the own state - SavableDataset.restore_state(self, state) - + self._restore_keys = state.restore_keys.copy() self._restore_pending = True - def restore_key(self) -> Tuple[Union[str, int], ...]: + def restore_key(self) -> tuple[RestoreKey | None, ...]: return tuple(self._restore_keys) def restore_samples( - self, index: Tuple[Union[str, int, tuple], ...] - ) -> Tuple[Tuple[Union[str, int, tuple], ...], List[T_sample]]: + self, index: tuple[RestoreKey | None, ...] + ) -> tuple[tuple[RestoreKey | None, ...], list[T_sample]]: buffer = [] restore_keys = [] for sub_index in index: - sample = self.restore_sample(sub_index) + sample = self.dataset.restore_sample(sub_index) restore_keys.append(get_sample_restore_key(sample)) buffer.append(sample) return tuple(restore_keys), buffer diff --git a/src/megatron/energon/wrappers/concat_dataset.py b/src/megatron/energon/wrappers/concat_dataset.py index 83e35660..38c7df7d 100644 --- a/src/megatron/energon/wrappers/concat_dataset.py +++ b/src/megatron/energon/wrappers/concat_dataset.py @@ -1,15 +1,25 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import Any, Dict, Generic, Iterator, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + WrappedRestoreKey, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class ConcatRestoreKey(WrappedRestoreKey): + dataset_idx: int + + class ConcatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """ This dataset wrapper concatenates multiple iterable datasets together. The datasets must be @@ -35,12 +45,16 @@ def len_worker(self, worker_idx: int | None = None) -> int: def __iter__(self) -> Iterator[T_sample]: for ds_idx, dataset in enumerate(self.datasets): for sample in dataset: - yield add_sample_restore_key( + yield wrap_sample_restore_key( sample, - ds_idx, - src=self, + ConcatRestoreKey, + dataset_idx=ds_idx, ) + def restore_sample(self, restore_key: RestoreKey) -> T_sample: + assert isinstance(restore_key, ConcatRestoreKey) + return self.datasets[restore_key.dataset_idx].restore_sample(restore_key.inner) + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 3f6f71a0..85fc7727 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -43,8 +43,6 @@ def __init__( self.length = length self._active_iter = None - self.reset_state_own() - def reset_state_own(self) -> None: self._offset = 0 diff --git a/src/megatron/energon/wrappers/file_store_init_wrapper.py b/src/megatron/energon/wrappers/file_store_init_wrapper.py new file mode 100644 index 00000000..d47aa951 --- /dev/null +++ b/src/megatron/energon/wrappers/file_store_init_wrapper.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Dict, Generic, Iterator, Sequence, TypeVar + +from megatron.energon.cache.base import FileStore +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import BaseWrapperDataset + +T_sample = TypeVar("T_sample") + + +class FileStoreInitWrapper(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): + """This dataset wraps a primary dataset, and additional auxiliary datasets for (de)initialization.""" + + auxiliary_datasets: Sequence[FileStore] + + def __init__( + self, + dataset: SavableDataset[T_sample], + *, + auxiliary_datasets: Sequence[FileStore] = (), + worker_config: WorkerConfig, + ): + """Construct the auxiliary datasets dataset, which wraps a primary dataset and additional + auxiliary datasets for initialization. + + Args: + dataset: The input dataset to wrap + auxiliary_datasets: The additional datasets to (de)initialize + worker_config: The worker configuration + """ + super().__init__(dataset, worker_config=worker_config) + self.auxiliary_datasets = auxiliary_datasets + + def reset_state_own(self) -> None: + for ds in self.auxiliary_datasets: + ds.worker_init() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.worker_init() + + def worker_close(self) -> None: + for ds in self.auxiliary_datasets: + ds.worker_close() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.worker_close() + super().worker_close() + + def close(self) -> None: + for ds in self.auxiliary_datasets: + ds.close() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.close() + super().close() + + def __iter__(self) -> Iterator[T_sample]: + yield from self.dataset + + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + + def config(self) -> Dict[str, Any]: + # Transparent logger, it won't change the samples + return self.dataset.config() + + def __str__(self): + return f"FileStoreInitWrapper(auxiliary_datasets={self.auxiliary_datasets}, dataset={self.dataset})" diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index f28d84f0..2302c104 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -42,14 +42,9 @@ def __init__( self.filter_fn = filter_fn self.filter_fn_config = filter_fn_config - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: for sample in self.dataset: with self._sample_index.ctx(): diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index 70f31688..82fdc375 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -11,6 +11,7 @@ from torch.distributed.distributed_c10d import reduce_op from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -97,9 +98,6 @@ def __init__( def reset_state_own(self) -> None: return - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: in_worker = torch.utils.data.get_worker_info() is not None if in_worker and not _frozen_cuda_tensors_initialized: @@ -122,6 +120,14 @@ def __iter__(self) -> Iterator[T_sample]: if self.freeze: gc.unfreeze() + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # This is transparent, no config to be saved (it does not affect the dataset) return self.dataset.config() diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 41a45b20..d36032af 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -20,34 +21,116 @@ from megatron.energon.edataclass import edataclass from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import ( - FlexState, + RestoreKey, SavableDataset, set_sample_restore_key, ) from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex -from megatron.energon.wrappers.buffer import SavableSampleBuffer +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, +) +from megatron.energon.wrappers.buffer import SavableSampleBuffer, SavableSampleBufferState T_batch = TypeVar("T_batch", covariant=True) T_batch_sample = TypeVar("T_batch_sample", covariant=True) @edataclass -class Bucket(Savable, Generic[T_batch_sample]): +class BucketState: + """State of a bucket. This is used to save and restore the bucket.""" + batch_size: int + samples: SavableSampleBufferState + + +@edataclass +class BucketsState: + """State of the buckets. This is used to save and restore the buckets.""" + + buckets: Dict[Hashable, BucketState] + +@edataclass +class Bucket(Savable, Generic[T_batch_sample]): + """A bucket for a GroupBatchDataset. It contains the samples.""" + + batch_size: int samples: SavableSampleBuffer[T_batch_sample] - def save_state(self) -> FlexState: - return FlexState( + def save_state(self) -> BucketState: + return BucketState( batch_size=self.batch_size, samples=self.samples.save_state(), ) - def restore_state(self, state: FlexState): - self.batch_size = state["batch_size"] - self.samples.restore_state(state["samples"]) + def restore_state(self, state: BucketState): + self.batch_size = state.batch_size + self.samples.restore_state(state.samples) + + +class Buckets(Savable, Generic[T_batch_sample]): + """This class manages the buckets for a GroupBatchDataset. It is a savable object, which can be saved and restored.""" + + _dataset: SavableDataset[T_batch_sample] + _worker_config: WorkerConfig + + _buckets: Dict[Hashable, Bucket[T_batch_sample]] + + def __init__(self, dataset: SavableDataset[T_batch_sample], worker_config: WorkerConfig): + self._dataset = dataset + self._worker_config = worker_config + self._buckets = {} + + def save_state(self) -> BucketsState: + return BucketsState( + buckets={key: bucket.save_state() for key, bucket in self._buckets.items()} + ) + + def restore_state(self, state: BucketsState): + self._buckets = { + key: Bucket( + batch_size=-1, + samples=SavableSampleBuffer(self._dataset, worker_config=self._worker_config), + ) + for key, bucket in state.buckets.items() + } + for key, bucket in self._buckets.items(): + bucket.restore_state(state.buckets[key]) + + def get(self, key: Hashable, batch_size: int | None) -> Bucket[T_batch_sample]: + """Get a bucket for a given key. If the bucket does not exist, create it.""" + bucket = self._buckets.get(key) + if bucket is None: + assert batch_size is not None + self._buckets[key] = bucket = Bucket( + batch_size=batch_size, + samples=SavableSampleBuffer(self._dataset, worker_config=self._worker_config), + ) + else: + assert bucket.batch_size == batch_size, ( + f"Got different batch size for group {key}: {bucket.batch_size} != {batch_size}." + ) + return bucket + + def flush(self) -> Generator[Bucket[T_batch_sample], None, None]: + """Yield all buckets and clear afterwards.""" + yield from self._buckets.values() + self._buckets.clear() + + def clear(self): + self._buckets.clear() + + def worker_start(self): + for bucket in self._buckets.values(): + bucket.samples.worker_start() + + +@dataclass(kw_only=True, slots=True, frozen=True) +class GroupBatchRestoreKey(MultiWrappedRestoreKey): + sample_idx: int class GroupBatchDataset( @@ -66,10 +149,13 @@ class GroupBatchDataset( drop_last: bool _group_key_sample_index: SampleIndex _batch_sample_index: SampleIndex - _buckets: Dict[Hashable, Bucket[T_batch_sample]] + _buckets: Buckets _batch_failure_handler: ErrorContext _group_key_failure_handler: ErrorContext + _savable_fields = ("_group_key_sample_index", "_batch_sample_index", "_buckets") + _worker_local_fields = ("_last_batch_failures",) + def __init__( self, dataset: SavableDataset[T_batch_sample], @@ -114,8 +200,6 @@ def __init__( tolerance=failure_tolerance, ) - self.reset_state_own() - assert not inspect.isgeneratorfunction(batcher), ( f"Batcher {batcher} must not be a generator function for grouped batching." ) @@ -123,24 +207,18 @@ def __init__( def reset_state_own(self) -> None: self._group_key_sample_index = SampleIndex(self.worker_config, src=self) self._batch_sample_index = SampleIndex(self.worker_config, src=self) - self._buckets = {} + self._buckets = Buckets(self.dataset, self.worker_config) def len_worker(self, worker_idx: int | None = None) -> int: # Return an upper bound. This is for sure not correct. return self.dataset.len_worker(worker_idx) def __iter__(self) -> Iterator[T_batch]: - buckets = self._buckets - - if buckets is None: - buckets = self._buckets = dict() - # Load saved state if available - for bucket in buckets.values(): - bucket.samples.worker_start() + self._buckets.worker_start() # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") - # for bucket_key, bucket in buckets.items(): + # for bucket_key, bucket in self._buckets._buckets.items(): # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") # bucket.samples.debug_print(" ") # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") @@ -148,7 +226,7 @@ def __iter__(self) -> Iterator[T_batch]: def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: # Debug print the state # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") - # for dbg_bucket_key, dbg_bucket in buckets.items(): + # for dbg_bucket_key, dbg_bucket in self._buckets._buckets.items(): # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") # dbg_bucket.samples.debug_print(" ") batch_items, sample_restore_keys = bucket.samples.flush() @@ -160,7 +238,10 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." ) self._batch_failure_handler.reset() - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key( + batch_sample, + GroupBatchRestoreKey(sample_idx=sample_idx, inner=sample_restore_keys), + ) yield batch_sample # Add samples to the buckets @@ -174,48 +255,18 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: ) if self.fixed_batch_size is not None: batch_size = self.fixed_batch_size - bucket = buckets.get(bucket_key) - if bucket is None: - assert batch_size is not None - buckets[bucket_key] = bucket = Bucket( - batch_size=batch_size, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), - ) - else: - assert bucket.batch_size == batch_size, ( - f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." - ) + bucket = self._buckets.get(bucket_key, batch_size) bucket.samples.append(sample) if bucket.samples.len_worker() >= bucket.batch_size: yield from flush(bucket) # Flush out last samples if not self.drop_last: - for bucket in buckets.values(): + for bucket in self._buckets.flush(): if bucket.samples.len_worker() > 0: yield from flush(bucket) # Clear the buckets self._buckets.clear() - def save_state(self) -> FlexState: - return FlexState( - bucket_sample_index=self._group_key_sample_index.save_state(), - batch_sample_index=self._batch_sample_index.save_state(), - buckets={key: bucket.save_state() for key, bucket in self._buckets.items()}, - **super().save_state(), - ) - - def restore_state(self, state: FlexState) -> None: - super().restore_state(state) - - self._group_key_sample_index.restore_state(state["bucket_sample_index"]) - self._batch_sample_index.restore_state(state["batch_sample_index"]) - for key, bucket_state in state["buckets"].items(): - self._buckets[key] = Bucket( - batch_size=-1, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), - ) - self._buckets[key].restore_state(bucket_state) - def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.batcher_stateless @@ -225,16 +276,15 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: + def restore_sample(self, index: RestoreKey) -> T_batch: self.assert_can_restore() - id, sample_idx, *sample_restore_keys = index - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] + assert isinstance(index, GroupBatchRestoreKey) + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in index.inner] with handle_restore_errors(self.worker_config.restore_error_handler, batch): - with self._batch_sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(index.sample_idx): batch_sample = self.batcher(batch) - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key(batch_sample, index) return batch_sample diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index a595e83b..40bab8b4 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import ( Any, Callable, @@ -9,7 +10,6 @@ Generic, Iterator, Optional, - Tuple, TypeVar, Union, ) @@ -17,14 +17,25 @@ from torch.utils.data import IterableDataset from megatron.energon.errors import ErrorContext, handle_restore_errors -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset, set_sample_restore_key from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, + get_sample_restore_key, +) T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") +@dataclass(kw_only=True, slots=True, frozen=True) +class IterMapRestoreKey(MultiWrappedRestoreKey): + sample_idx: int + iter_idx: int + + class IterMapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]): """This dataset wrapper applies a custom function to transform the stream of samples and yield a new stream of samples. @@ -80,8 +91,6 @@ def __init__( handler=worker_config.global_error_handler, ) - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) @@ -96,7 +105,7 @@ def __iter__(self) -> Iterator[T_sample_out]: # This is the sample index within the currently yielded sample iter_idx = 0 sample_idx = 0 - sample_restore_keys = [] + sample_restore_keys: list[RestoreKey | None] = [] def reset_idx_iter() -> Generator[T_sample, None, None]: # Resets the inner sample index @@ -115,10 +124,11 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)): yield set_sample_restore_key( sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + IterMapRestoreKey( + sample_idx=sample_idx, + iter_idx=iter_idx, + inner=tuple(sample_restore_keys), + ), ) sample_restore_keys.clear() iter_idx += 1 @@ -133,30 +143,26 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + def restore_sample(self, restore_key: RestoreKey) -> T_sample: self.assert_can_restore() - id, sample_idx, iter_idx, *sample_restore_keys = restore_key - assert id == type(self).__name__ - assert isinstance(iter_idx, int) + assert isinstance(restore_key, IterMapRestoreKey) to_be_mapped = tuple( - self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys + self.dataset.restore_sample(inner_index) for inner_index in restore_key.inner ) with handle_restore_errors(self.worker_config.restore_error_handler, to_be_mapped): inner_iter = iter(self.iter_map_fn(iter(to_be_mapped))) try: + sample_index = SampleIndex(self.worker_config, src=self) # Skip inner yielded samples to get the correct sample - for skip_idx in range(iter_idx): - with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): + for skip_idx in range(restore_key.iter_idx): + with sample_index.ctx(restore_key.sample_idx - restore_key.iter_idx + skip_idx): next(inner_iter) # This is the sample to restore - with self._sample_index.ctx(sample_idx): + with sample_index.ctx(restore_key.sample_idx): sample = next(inner_iter) return set_sample_restore_key( sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + restore_key, ) except StopIteration: raise RuntimeError( diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index 9ee76200..62425ec9 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -38,7 +38,6 @@ def __init__( super().__init__(dataset, worker_config=worker_config) self.length = length self.reset_after_epoch = reset_after_epoch - self.reset_state_own() def reset_state_own(self) -> None: self.current_offset = 0 @@ -104,10 +103,9 @@ def __iter__(self) -> Iterator[T_sample]: ) # Reset the inner dataset - self.dataset.reset_state_deep() self.current_offset = 0 if self.reset_after_epoch: - self.dataset.reset_state_deep() + self.dataset.reset_state() def worker_has_samples(self) -> bool: return super().worker_has_samples() and self.length > 0 diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index d291fb87..de401357 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -5,6 +5,7 @@ from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.sample_utils import default_get_batch_keys +from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -36,14 +37,9 @@ def __init__( self.get_keys_fn = get_keys_fn self.mode = mode - self.reset_state_own() - def reset_state_own(self) -> None: self._step = 0 - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def _log(self, sample: T_sample) -> None: if self.worker_config.should_log(level=1): log_entry = { @@ -65,6 +61,14 @@ def __iter__(self) -> Iterator[T_sample]: self._step += 1 yield sample + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # Transparent logger, it won't change the samples return self.dataset.config() diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index a4b4873b..0b7a4164 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -10,7 +11,6 @@ Generic, Iterator, Optional, - Tuple, TypeVar, Union, ) @@ -19,14 +19,34 @@ ErrorContext, handle_restore_errors, ) -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + RestoreKey, + SavableDataset, + set_sample_restore_key, +) from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + SampleIndex, + WrappedRestoreKey, + get_sample_restore_key, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") +@dataclass(kw_only=True, slots=True, frozen=True) +class MapRestoreKey(WrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class MapGenRestoreKey(MapRestoreKey): + gen_idx: int + + class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]): """This dataset wrapper applies a custom function to transform each sample.""" @@ -34,8 +54,8 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex - _generator_sample_key: Optional[Any] - _generator_offset: Optional[int] + _generator_sample_key: RestoreKey | None + _generator_offset: int | None _map_failure_handler: ErrorContext _savable_fields = ( @@ -44,6 +64,8 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T "_generator_offset", ) + _worker_local_fields = ("_last_map_failures",) + def __init__( self, dataset: SavableDataset[T_sample], @@ -89,9 +111,6 @@ def reset_state_own(self) -> None: self._generator_sample_key = None self._generator_offset = None - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample_out]: if self._generator_sample_key is not None: assert self._generator_offset is not None @@ -111,11 +130,11 @@ def __iter__(self) -> Iterator[T_sample_out]: # Skip other samples if idx >= target_offset: self._generator_offset = idx + 1 - yield add_sample_restore_key( + yield wrap_sample_restore_key( inner_sample, - sample_idx, - idx, - src=self, + MapGenRestoreKey, + sample_idx=sample_idx, + gen_idx=idx, ) self._generator_sample_key = None self._generator_offset = None @@ -138,20 +157,20 @@ def __iter__(self) -> Iterator[T_sample_out]: ): self._generator_offset = idx + 1 self._map_failure_handler.reset() - yield add_sample_restore_key( + yield wrap_sample_restore_key( inner_sample, - sample_idx, - idx, - src=self, + MapGenRestoreKey, + sample_idx=sample_idx, + gen_idx=idx, ) self._generator_sample_key = None self._generator_offset = None else: self._map_failure_handler.reset() - yield add_sample_restore_key( + yield wrap_sample_restore_key( mapped_sample, - sample_idx, - src=self, + MapRestoreKey, + sample_idx=sample_idx, ) def can_restore_sample(self) -> bool: @@ -163,36 +182,37 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: + def restore_sample(self, restore_key: RestoreKey) -> T_sample_out: self.assert_can_restore() + assert isinstance(restore_key, MapRestoreKey), ( + f"Expected MapRestoreKey, got {type(restore_key)}" + ) if inspect.isgeneratorfunction(self.map_fn): - id, sample_idx, local_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - assert isinstance(local_idx, int) - else: - id, sample_idx = restore_key[:2] - assert id == type(self).__name__ - restore_key = restore_key[2:] - inner_sample = self.dataset.restore_sample(restore_key) - + assert isinstance(restore_key, MapGenRestoreKey) + inner_sample = self.dataset.restore_sample(restore_key.inner) with handle_restore_errors(self.worker_config.restore_error_handler, inner_sample): - with self._sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(restore_key.sample_idx): mapped_sample = self.map_fn(inner_sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) + assert isinstance(restore_key, MapGenRestoreKey) for idx, (sample_idx, res_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx( + mapped_sample, restore_key.sample_idx + ) ): - if idx == local_idx: - return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) + if idx == restore_key.gen_idx: + return set_sample_restore_key( + res_sample, + restore_key, + ) assert False, ( - "Generator did not yield enough samples, but is marked stateless/deterministic." + f"Generator sub-index {restore_key.gen_idx} not yielded by {self.map_fn}" ) else: - return add_sample_restore_key(mapped_sample, sample_idx, src=self) + return set_sample_restore_key(mapped_sample, restore_key) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/mix_batch_dataset.py b/src/megatron/energon/wrappers/mix_batch_dataset.py index 7b12aca2..eb6d9b69 100644 --- a/src/megatron/energon/wrappers/mix_batch_dataset.py +++ b/src/megatron/energon/wrappers/mix_batch_dataset.py @@ -117,9 +117,6 @@ def __init__( def reset_state_own(self) -> None: return - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_batch]: yield from self.dataset diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index ee10e6c2..c52b8da1 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -3,6 +3,7 @@ import contextlib import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -18,12 +19,19 @@ from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import ( + RestoreKey, SavableDataset, - add_sample_restore_key, set_sample_restore_key, ) from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, + WrappedRestoreKey, + get_sample_restore_key, + wrap_sample_restore_key, +) from megatron.energon.wrappers.buffer import SavableSampleBuffer T_sample = TypeVar("T_sample") @@ -31,6 +39,21 @@ T_batch_sample = TypeVar("T_batch_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class EncodePackRestoreKey(WrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class PackingRestoreKey(MultiWrappedRestoreKey): + pack_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class PackingGenRestoreKey(PackingRestoreKey): + gen_idx: int + + class PackingDataset( BaseWrapperDataset[T_sample, T_batch_sample], Generic[T_sample, T_encoded_sample, T_batch_sample], @@ -82,6 +105,12 @@ class PackingDataset( "_final_packing_sample_index", ) + _worker_local_fields = ( + "_last_pre_pack_failures", + "_last_final_pack_failures", + "_last_sample_encoder_failures", + ) + def __init__( self, dataset: SavableDataset[T_sample], @@ -232,10 +261,10 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: assert not isinstance(encoded_sample, Generator), "Generator not supported" self._sample_encoder_failure_handler.reset() encoded_pack.append( - add_sample_restore_key( + wrap_sample_restore_key( encoded_sample, - encode_idx, - src=self, + EncodePackRestoreKey, + sample_idx=encode_idx, ) ) return encoded_pack @@ -290,18 +319,15 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: self._final_pack_failure_handler.reset() yield set_sample_restore_key( inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, + PackingGenRestoreKey( + pack_idx=pack_sub_idx, gen_idx=pack_sub_idx, inner=pack_restore_keys + ), ) else: self._final_pack_failure_handler.reset() yield set_sample_restore_key( final_packed_sample, - pack_idx, - *pack_restore_keys, - src=self, + PackingRestoreKey(pack_idx=pack_idx, inner=pack_restore_keys), ) # Main loop: @@ -363,57 +389,48 @@ def assert_can_restore(self): ) super().assert_can_restore() - def restore_sample(self, restore_key: Any) -> T_sample: + def restore_sample(self, restore_key: RestoreKey) -> T_sample: # We need to store multiple indices to restore a batch. self.assert_can_restore() + assert isinstance(restore_key, PackingRestoreKey) if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, pack_idx, *pack_restore_keys = restore_key - id, pack_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ + assert isinstance(restore_key, PackingGenRestoreKey) pack = [] - for inner_idx in pack_restore_keys: + for inner_key in restore_key.inner: if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ - assert isinstance(sample_idx, int) - sample = self.dataset.restore_sample(inner_idx) + assert isinstance(inner_key, EncodePackRestoreKey) + encode_key = inner_key + inner_key = inner_key.inner + sample = self.dataset.restore_sample(inner_key) if self.sample_encoder is not None: with handle_restore_errors(self.worker_config.restore_error_handler, sample): - with self._sample_encoder_sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(encode_key.sample_idx): sample = self.sample_encoder(sample) assert not isinstance(sample, Generator), "Generator not supported" - sample = add_sample_restore_key(sample, sample_idx, src=self) - + sample = set_sample_restore_key(sample, encode_key) pack.append(sample) - with handle_restore_errors(self.worker_config.restore_error_handler, pack): - with self._final_packing_sample_index.ctx(pack_idx): + with SampleIndex(self.worker_config, src=self).ctx(restore_key.pack_idx): final_pack = self.final_packer(pack) if isinstance(final_pack, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." ) - for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - self._final_packing_sample_index.iter_ctx(final_pack, pack_idx) + assert isinstance(restore_key, PackingGenRestoreKey) + for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate( + SampleIndex(self.worker_config, src=self).iter_ctx( + final_pack, restore_key.pack_idx + ) ): - if cur_batch_sub_idx == pack_sub_idx: + if pack_sub_idx == restore_key.gen_idx: return set_sample_restore_key( inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, + restore_key, ) assert False, f"Pack sub-index {pack_sub_idx} not found in pack" else: - return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + return set_sample_restore_key(final_pack, restore_key) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index eb2298cd..f63010e9 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -41,8 +41,6 @@ def __init__( self.repeats = repeats self.restart = restart - self.reset_state_own() - def reset_state_own(self) -> None: self._repetition = 0 self._index = 0 diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 4d7bbc24..ea85bda2 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng @@ -19,7 +19,7 @@ class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sam _worker_rng: WorkerRng _active_buffer: SavableSampleBuffer[T_sample] - _savable_fields = ("_active_buffer", "_worker_rng") + _savable_fields = ("_worker_rng", "_active_buffer") def __init__( self, @@ -31,15 +31,11 @@ def __init__( """Create a shuffle buffer for the dataset.""" super().__init__(dataset, worker_config=worker_config) self.size = size - self.reset_state_own() def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: self._active_buffer.worker_start() it = iter(self._active_buffer.append_iter()) @@ -56,9 +52,6 @@ def __iter__(self) -> Iterator[T_sample]: pop_idx = self._worker_rng.randbelow(self._active_buffer.len_worker()) yield self._active_buffer.pop(pop_idx) - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self._active_buffer.restore_sample(restore_key) - def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index 68ecffea..6bd65cb2 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState from megatron.energon.watchdog import Watchdog from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -15,6 +16,10 @@ class WatchdogDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset wraps another dataset and watches the time it takes to yield samples.""" + timeout_seconds: Optional[float] + initial_timeout_seconds: Optional[float] + fail_on_timeout: bool + def __init__( self, dataset: SavableDataset[T_sample], @@ -41,9 +46,6 @@ def __init__( def reset_state_own(self) -> None: pass - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def _watchdog_trigger(self) -> None: if self.fail_on_timeout: # Raising an exception here will kill the whole process @@ -68,6 +70,14 @@ def __iter__(self) -> Iterator[T_sample]: ) yield from watchdog.watch_iter(self.dataset) + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # Watchdog is transparent, it won't change the samples return self.dataset.config() diff --git a/tests/test_av_decoder.py b/tests/test_av_decoder.py index 02677855..4b3996d2 100644 --- a/tests/test_av_decoder.py +++ b/tests/test_av_decoder.py @@ -9,11 +9,11 @@ import pickle import sys import time -import unittest from pathlib import Path import av import numpy as np +import pytest import torch import torchvision.transforms as transforms @@ -73,185 +73,170 @@ def tensors_close(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float return mae <= tolerance -class TestVideoDecode(unittest.TestCase): - """Test video decoding functionality.""" - - def setUp(self): - """Set up test fixtures.""" - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - self.decode_baseline_video_pyav() - self.loaders = [] # Keep track of loaders for cleanup - - def tearDown(self): - """Clean up test fixtures.""" - # Clean up any loaders - for loader in self.loaders: - if hasattr(loader, "_iterator"): - loader._iterator = None - if hasattr(loader, "_shutdown_workers"): - try: - loader._shutdown_workers() - except Exception: - pass - - def decode_baseline_video_pyav(self): - """Load the baseline video using PyAV directly.""" - self.complete_video_tensor = load_video_to_tensor("tests/data/sync_test.mp4") - - def test_decode_all_frames(self): - """Test decoding all frames from a video file.""" - av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) - av_data = av_decoder.get_frames() - video_tensor = av_data.video_clips[0] - - print(video_tensor.shape) - assert (video_tensor == self.complete_video_tensor).all(), ( - "Energon decoded video does not match baseline" +@pytest.fixture +def video_test_setup(): + """Set up test fixtures for video tests.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + complete_video_tensor = load_video_to_tensor("tests/data/sync_test.mp4") + yield complete_video_tensor + + +def test_decode_all_frames(video_test_setup): + """Test decoding all frames from a video file.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + av_data = av_decoder.get_frames() + video_tensor = av_data.video_clips[0] + + print(video_tensor.shape) + assert (video_tensor == video_test_setup).all(), "Energon decoded video does not match baseline" + + +def test_decode_video_metadata(video_test_setup): + """Test decoding metadata.""" + expected_metadata = [ + AVMetadata( + video_duration=63.054, + video_num_frames=1891, + video_fps=30.0, + video_width=192, + video_height=108, + audio_duration=63.103, + audio_channels=2, + audio_sample_rate=48000, + ), + AVMetadata( + video_duration=63.03333333333333, + video_num_frames=1891, + video_fps=30.0, + video_width=192, + video_height=108, + audio_duration=63.068, + audio_channels=2, + audio_sample_rate=48000, + ), + ] + for video_file, expected_metadata in zip( + ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata + ): + av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) + assert av_decoder.get_metadata() == expected_metadata, ( + f"Metadata does not match expected metadata for {video_file}" ) - def test_decode_metadata(self): - """Test decoding metadata.""" - expected_metadata = [ - AVMetadata( - video_duration=63.054, - video_num_frames=1891, - video_fps=30.0, - video_width=192, - video_height=108, - audio_duration=63.103, - audio_channels=2, - audio_sample_rate=48000, - ), - AVMetadata( - video_duration=63.03333333333333, - video_num_frames=1891, - video_fps=30.0, - video_width=192, - video_height=108, - audio_duration=63.068, - audio_channels=2, - audio_sample_rate=48000, - ), - ] - for video_file, expected_metadata in zip( - ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata - ): - av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( - f"Metadata does not match expected metadata for {video_file}" - ) - - assert av_decoder.get_video_duration(get_frame_count=False) in ( - (expected_metadata.video_duration, None), - (expected_metadata.video_duration, expected_metadata.video_num_frames), - ) - assert av_decoder.get_video_duration(get_frame_count=True) == ( - expected_metadata.video_duration, - expected_metadata.video_num_frames, - ) - - assert av_decoder.get_audio_duration() == expected_metadata.audio_duration - assert av_decoder.get_video_fps() == expected_metadata.video_fps - assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate - - def test_decode_strided_resized(self): - """Test decoding a subset of frames with resizing.""" - for video_file in ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"]: - print(f"================= Testing {video_file} ==================") - av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - - video_tensor = get_single_frames_uniform( - av_decoder=av_decoder, - num_frames=64, - video_out_frame_size=(224, 224), - ) - - # Get strided frames from baseline complete video tensor - strided_baseline_tensor = self.complete_video_tensor[ - np.linspace(0, self.complete_video_tensor.shape[0] - 1, 64, dtype=int).tolist() - ] - # Now resize the baseline frames - resize = transforms.Resize((224, 224)) - strided_resized_baseline_tensor = resize(strided_baseline_tensor) - - # We allow small numerical differences due to different resize implementations - assert tensors_close(video_tensor, strided_resized_baseline_tensor, tolerance=0.01), ( - "Energon decoded video does not match baseline" - ) - - def test_video_audio_sync(self): - """Test decoding video frames and audio clips together.""" - av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) - - # Extract a single frame every 2 seconds and an audio clip (0.05 seconds long) at the same time. - # We extract the frames from the sync video that shows the full white circle on the left, - # when the click sound occurs. - # Note that the click sound is actually off by 0.022 secs in the original video, - # I verified this in Davinci Resolve. - av_data = av_decoder.get_clips( - video_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30) for a in range(65)], - audio_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30 + 0.05) for a in range(65)], - video_unit="seconds", - audio_unit="seconds", - video_out_frame_size=None, + assert av_decoder.get_video_duration(get_frame_count=False) in ( + (expected_metadata.video_duration, None), + (expected_metadata.video_duration, expected_metadata.video_num_frames), ) - - # We drop the first two extracted frames because the click sequence hasn't started yet - video_clips = av_data.video_clips[2:] - audio_clips = av_data.audio_clips[2:] - # Then we check that the first extracted frame is all white in the area (18, 18, 55, 55) - # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png') - assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), ( - "First extracted frame is not all white in the area (18, 18, 55, 55)" + assert av_decoder.get_video_duration(get_frame_count=True) == ( + expected_metadata.video_duration, + expected_metadata.video_num_frames, ) - # Check that all the video frames are the same (close value) - for video_clip in video_clips: - assert tensors_close(video_clip, video_clips[0], tolerance=0.01), ( - "All video frames are not the same" - ) + assert av_decoder.get_audio_duration() == expected_metadata.audio_duration + assert av_decoder.get_video_fps() == expected_metadata.video_fps + assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate - # Check that the first audio clip has the click sound - assert (audio_clips[0] > 0.5).any(), "Audio click not found" - # Check that all the audio clips are the same (close value) - for audio_clip in audio_clips: - assert tensors_close(audio_clip, audio_clips[0], tolerance=0.01), ( - "All audio clips are not the same" - ) +def test_decode_strided_resized(video_test_setup): + """Test decoding a subset of frames with resizing.""" + for video_file in ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"]: + print(f"================= Testing {video_file} ==================") + av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - def test_pickle_decoder(self): - """Test AVDecoder on a video file can be pickled and unpickled.""" - av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + video_tensor = get_single_frames_uniform( + av_decoder=av_decoder, + num_frames=64, + video_out_frame_size=(224, 224), + ) - # Get metadata from original decoder - original_metadata = av_decoder.get_metadata() + # Get strided frames from baseline complete video tensor + strided_baseline_tensor = video_test_setup[ + np.linspace(0, video_test_setup.shape[0] - 1, 64, dtype=int).tolist() + ] + # Now resize the baseline frames + resize = transforms.Resize((224, 224)) + strided_resized_baseline_tensor = resize(strided_baseline_tensor) - # Pickle the decoder - pickled_data = pickle.dumps(av_decoder) + # We allow small numerical differences due to different resize implementations + assert tensors_close(video_tensor, strided_resized_baseline_tensor, tolerance=0.01), ( + "Energon decoded video does not match baseline" + ) - # Unpickle the decoder - unpickled_decoder = pickle.loads(pickled_data) - # Verify metadata matches - unpickled_metadata = unpickled_decoder.get_metadata() - assert unpickled_metadata == original_metadata, ( - f"Unpickled metadata {unpickled_metadata} does not match original {original_metadata}" +def test_video_audio_sync(video_test_setup): + """Test decoding video frames and audio clips together.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + + # Extract a single frame every 2 seconds and an audio clip (0.05 seconds long) at the same time. + # We extract the frames from the sync video that shows the full white circle on the left, + # when the click sound occurs. + # Note that the click sound is actually off by 0.022 secs in the original video, + # I verified this in Davinci Resolve. + av_data = av_decoder.get_clips( + video_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30) for a in range(65)], + audio_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30 + 0.05) for a in range(65)], + video_unit="seconds", + audio_unit="seconds", + video_out_frame_size=None, + ) + + # We drop the first two extracted frames because the click sequence hasn't started yet + video_clips = av_data.video_clips[2:] + audio_clips = av_data.audio_clips[2:] + # Then we check that the first extracted frame is all white in the area (18, 18, 55, 55) + # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png') + assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), ( + "First extracted frame is not all white in the area (18, 18, 55, 55)" + ) + + # Check that all the video frames are the same (close value) + for video_clip in video_clips: + assert tensors_close(video_clip, video_clips[0], tolerance=0.01), ( + "All video frames are not the same" ) - # Verify we can still decode frames from the unpickled decoder - video_tensor = get_single_frames_uniform( - av_decoder=unpickled_decoder, - num_frames=16, - video_out_frame_size=(64, 64), - ) + # Check that the first audio clip has the click sound + assert (audio_clips[0] > 0.5).any(), "Audio click not found" - # Check that we got the expected shape - assert video_tensor.shape == (16, 3, 64, 64), ( - f"Expected shape (16, 3, 64, 64), got {video_tensor.shape}" + # Check that all the audio clips are the same (close value) + for audio_clip in audio_clips: + assert tensors_close(audio_clip, audio_clips[0], tolerance=0.01), ( + "All audio clips are not the same" ) +def test_pickle_decoder(video_test_setup): + """Test AVDecoder on a video file can be pickled and unpickled.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + + # Get metadata from original decoder + original_metadata = av_decoder.get_metadata() + + # Pickle the decoder + pickled_data = pickle.dumps(av_decoder) + + # Unpickle the decoder + unpickled_decoder = pickle.loads(pickled_data) + + # Verify metadata matches + unpickled_metadata = unpickled_decoder.get_metadata() + assert unpickled_metadata == original_metadata, ( + f"Unpickled metadata {unpickled_metadata} does not match original {original_metadata}" + ) + + # Verify we can still decode frames from the unpickled decoder + video_tensor = get_single_frames_uniform( + av_decoder=unpickled_decoder, + num_frames=16, + video_out_frame_size=(64, 64), + ) + + # Check that we got the expected shape + assert video_tensor.shape == (16, 3, 64, 64), ( + f"Expected shape (16, 3, 64, 64), got {video_tensor.shape}" + ) + + def load_audio_to_tensor(audio_path: str) -> torch.Tensor: """Load an audio file into a tensor using PyAV directly. @@ -271,218 +256,195 @@ def load_audio_to_tensor(audio_path: str) -> torch.Tensor: return audio_tensor -class TestAudioDecode(unittest.TestCase): - """Test audio decoding functionality.""" - - def setUp(self): - """Set up test fixtures.""" - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - self.decode_baseline_audio_pyav() - self.loaders = [] # Keep track of loaders for cleanup - - def tearDown(self): - """Clean up test fixtures.""" - # Clean up any loaders - for loader in self.loaders: - if hasattr(loader, "_iterator"): - loader._iterator = None - if hasattr(loader, "_shutdown_workers"): - try: - loader._shutdown_workers() - except Exception: - pass - - def decode_baseline_audio_pyav(self): - """Load the baseline audio using PyAV directly.""" - self.complete_audio_tensor = load_audio_to_tensor("tests/data/test_audio.flac") - - def test_decode_all_samples(self): - """Test decoding all samples from an audio file.""" - with open("tests/data/test_audio.flac", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = av_decoder.get_audio() - audio_tensor = av_data.audio_clips[0] - - assert (audio_tensor == self.complete_audio_tensor).all(), ( - "Energon decoded audio does not match baseline" - ) +@pytest.fixture +def audio_test_setup(): + """Set up test fixtures for audio tests.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + complete_audio_tensor = load_audio_to_tensor("tests/data/test_audio.flac") + yield complete_audio_tensor + + +def test_decode_all_samples(audio_test_setup): + """Test decoding all samples from an audio file.""" + with open("tests/data/test_audio.flac", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = av_decoder.get_audio() + audio_tensor = av_data.audio_clips[0] + + assert (audio_tensor == audio_test_setup).all(), "Energon decoded audio does not match baseline" + - def test_decode_clips(self): - """Test decoding multiple clips from an audio file.""" - with open("tests/data/test_audio.flac", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, num_clips=5, clip_duration_seconds=3, request_audio=True - ) +def test_decode_clips(audio_test_setup): + """Test decoding multiple clips from an audio file.""" + with open("tests/data/test_audio.flac", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, num_clips=5, clip_duration_seconds=3, request_audio=True + ) + audio_tensor = av_data.audio_clips[0] + audio_sps = av_decoder.get_audio_samples_per_second() + + # Check audio tensor shape (5 clips, channels, 3 seconds at original sample rate) + assert len(av_data.audio_clips) == 5 + assert len(av_data.audio_timestamps) == 5 + assert audio_tensor.shape[1] >= int(3 * audio_sps) + assert audio_tensor.shape[1] <= int(4 * audio_sps) + + +def test_decode_wav(audio_test_setup): + """Test decoding a WAV file.""" + # Skip WAV test if file doesn't exist + if not os.path.exists("tests/data/test_audio.wav"): + pytest.skip("WAV test file not found") + return + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, num_clips=3, clip_duration_seconds=3, request_audio=True + ) + audio_sps = av_decoder.get_audio_samples_per_second() + + # Check audio tensor shape (3 clips, 2 channels, samples) + expected_samples = int(3 * audio_sps) # 3 seconds at original sample rate + assert all( + audio_tensor.shape == torch.Size([2, expected_samples]) + for audio_tensor in av_data.audio_clips + ), "Energon decoded WAV file has wrong shape." + + +def test_decode_wav_same_shape(audio_test_setup): + """Test decoding a WAV file.""" + # Skip WAV test if file doesn't exist + if not os.path.exists("tests/data/test_audio.wav"): + pytest.skip("WAV test file not found") + return + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, + num_clips=10, + clip_duration_seconds=0.9954783485892385, + request_audio=True, + ) + audio_sps = av_decoder.get_audio_samples_per_second() + + print(f"SPS: {audio_sps}") + for audio_tensor in av_data.audio_clips: + print(audio_tensor.shape) + + assert all( + audio_tensor.shape == av_data.audio_clips[0].shape for audio_tensor in av_data.audio_clips + ), "Audio clips have different shapes" + + +def test_wav_decode_against_soundfile(audio_test_setup): + """Test decoding a WAV file against the soundfile library.""" + + try: + import soundfile + except ImportError: + pytest.skip("soundfile library not found") + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") + audio_tensor = av_data.audio_clips[0] + + # Load the same audio file using soundfile + + audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") + audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) + + # Check that the two tensors are close + assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( + "Energon decoded audio does not match baseline" + ) + + # Now check partial extraction in the middle of the audio + av_data = av_decoder.get_clips(audio_clip_ranges=[(0.5, 1.0)], audio_unit="seconds") + audio_tensor = av_data.audio_clips[0] + audio_sps = av_decoder.get_audio_samples_per_second() + audio_tensor_soundfile = torch.from_numpy( + audio_data[int(0.5 * audio_sps) : int(1.0 * audio_sps)] + ).transpose(0, 1) + + # Check that the two tensors are close + assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( + "Energon decoded audio does not match baseline" + ) + + # Now compare the speed of the two implementations by repeatedly decoding the same audio + num_trials = 100 + + start_time = time.perf_counter() + for _ in range(num_trials): + av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") audio_tensor = av_data.audio_clips[0] - audio_sps = av_decoder.get_audio_samples_per_second() - - # Check audio tensor shape (5 clips, channels, 3 seconds at original sample rate) - assert len(av_data.audio_clips) == 5 - assert len(av_data.audio_timestamps) == 5 - assert audio_tensor.shape[1] >= int(3 * audio_sps) - assert audio_tensor.shape[1] <= int(4 * audio_sps) - - def test_decode_wav(self): - """Test decoding a WAV file.""" - # Skip WAV test if file doesn't exist - if not os.path.exists("tests/data/test_audio.wav"): - self.skipTest("WAV test file not found") - return - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, num_clips=3, clip_duration_seconds=3, request_audio=True - ) - audio_sps = av_decoder.get_audio_samples_per_second() - - # Check audio tensor shape (3 clips, 2 channels, samples) - expected_samples = int(3 * audio_sps) # 3 seconds at original sample rate - assert all( - audio_tensor.shape == torch.Size([2, expected_samples]) - for audio_tensor in av_data.audio_clips - ), "Energon decoded WAV file has wrong shape." - - def test_decode_wav_same_shape(self): - """Test decoding a WAV file.""" - # Skip WAV test if file doesn't exist - if not os.path.exists("tests/data/test_audio.wav"): - self.skipTest("WAV test file not found") - return - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, - num_clips=10, - clip_duration_seconds=0.9954783485892385, - request_audio=True, - ) - audio_sps = av_decoder.get_audio_samples_per_second() + end_time = time.perf_counter() + print(f"AVDecoder time: {end_time - start_time} seconds") - print(f"SPS: {audio_sps}") - for audio_tensor in av_data.audio_clips: - print(audio_tensor.shape) - - assert all( - audio_tensor.shape == av_data.audio_clips[0].shape - for audio_tensor in av_data.audio_clips - ), "Audio clips have different shapes" - - def test_wav_decode_against_soundfile(self): - """Test decoding a WAV file against the soundfile library.""" - - try: - import soundfile - except ImportError: - self.skipTest("soundfile library not found") - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) + # Now do the same with soundfile + start_time = time.perf_counter() + for _ in range(num_trials): + audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") + audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) + end_time = time.perf_counter() + print(f"Soundfile time: {end_time - start_time} seconds") - av_decoder = AVDecoder(stream) + start_time = time.perf_counter() + for _ in range(num_trials): av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") audio_tensor = av_data.audio_clips[0] + end_time = time.perf_counter() + print(f"AVDecoder time: {end_time - start_time} seconds") - # Load the same audio file using soundfile - + # Now do the same with soundfile + start_time = time.perf_counter() + for _ in range(num_trials): audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - - # Check that the two tensors are close - assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( - "Energon decoded audio does not match baseline" - ) - - # Now check partial extraction in the middle of the audio - av_data = av_decoder.get_clips(audio_clip_ranges=[(0.5, 1.0)], audio_unit="seconds") - audio_tensor = av_data.audio_clips[0] - audio_sps = av_decoder.get_audio_samples_per_second() - audio_tensor_soundfile = torch.from_numpy( - audio_data[int(0.5 * audio_sps) : int(1.0 * audio_sps)] - ).transpose(0, 1) - - # Check that the two tensors are close - assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( - "Energon decoded audio does not match baseline" + end_time = time.perf_counter() + print(f"Soundfile time: {end_time - start_time} seconds") + + +def test_decode_audio_metadata(audio_test_setup): + """Test decoding metadata.""" + expected_metadata = [ + AVMetadata( + audio_duration=10.0, + audio_channels=1, + audio_sample_rate=32000, + ), + AVMetadata( + audio_duration=12.782585034013605, + audio_channels=2, + audio_sample_rate=44100, + ), + ] + for audio_file, expected_metadata in zip( + ["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata + ): + av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes())) + assert av_decoder.get_metadata() == expected_metadata, ( + f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}" ) - # Now compare the speed of the two implementations by repeatedly decoding the same audio - num_trials = 100 - - start_time = time.perf_counter() - for _ in range(num_trials): - av_data = av_decoder.get_clips( - audio_clip_ranges=[(0, float("inf"))], audio_unit="samples" - ) - audio_tensor = av_data.audio_clips[0] - end_time = time.perf_counter() - print(f"AVDecoder time: {end_time - start_time} seconds") - - # Now do the same with soundfile - start_time = time.perf_counter() - for _ in range(num_trials): - audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") - audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - end_time = time.perf_counter() - print(f"Soundfile time: {end_time - start_time} seconds") - - start_time = time.perf_counter() - for _ in range(num_trials): - av_data = av_decoder.get_clips( - audio_clip_ranges=[(0, float("inf"))], audio_unit="samples" - ) - audio_tensor = av_data.audio_clips[0] - end_time = time.perf_counter() - print(f"AVDecoder time: {end_time - start_time} seconds") - - # Now do the same with soundfile - start_time = time.perf_counter() - for _ in range(num_trials): - audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") - audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - end_time = time.perf_counter() - print(f"Soundfile time: {end_time - start_time} seconds") - - def test_decode_metadata(self): - """Test decoding metadata.""" - expected_metadata = [ - AVMetadata( - audio_duration=10.0, - audio_channels=1, - audio_sample_rate=32000, - ), - AVMetadata( - audio_duration=12.782585034013605, - audio_channels=2, - audio_sample_rate=44100, - ), - ] - for audio_file, expected_metadata in zip( - ["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata - ): - av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( - f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}" - ) - - assert av_decoder.get_audio_duration() == expected_metadata.audio_duration - assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate - - -if __name__ == "__main__": - unittest.main() + assert av_decoder.get_audio_duration() == expected_metadata.audio_duration + assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 15949c85..345423ec 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -11,11 +11,11 @@ import shutil import sys import tempfile -import unittest import warnings from pathlib import Path from typing import List +import pytest import torch import webdataset as wds @@ -149,6 +149,7 @@ class CookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, T Cooker(cook_media_metadata, has_subflavors={"crude_type": "media_metadata"}), ] + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, @@ -217,6 +218,7 @@ def pack_selected_samples(self, samples: List[LazyTextSample]) -> TextSample: text=samples[0].txt + "|" + next_txt, ) + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, @@ -250,6 +252,7 @@ def pack_selected_samples(self, samples: List[TextSample]) -> TextSample: assert len(samples) == 1 return samples[0] + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, @@ -269,233 +272,234 @@ def batch(self, samples: List[TextSample]) -> TextBatch: ) -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_crude_text_test_dataset(self.dataset_path / "ds1", 0) - self.create_crude_text_test_dataset(self.dataset_path / "ds2", 100) - - self.mds_path = self.dataset_path / "metadataset.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: Metadataset", - "splits:", - " train:", - " datasets:", - " - weight: 1", - " path: ds1", - " subflavors:", - " source: metadataset.yaml", - " number: 43", - " mds: mds", - " crude_type: txtpkl", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavors:", - " source: metadataset.yaml", - " number: 44", - " mds: mds", - " crude_type: otherpkl", - " val:", - " datasets:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) +@pytest.fixture +def dataset_path(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + # Create a temporary directory + temp_dir = tempfile.TemporaryDirectory() + dataset_path = Path(temp_dir.name) + # dataset_path = Path("./test_dataset") + + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_crude_text_test_dataset(dataset_path / "ds1", 0) + create_crude_text_test_dataset(dataset_path / "ds2", 100) + + mds_path = dataset_path / "metadataset.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: Metadataset", + "splits:", + " train:", + " datasets:", + " - weight: 1", + " path: ds1", + " subflavors:", + " source: metadataset.yaml", + " number: 43", + " mds: mds", + " crude_type: txtpkl", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavors:", + " source: metadataset.yaml", + " number: 44", + " mds: mds", + " crude_type: otherpkl", + " val:", + " datasets:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] ) + ) - self.aux_mds_path = self.dataset_path / "aux_metadataset.yaml" - with open(self.aux_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: ds1", - " aux:", - " pkl_source: ds2", - " fs_source: filesystem://.", - " subflavors:", - " crude_type: aux_random_access", - ] - ) + aux_mds_path = dataset_path / "aux_metadataset.yaml" + with open(aux_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: ds1", + " aux:", + " pkl_source: ds2", + " fs_source: filesystem://.", + " subflavors:", + " crude_type: aux_random_access", + ] ) + ) - self.multimedia_wds_path = self.dataset_path / "multimedia_wds" - self.create_multimedia_webdataset(self.multimedia_wds_path) - - self.multimedia_fs_path = self.dataset_path / "multimedia_fs" - self.create_multimedia_filesystem_dataset(self.multimedia_fs_path) - - self.media_mds_path = self.dataset_path / "media_metadataset.yaml" - with open(self.media_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: multimedia_wds", - " aux:", - " media: filesystem://multimedia_fs", - " subflavors:", - " crude_type: media_metadata", - ] - ) + multimedia_wds_path = dataset_path / "multimedia_wds" + create_multimedia_webdataset(multimedia_wds_path) + + multimedia_fs_path = dataset_path / "multimedia_fs" + create_multimedia_filesystem_dataset(multimedia_fs_path) + + media_mds_path = dataset_path / "media_metadataset.yaml" + with open(media_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: multimedia_wds", + " aux:", + " media: filesystem://multimedia_fs", + " subflavors:", + " crude_type: media_metadata", + ] ) - - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_crude_text_test_dataset(path: Path, offset: int): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for idx in range(55): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx + offset:06d}", - "txt": f"{idx + offset}".encode(), - "pkl": pickle.dumps({"idx": idx + offset}), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - workers=1, - media_filter=None, ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: CrudeWebdataset", - "subflavors:", - " dataset.yaml: true", - " number: 42", - ] - ) - ) + print(dataset_path) + + yield dataset_path + + # Remove all temporary files + gc.collect() + temp_dir.cleanup() - @staticmethod - def create_multimedia_webdataset(path: Path): - path.mkdir(exist_ok=True, parents=True) - (path / "parts").mkdir(exist_ok=True, parents=True) - - jpg_bytes = _noise_image_bytes((32, 16), "JPEG", seed=0) - png_bytes = _noise_image_bytes((24, 24), "PNG", seed=1) - video_bytes = Path("tests/data/sync_test.mp4").read_bytes() - audio_bytes = Path("tests/data/test_audio.flac").read_bytes() - - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - shard_writer.write({"__key__": "image000", "jpg": jpg_bytes}) - shard_writer.write({"__key__": "image001", "png": png_bytes}) - shard_writer.write({"__key__": "audio001", "flac": audio_bytes}) - shard_writer.write({"__key__": "video001", "mp4": video_bytes}) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - workers=1, - media_filter=MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), - ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: CrudeWebdataset", - "subflavors:", - " crude_type: media_metadata", - ] - ) +def create_crude_text_test_dataset(path: Path, offset: int): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for idx in range(55): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{idx + offset:06d}", + "txt": f"{idx + offset}".encode(), + "pkl": pickle.dumps({"idx": idx + offset}), + }, ) + total_shards = shard_writer.shard - @staticmethod - def create_multimedia_filesystem_dataset(path: Path): - path.mkdir(exist_ok=True, parents=True) + from megatron.energon.flavors import BaseWebdatasetFactory - (path / "image000.jpg").write_bytes(_noise_image_bytes((32, 16), "JPEG", seed=0)) - (path / "image001.png").write_bytes(_noise_image_bytes((24, 24), "PNG", seed=1)) - shutil.copyfile("tests/data/sync_test.mp4", path / "video001.mp4") - shutil.copyfile("tests/data/test_audio.flac", path / "audio001.flac") + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + workers=1, + media_filter=None, + ) - prepare_filesystem_dataset( - EPath(path), MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), progress=False + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: CrudeWebdataset", + "subflavors:", + " dataset.yaml: true", + " number: 42", + ] + ) ) - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - global_error_handler=reraise_exception, - ) - # Train mode dataset - torch.manual_seed(42) - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=3, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - loader = get_savable_loader( - train_dataset, +def create_multimedia_webdataset(path: Path): + path.mkdir(exist_ok=True, parents=True) + (path / "parts").mkdir(exist_ok=True, parents=True) + + jpg_bytes = _noise_image_bytes((32, 16), "JPEG", seed=0) + png_bytes = _noise_image_bytes((24, 24), "PNG", seed=1) + video_bytes = Path("tests/data/sync_test.mp4").read_bytes() + audio_bytes = Path("tests/data/test_audio.flac").read_bytes() + + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + shard_writer.write({"__key__": "image000", "jpg": jpg_bytes}) + shard_writer.write({"__key__": "image001", "png": png_bytes}) + shard_writer.write({"__key__": "audio001", "flac": audio_bytes}) + shard_writer.write({"__key__": "video001", "mp4": video_bytes}) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + workers=1, + media_filter=MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: CrudeWebdataset", + "subflavors:", + " crude_type: media_metadata", + ] + ) ) + +@staticmethod +def create_multimedia_filesystem_dataset(path: Path): + path.mkdir(exist_ok=True, parents=True) + + (path / "image000.jpg").write_bytes(_noise_image_bytes((32, 16), "JPEG", seed=0)) + (path / "image001.png").write_bytes(_noise_image_bytes((24, 24), "PNG", seed=1)) + shutil.copyfile("tests/data/sync_test.mp4", path / "video001.mp4") + shutil.copyfile("tests/data/test_audio.flac", path / "audio001.flac") + + prepare_filesystem_dataset( + EPath(path), MediaFilterConfig(strategy=MediaFilterStrategy.EXTENSION), progress=False + ) + + +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + global_error_handler=reraise_exception, + ) + + # Train mode dataset + torch.manual_seed(42) + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=3, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_savable_loader( + train_dataset, + ) as loader: print(len(train_dataset)) # assert len(train_dataset) == 11 @@ -515,27 +519,26 @@ def test_metadataset(self): print(key, txt) - def test_loader(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) +def test_loader(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ) as loader: samples = [s.__key__ for idx, s in zip(range(100), loader)] print(samples) @@ -545,51 +548,44 @@ def test_loader(self): samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_after) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - - loader.restore_state_rank(state) - + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ).with_restored_state_rank(state) as loader: samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_restored) assert all([a == b for a, b in zip(samples_after, samples_restored)]) - def test_aux_random_access(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - print("Initializing dataset") - - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) +def test_aux_random_access(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ) as loader: print("Iterating from dataset") samples = [s.txts for idx, s in zip(range(100), loader)] for idx, txts in enumerate(samples): @@ -605,55 +601,48 @@ def test_aux_random_access(self): samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_after) - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - - loader.restore_state_rank(state) - + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ).with_restored_state_rank(state) as loader: samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_restored) assert all([a == b for a, b in zip(samples_after, samples_restored)]) - def test_aux_random_access_with_cache(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - print("Initializing dataset") - - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, - ), - ) +def test_aux_random_access_with_cache(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ) as loader: print("Iterating from dataset") samples = [s.txts for idx, s in zip(range(100), loader)] for idx, txts in enumerate(samples): @@ -670,59 +659,52 @@ def test_aux_random_access_with_cache(self): samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_after) - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, - ), - ) - - loader.restore_state_rank(state) - + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ).with_restored_state_rank(state) as loader: samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_restored) assert all([a == b for a, b in zip(samples_after, samples_restored)]) - def test_aux_random_access_with_cache_and_postencode(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - print("Initializing dataset") - - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoderWithPostencode(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, - ), - ) +def test_aux_random_access_with_cache_and_postencode(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoderWithPostencode(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ) as loader: print("Iterating from dataset") samples = [s.txts for idx, s in zip(range(100), loader)] for idx, txts in enumerate(samples): @@ -739,26 +721,21 @@ def test_aux_random_access_with_cache_and_postencode(self): samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_after) - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoderWithPostencode(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, - ), - ) - - loader.restore_state_rank(state) - + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoderWithPostencode(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ).with_restored_state_rank(state) as loader: samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] print(samples_restored) @@ -771,129 +748,125 @@ def test_aux_random_access_with_cache_and_postencode(self): assert sample_src_check == ( # Primary source for the sample, reading all source files SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), + dataset_path=EPath(dataset_path / "ds1"), index=2, shard_name="parts/data-0.tar", file_names=("000002.pkl", "000002.txt"), ), # Auxiliary source for the sample, reading from ds2 SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), + dataset_path=EPath(dataset_path / "ds2"), index="000102.txt", shard_name="parts/data-0.tar", file_names=("000102.txt",), ), # Auxiliary source for the sample, reading from ds1, but next sample SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), + dataset_path=EPath(dataset_path / "ds1"), index="000003.txt", shard_name="parts/data-0.tar", file_names=("000003.txt",), ), SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), + dataset_path=EPath(dataset_path / "ds1"), index=21, shard_name="parts/data-2.tar", file_names=("000021.pkl", "000021.txt"), ), SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), + dataset_path=EPath(dataset_path / "ds2"), index="000121.txt", shard_name="parts/data-2.tar", file_names=("000121.txt",), ), SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), + dataset_path=EPath(dataset_path / "ds1"), index="000022.txt", shard_name="parts/data-2.tar", file_names=("000022.txt",), ), ) - def test_aux_filesystem_reference(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - ) - loader = get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=1, - worker_config=worker_config, - task_encoder=CookingTaskEncoderWithAuxFilesystemReference(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) +def test_aux_filesystem_reference(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=1, + worker_config=worker_config, + task_encoder=CookingTaskEncoderWithAuxFilesystemReference(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as loader: sample = next(iter(loader)) assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") - def test_media_metadata_webdataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - ) - loader = get_savable_loader( - get_train_dataset( - self.media_mds_path, - batch_size=1, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - ) +def test_media_metadata_webdataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) - descriptions = [] - for _, batch in zip(range(4), loader): - descriptions.extend(batch.txts) - - # from pprint import pprint - # pprint(descriptions, indent=4) - - # The descriptions are like "A|B", where A is the format - # in the WebDataset and B is the format in the auxiliary dataset. - - assert descriptions == [ - "IMG-32x16-JPEG|IMG-32x16-JPEG", - "IMG-24x24-PNG|IMG-24x24-PNG", - "AUDIO-10.0s@32000Hz|AUDIO-10.0s@32000Hz", - "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s", - ] - - def test_nomds(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, + loader = get_savable_loader( + get_train_dataset( + dataset_path / "media_metadataset.yaml", + batch_size=1, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ) + ) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path / "ds1", - batch_size=2, - worker_config=worker_config, - task_encoder=GenericCookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) + descriptions = [] + for _, batch in zip(range(4), loader): + descriptions.extend(batch.txts) + + # from pprint import pprint + # pprint(descriptions, indent=4) + + # The descriptions are like "A|B", where A is the format + # in the WebDataset and B is the format in the auxiliary dataset. + + assert descriptions == [ + "IMG-32x16-JPEG|IMG-32x16-JPEG", + "IMG-24x24-PNG|IMG-24x24-PNG", + "AUDIO-10.0s@32000Hz|AUDIO-10.0s@32000Hz", + "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s", + ] + + +def test_nomds(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + with get_savable_loader( + get_train_dataset( + dataset_path / "ds1", + batch_size=2, + worker_config=worker_config, + task_encoder=GenericCookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as loader: samples = [s.__key__ for idx, s in zip(range(100), loader)] print(samples) assert len(samples) == 100 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 00000000..1e134324 --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,390 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +"""This module defines tests for meta datasets.""" + +import gc +import logging +import sys +import tempfile +import warnings +from collections import Counter +from pathlib import Path +from typing import Iterable + +import pytest +import torch +import webdataset as wds + +from megatron.energon import ( + WorkerConfig, + get_train_dataset, +) +from megatron.energon.dataloader import DataLoader, ForkDataLoaderWorker, ThreadDataLoaderWorker +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.task_encoder.base import DefaultTaskEncoder, stateless + + +class VerifyWorkerTaskEncoder(DefaultTaskEncoder): + def __init__(self, expected_num_workers: int): + self.expected_num_workers = expected_num_workers + super().__init__() + + @stateless + def encode_sample(self, sample): + sample = super().encode_sample(sample) + worker_info = torch.utils.data.get_worker_info() + if self.expected_num_workers > 0: + assert worker_info is not None + assert worker_info.num_workers == self.expected_num_workers + else: + assert worker_info is None + return sample + + +@pytest.fixture +def temp_dir(): + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + return dataset_path + + +@pytest.fixture +def ds1_path(dataset_path): + ds1_path = dataset_path / "ds1" + ds1_path.mkdir(exist_ok=True, parents=True) + create_text_test_dataset(ds1_path, range(55), range(55)) + print(ds1_path) + return ds1_path + + +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) + ) + + +def test_dataloader_no_workers(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), + ), + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), + ), + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + +def test_dataloader_fork(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=2), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) + + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + +def test_dataloader_fork_multi_parallel(ds1_path): + torch.manual_seed(42) + worker_config_r0 = WorkerConfig( + rank=0, + world_size=2, + num_workers=2, + seed_offset=42, + ) + worker_config_r1 = WorkerConfig( + rank=1, + world_size=2, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + train_loader_r0 = DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config_r0, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config_r0.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r0) == 4, len(train_loader_r0) + + train_order1_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + print(train_order1_r0[:10]) + print(Counter(train_order1_r0)) + assert len(train_order1_r0) == 28, len(train_order1_r0) + assert len(Counter(train_order1_r0)) == 28, Counter(train_order1_r0) + assert all(v == 1 for v in Counter(train_order1_r0).values()), Counter(train_order1_r0) + + train_loader_r1 = DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config_r1, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config_r1.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r1) == 4, len(train_loader_r1) + + train_order1_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + print(train_order1_r1[:10]) + print(Counter(train_order1_r1)) + assert len(train_order1_r1) == 27, len(train_order1_r1) + assert len(Counter(train_order1_r1)) == 27, Counter(train_order1_r1) + assert all(v == 1 for v in Counter(train_order1_r1).values()), Counter(train_order1_r1) + + train_loader_r1.save_state_rank() + + train_loader_r0.save_state_rank() + + train_order2_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + assert len(train_order2_r0) == 28 + + train_order2_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + assert len(train_order2_r1) == 27 + + train_loader_r0.shutdown() + train_loader_r1.shutdown() + + +def test_dataloader_thread(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4c5992d0..bcde3d28 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,7 +3,6 @@ """This module defines tests for the dataset.""" -import dataclasses import gc import io import json @@ -12,7 +11,6 @@ import random import sys import tempfile -import unittest import warnings from collections import defaultdict from dataclasses import dataclass @@ -20,6 +18,7 @@ from typing import Hashable, List, Tuple, Type, Union import numpy as np +import pytest import torch import webdataset as wds from click.testing import CliRunner @@ -35,7 +34,6 @@ MapDataset, MixBatchDataset, Sample, - SavableDataLoader, TaskEncoder, WorkerConfig, generic_batch, @@ -45,6 +43,7 @@ get_val_dataset, homogeneous_concat_mix, ) +from megatron.energon.dataloader.dataloader import DataLoader from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.edataclass import edataclass from megatron.energon.flavors import BaseWebdatasetFactory @@ -97,221 +96,224 @@ class ShouldRaiseException(Exception): pass -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.samples = self.create_captioning_test_dataset(self.dataset_path, DATASET_SIZE) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_captioning_test_dataset(path: Union[str, Path], num_samples: int = 50): - """Creates a small dummy captioning dataset for testing purposes.""" - path = Path(path) - - animals = ( - "ant bee beetle bug bumblebee butterfly caterpillar cicada cricket dragonfly earwig " - "firefly grasshopper honeybee hornet inchworm ladybug locust mantis mayfly mosquito " - "moth sawfly silkworm termite wasp woodlouse" - ).split() - adjectives = ( - "adorable affable amazing amiable attractive beautiful calm charming cherubic classic " - "classy convivial cordial cuddly curly cute debonair elegant famous fresh friendly " - "funny gorgeous graceful gregarious grinning handsome hilarious hot interesting kind " - "laughing lovely meek mellow merciful neat nifty notorious poetic pretty refined " - "refreshing sexy smiling sociable spiffy stylish sweet tactful whimsical" - ).split() - - # Set random seeds for numpy and torch - np.random.seed(42) - torch.manual_seed(42) - - entries = [] - - assert num_samples < len(animals) * len(adjectives), ( - "Cannot generate more samples than unique captions." - ) +@pytest.fixture +def temp_dir(): + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() - # Create num_samples unique captions - captions = set() - while len(captions) < num_samples: - # Create random description by sampling from adjectives and animals - adjective = np.random.choice(adjectives) - prefix = "An" if adjective[0] in "aeiou" else "A" - description = f"{prefix} {adjective} {np.random.choice(animals)}." - captions.add(description) - - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=30) as shard_writer: - for idx in range(num_samples): - # Create a dummy image with random noise and save to disk - img_buf = io.BytesIO() - randimg = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) - image = Image.fromarray(randimg) - image.save(img_buf, format="PNG") - img_bytes = img_buf.getvalue() - - description = captions.pop() - - entries.append({"image": randimg, "caption": description}) - - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx:06d}", - "png": img_bytes, - "txt": description.encode("utf-8"), - "json": json.dumps({"caption": description}), - }, - ) - total_shards = shard_writer.shard - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - ) +@pytest.fixture +def dataset_path(temp_dir): + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + return dataset_path - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: txt", - ] - ) + +@pytest.fixture +def samples(dataset_path): + return create_captioning_test_dataset(dataset_path, DATASET_SIZE) + + +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + +def create_captioning_test_dataset(path: Union[str, Path], num_samples: int = 50): + """Creates a small dummy captioning dataset for testing purposes.""" + path = Path(path) + + animals = ( + "ant bee beetle bug bumblebee butterfly caterpillar cicada cricket dragonfly earwig " + "firefly grasshopper honeybee hornet inchworm ladybug locust mantis mayfly mosquito " + "moth sawfly silkworm termite wasp woodlouse" + ).split() + adjectives = ( + "adorable affable amazing amiable attractive beautiful calm charming cherubic classic " + "classy convivial cordial cuddly curly cute debonair elegant famous fresh friendly " + "funny gorgeous graceful gregarious grinning handsome hilarious hot interesting kind " + "laughing lovely meek mellow merciful neat nifty notorious poetic pretty refined " + "refreshing sexy smiling sociable spiffy stylish sweet tactful whimsical" + ).split() + + # Set random seeds for numpy and torch + np.random.seed(42) + torch.manual_seed(42) + + entries = [] + + assert num_samples < len(animals) * len(adjectives), ( + "Cannot generate more samples than unique captions." + ) + + # Create num_samples unique captions + captions = set() + while len(captions) < num_samples: + # Create random description by sampling from adjectives and animals + adjective = np.random.choice(adjectives) + prefix = "An" if adjective[0] in "aeiou" else "A" + description = f"{prefix} {adjective} {np.random.choice(animals)}." + captions.add(description) + + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=30) as shard_writer: + for idx in range(num_samples): + # Create a dummy image with random noise and save to disk + img_buf = io.BytesIO() + randimg = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + image = Image.fromarray(randimg) + image.save(img_buf, format="PNG") + img_bytes = img_buf.getvalue() + + description = captions.pop() + + entries.append({"image": randimg, "caption": description}) + + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{idx:06d}", + "png": img_bytes, + "txt": description.encode("utf-8"), + "json": json.dumps({"caption": description}), + }, ) + total_shards = shard_writer.shard + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: txt", + ] + ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_field.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: json[caption]", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_field.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: json[caption]", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "sample_loader: sample_loader.py:sample_loader", - "part_filter: sample_loader.py:part_filter", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "sample_loader: sample_loader.py:sample_loader", + "part_filter: sample_loader.py:part_filter", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader_key.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "sample_loader: sample_loader.py:sample_loader_key", - "part_filter: sample_loader.py:part_filter", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader_key.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "sample_loader: sample_loader.py:sample_loader_key", + "part_filter: sample_loader.py:part_filter", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "sample_loader.py", "w") as f: - f.write( - "\n".join( - [ - "def sample_loader(raw: dict) -> dict:", - " assert 'txt' not in raw", - " return dict(", - ' image=raw["png"],', - ' caption="" + raw["json"]["caption"],', - " )", - "", - "def sample_loader_key(raw: dict) -> dict:", - " assert 'txt' not in raw", - " return dict(", - ' __key__="" + raw["__key__"],', - ' image=raw["png"],', - ' caption="" + raw["json"]["caption"],', - " )", - "", - "def part_filter(part: str) -> bool:", - ' return part in ["json", "png"]', - "", - ] - ) + with open(path / MAIN_FOLDER_NAME / "sample_loader.py", "w") as f: + f.write( + "\n".join( + [ + "def sample_loader(raw: dict) -> dict:", + " assert 'txt' not in raw", + " return dict(", + ' image=raw["png"],', + ' caption="" + raw["json"]["caption"],', + " )", + "", + "def sample_loader_key(raw: dict) -> dict:", + " assert 'txt' not in raw", + " return dict(", + ' __key__="" + raw["__key__"],', + ' image=raw["png"],', + ' caption="" + raw["json"]["caption"],', + " )", + "", + "def part_filter(part: str) -> bool:", + ' return part in ["json", "png"]', + "", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_exclude.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: txt", - "split_config: split2.yaml", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_exclude.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: txt", + "split_config: split2.yaml", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: - with open(path / MAIN_FOLDER_NAME / "split.yaml", "r") as rf: - origsplit = rf.read() - f.write( - origsplit - + "\n" - + "\n".join( - [ - "exclude:", - " - parts/data-0.tar", - " - parts/data-1.tar/00003{5..9}", - ] - ) + with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: + with open(path / MAIN_FOLDER_NAME / "split.yaml", "r") as rf: + origsplit = rf.read() + f.write( + origsplit + + "\n" + + "\n".join( + [ + "exclude:", + " - parts/data-0.tar", + " - parts/data-1.tar/00003{5..9}", + ] ) + ) - return entries + return entries - def test_captioning_dataset(self): - ds = get_dataset_from_config( - self.dataset_path, - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - ds = MapDataset( - ds.build(), +def test_captioning_dataset(dataset_path, samples): + def new_ds(): + return MapDataset( + get_dataset_from_config( + dataset_path, + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ).build(), lambda x: CaptioningSample( __key__=x.__key__, __restore_key__=x.__restore_key__, @@ -322,42 +324,43 @@ def test_captioning_dataset(self): worker_config=no_worker_config, ) - def get_ld(ds): - return get_loader(ds) - - # Check len operator - assert len(ds) == 50 - # Check if iterating returns the same - iter1 = list(get_ld(ds)) - iter2 = list(get_ld(ds)) - assert len(iter1) == 50 - assert len(iter2) == 50 - assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) - - # Check case when batch size is larger than dataset size - batch_sizes = [] - for wrapped_sample in get_ld( - BatchDataset( - ds, - batch_size=DATASET_SIZE * 2, - batcher=generic_batch, - worker_config=no_worker_config, - ) - ): + ds = new_ds() + # Check len operator + assert len(ds) == 50 + # Check if iterating returns the same + with get_loader(ds) as l1, get_loader(new_ds()) as l2: + iter1 = list(l1) + iter2 = list(l2) + assert len(iter1) == 50 + assert len(iter2) == 50 + assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) + + # Check case when batch size is larger than dataset size + batch_sizes = [] + with get_loader( + BatchDataset( + new_ds(), + batch_size=DATASET_SIZE * 2, + batcher=generic_batch, + worker_config=no_worker_config, + ) + ) as l: + for wrapped_sample in l: batch_sizes.append(wrapped_sample.image.shape[0]) - assert batch_sizes == [DATASET_SIZE] + assert batch_sizes == [DATASET_SIZE] - # Check returned dimensions and batch sizes if batch size is smaller than dataset size - batch_size = 4 - assert batch_size < DATASET_SIZE + # Check returned dimensions and batch sizes if batch size is smaller than dataset size + batch_size = 4 + assert batch_size < DATASET_SIZE - batched_ds = BatchDataset( - ds, batch_size=batch_size, batcher=generic_batch, worker_config=no_worker_config - ) + batched_ds = BatchDataset( + new_ds(), batch_size=batch_size, batcher=generic_batch, worker_config=no_worker_config + ) - cnt = 0 - expected_num_batches = math.ceil(DATASET_SIZE / batch_size) - for idx, wrapped_sample in enumerate(get_ld(batched_ds)): + cnt = 0 + expected_num_batches = math.ceil(DATASET_SIZE / batch_size) + with get_loader(batched_ds) as l: + for idx, wrapped_sample in enumerate(l): # Check batch sizes if idx < expected_num_batches - 1: assert wrapped_sample.image.shape[0] == batch_size @@ -375,14 +378,14 @@ def get_ld(ds): logging.info(f" {wrapped_sample.image.shape=}") logging.info(f" {wrapped_sample.caption.shape=}") - assert cnt == expected_num_batches + assert cnt == expected_num_batches - # Check if actual image and caption data are correct - loader = get_ld( - BatchDataset(ds, batch_size=9, batcher=generic_batch, worker_config=no_worker_config), - ) + # Check if actual image and caption data are correct + with get_loader( + BatchDataset(new_ds(), batch_size=9, batcher=generic_batch, worker_config=no_worker_config), + ) as loader: batch_sizes = [] - dataset_samples = {sample["caption"]: sample["image"] for sample in self.samples} + dataset_samples = {sample["caption"]: sample["image"] for sample in samples} for idx, sample in enumerate(loader): batch_sizes.append(sample.image.shape[0]) for bidx in range(sample.image.shape[0]): @@ -396,93 +399,101 @@ def get_ld(ds): assert len(dataset_samples) == 0 assert batch_sizes == [9, 9, 9, 9, 9, 5] - def test_field_access(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_field.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - for sample in get_loader(ds.build()): - captions.remove(sample.caption) - assert len(captions) == 0 - def test_sample_loader(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_sample_loader.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - for sample in get_loader(ds.build()): +def test_field_access(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_field.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + with get_loader(ds.build()) as loader: + for sample in loader: + captions.remove(sample.caption) + assert len(captions) == 0 + + +def test_sample_loader(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_sample_loader.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + with get_loader(ds.build()) as loader: + for sample in loader: assert sample.caption[:4] == "" captions.remove(sample.caption[4:]) - assert len(captions) == 0 - - def test_sample_loader_key(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_sample_loader_key.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - keys = set(f"{idx:06d}" for idx in range(len(self.samples))) - for sample in get_loader(ds.build()): + assert len(captions) == 0 + + +def test_sample_loader_key(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_sample_loader_key.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + keys = set(f"{idx:06d}" for idx in range(len(samples))) + with get_loader(ds.build()) as loader: + for sample in loader: assert sample.caption[:4] == "" captions.remove(sample.caption[4:]) keys.remove(sample.__key__) - assert len(captions) == 0 - assert len(keys) == 0 + assert len(captions) == 0 + assert len(keys) == 0 - def test_exclusion(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_exclude.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - keys = [entry.__key__ for entry in get_loader(ds.build())] - assert keys == [f"{i:06d}" for i in list(range(30, 35)) + list(range(40, 50))], keys +def test_exclusion(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_exclude.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) - def test_loader(self): - torch.manual_seed(42) + with get_loader(ds.build()) as loader: + keys = [entry.__key__ for entry in loader] + assert keys == [f"{i:06d}" for i in list(range(30, 35)) + list(range(40, 50))], keys - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), - ) +def test_loader(dataset_path, samples): + torch.manual_seed(42) - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=10, - worker_config=no_worker_config, - parallel_shard_iters=2, - virtual_epoch_length=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=TestTaskEncoder(), + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), ) - ) + with get_loader( + get_train_dataset( + dataset_path, + batch_size=10, + worker_config=no_worker_config, + parallel_shard_iters=2, + virtual_epoch_length=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ) + ) as loader: assert len(loader) == 2 def hist(data): @@ -504,43 +515,44 @@ def hist(data): assert len(keyhist) == 50 assert all(v in (39, 40, 41) for v in keyhist.values()) - loader2 = get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=10, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(), - ) + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=10, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(), ) + ) as loader2: assert len(loader2) == 5 # The order in the split is shuffled this way assert list(key for batch in loader2 for key in batch.__key__) == [ f"{i:06d}" for i in range(30, 50) ] + [f"{i:06d}" for i in range(30)] - def test_default_dataset(self): - torch.manual_seed(42) - train_loader = get_loader( +def test_default_dataset(dataset_path, samples): + torch.manual_seed(42) + + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=10, worker_config=no_worker_config, shuffle_buffer_size=None, max_samples_per_sequence=None, ) - ) - - val_loader = get_loader( + ) as train_loader, + get_loader( get_val_dataset( - self.dataset_path, + dataset_path, split_part="train", batch_size=10, worker_config=no_worker_config, ) - ) - + ) as val_loader, + ): n_samples = 0 for i, sample in zip(range(100), train_loader): assert sample.image.shape == (10, 3, 100, 100) @@ -552,54 +564,53 @@ def test_default_dataset(self): n_samples += sample.image.shape[0] assert n_samples == 50 - def test_no_batching(self): - train_loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - ) +def test_no_batching(dataset_path, samples): + with get_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) as train_loader: one_sample = next(iter(train_loader)) # Single sample without batching assert isinstance(one_sample.image, torch.Tensor) assert isinstance(one_sample.caption, str) - def test_dataset_len(self): - torch.manual_seed(42) - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=4) +def test_dataset_len(dataset_path, samples): + torch.manual_seed(42) - train_dataset = get_train_dataset( - self.dataset_path, - batch_size=11, - worker_config=worker_config, - virtual_epoch_length=12, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - train_loader = get_loader(train_dataset) + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=4) + train_dataset = get_train_dataset( + dataset_path, + batch_size=11, + worker_config=worker_config, + virtual_epoch_length=12, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_loader(train_dataset) as train_loader: assert len(train_dataset) == 12 assert len(train_loader) == 12 assert len(list(train_loader)) == 12 val_dataset = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config + dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config ) - val_loader = get_loader(val_dataset) + with get_loader(val_dataset) as val_loader: assert len(val_loader) == 50 assert len(list(val_loader)) == 50 - val_dataset = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config - ) - val_loader = get_loader(val_dataset) - + val_dataset = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config + ) + with get_loader(val_dataset) as val_loader: # n samples: ceil(50 / 11) // 4 * 4 assert len(val_dataset) == 8 assert len(val_loader) == 8 @@ -607,40 +618,39 @@ def test_dataset_len(self): assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1] assert sum(len(entry.__key__) for entry in val_loader) == 50 - def test_multirank_dataset(self): - torch.manual_seed(42) - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) +def test_multirank_dataset(dataset_path, samples): + torch.manual_seed(42) - train_dataset = get_train_dataset( - self.dataset_path, - batch_size=11, - worker_config=worker_config_r0, - virtual_epoch_length=12, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - train_loader = get_loader(train_dataset) + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + train_dataset = get_train_dataset( + dataset_path, + batch_size=11, + worker_config=worker_config_r0, + virtual_epoch_length=12, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_loader(train_dataset) as train_loader: assert len(train_dataset) == 12 assert len(train_loader) == 12 assert len(list(train_loader)) == 12 - val_dataset0 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0 - ) - val_loader0 = get_loader(val_dataset0) + val_dataset0 = get_val_dataset( + dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0 + ) + with get_loader(val_dataset0) as val_loader0: print(len(val_loader0)) assert len(val_loader0) == 25 keys0 = set(key for entry in val_loader0 for key in entry.__key__) assert len(keys0) == 25 - val_dataset0b11 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0 - ) - val_loader0b11 = get_loader(val_dataset0b11) - + val_dataset0b11 = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0 + ) + with get_loader(val_dataset0b11) as val_loader0b11: assert len(val_dataset0b11) == 4 assert len(val_loader0b11) == 4 assert len(list(val_loader0b11)) == 4 @@ -651,10 +661,10 @@ def test_multirank_dataset(self): assert keys0b11 == keys0 - val_dataset1 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1 - ) - val_loader1 = get_loader(val_dataset1) + val_dataset1 = get_val_dataset( + dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1 + ) + with get_loader(val_dataset1) as val_loader1: print(len(val_loader1)) assert len(val_loader1) == 25 keys1 = set(key for entry in val_loader1 for key in entry.__key__) @@ -663,11 +673,10 @@ def test_multirank_dataset(self): print(sorted(keys0)) assert keys1.isdisjoint(keys0) - val_dataset1b11 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1 - ) - val_loader1b11 = get_loader(val_dataset1b11) - + val_dataset1b11 = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1 + ) + with get_loader(val_dataset1b11) as val_loader1b11: assert len(val_dataset1b11) == 4 assert len(val_loader1b11) == 4 assert len(list(val_loader1b11)) == 4 @@ -679,71 +688,71 @@ def test_multirank_dataset(self): assert keys1b11 == keys1 - def test_weight_aug(self): - class WeightAugmentTaskEncoder(AugmentTaskEncoder): - def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: type): - super().__init__(task_encoder) - self.weight = weight - self.target_data_class = target_data_class - def encode_sample(self, sample): - sample = super().encode_sample(sample) - return self.target_data_class(**dataclasses.asdict(sample), weight=self.weight) +def test_weight_aug(dataset_path, samples): + class WeightAugmentTaskEncoder(AugmentTaskEncoder): + def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: type): + super().__init__(task_encoder) + self.weight = weight + self.target_data_class = target_data_class - torch.manual_seed(42) + def encode_sample(self, sample): + sample = super().encode_sample(sample) + return self.target_data_class.extend(sample, weight=self.weight) - @edataclass - class WeightedCaptioningBatch(Batch): - image: torch.Tensor - caption: List[str] - weight: float + torch.manual_seed(42) - loader = get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=10, - worker_config=no_worker_config, - task_encoder=WeightAugmentTaskEncoder( - DefaultTaskEncoder(), - weight=0.8, - target_data_class=WeightedCaptioningBatch, - ), - ) - ) + @edataclass + class WeightedCaptioningBatch(Batch): + image: torch.Tensor + caption: List[str] + weight: float + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=10, + worker_config=no_worker_config, + task_encoder=WeightAugmentTaskEncoder( + DefaultTaskEncoder(), + weight=0.8, + target_data_class=WeightedCaptioningBatch, + ), + ) + ) as loader: for data in loader: assert data.weight == [0.8] * 10 - def test_blending(self): - torch.manual_seed(42) - - loader = get_loader( - BlendDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=10, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + +def test_blending(dataset_path, samples): + torch.manual_seed(42) + + with get_loader( + BlendDataset( + ( + get_train_dataset( + dataset_path, + batch_size=10, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=20, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=20, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - worker_config=no_worker_config, - ) + 8, + ), + worker_config=no_worker_config, ) - + ) as loader: bs_hist = {10: 0, 20: 0} for i, sample in zip(range(1000), loader): bs_hist[sample.image.shape[0]] += 1 @@ -751,50 +760,50 @@ def test_blending(self): assert 150 <= bs_hist[10] <= 250 assert 750 <= bs_hist[20] <= 850 - def test_mixing_homogeneous(self): - @dataclass - class TestBatch(Batch): - image: torch.Tensor - caption: List[str] - source: int - - class TestTaskEncoder(TaskEncoder): - def __init__(self, source: int): - self.source = source - - def encode_batch(self, batch): - return TestBatch(**dataclasses.asdict(batch), source=self.source) - - loader = get_loader( - MixBatchDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=0), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + +def test_mixing_homogeneous(dataset_path, samples): + @dataclass + class TestBatch(Batch): + image: torch.Tensor + caption: List[str] + source: int + + class TestTaskEncoder(TaskEncoder): + def __init__(self, source: int): + self.source = source + + def encode_batch(self, batch): + return TestBatch.extend(batch, source=self.source) + + with get_loader( + MixBatchDataset( + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=0), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=1), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=1), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - batch_size=10, - batch_mix_fn=homogeneous_concat_mix, - worker_config=no_worker_config, - ) + 8, + ), + batch_size=10, + batch_mix_fn=homogeneous_concat_mix, + worker_config=no_worker_config, ) - + ) as loader: source_hist = {0: 0, 1: 0} for i, sample in zip(range(1000), loader): assert sample.image.shape == (10, 3, 100, 100) @@ -803,54 +812,54 @@ def encode_batch(self, batch): assert 1500 <= source_hist[0] <= 2500 assert 7500 <= source_hist[1] <= 8500 - def test_mixing_heterogeneous(self): - @dataclass - class TestBatch1(Batch): - image: torch.Tensor - caption: List[str] - source: int - - @dataclass - class TestBatch2(TestBatch1): - pass - - class TestTaskEncoder(TaskEncoder): - def __init__(self, source: int, batch_cls: Type[TestBatch1]): - self.source = source - self.batch_cls = batch_cls - - def encode_batch(self, batch): - return self.batch_cls(**dataclasses.asdict(batch), source=self.source) - - loader = get_loader( - MixBatchDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=0, batch_cls=TestBatch1), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + +def test_mixing_heterogeneous(dataset_path, samples): + @dataclass + class TestBatch1(Batch): + image: torch.Tensor + caption: List[str] + source: int + + @dataclass + class TestBatch2(TestBatch1): + pass + + class TestTaskEncoder(TaskEncoder): + def __init__(self, source: int, batch_cls: Type[TestBatch1]): + self.source = source + self.batch_cls = batch_cls + + def encode_batch(self, batch): + return self.batch_cls.extend(batch, source=self.source) + + with get_loader( + MixBatchDataset( + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=0, batch_cls=TestBatch1), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=1, batch_cls=TestBatch2), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=1, batch_cls=TestBatch2), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - batch_size=10, - worker_config=no_worker_config, - ) + 8, + ), + batch_size=10, + worker_config=no_worker_config, ) - + ) as loader: source_hist = {0: 0, 1: 0} for i, samples in zip(range(1000), loader): assert len(samples) == 10 @@ -860,70 +869,76 @@ def encode_batch(self, batch): assert 1500 <= source_hist[0] <= 2500 assert 7500 <= source_hist[1] <= 8500 - def test_val_limit(self): - torch.manual_seed(42) - loader = get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=2, - worker_config=no_worker_config, - limit=3, - ) - ) +def test_val_limit(dataset_path, samples): + torch.manual_seed(42) + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=2, + worker_config=no_worker_config, + limit=3, + ) + ) as loader: assert len(loader) == 3 samples = [[batch.__key__ for batch in loader] for _ in range(10)] print(samples) + for s in samples: + print(" -", s) assert all(samples[0] == one_ep_samples for one_ep_samples in samples) - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) - loader = get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=2, - worker_config=worker_config, - limit=3, - ) + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=2, + worker_config=worker_config, + limit=3, ) - + ) as loader: assert len(loader) == 3 samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples) - assert all(samples_wrk2[0] == one_ep_samples for one_ep_samples in samples_wrk2) - - def test_current_batch_index(self): - # Tests if the get_current_batch_index works properly - torch.manual_seed(42) - - class TestTaskEncoder(TaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample): - # print("si stack:", WorkerConfig._sample_index_stack) - return ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000), - ) + print(samples_wrk2) + for s in samples_wrk2: + print(" -", s) + assert all( + all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) + for one_ep_samples in samples_wrk2 + ) - # First, test simple single main-thread loader with accessing get_current_batch_index - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=no_worker_config, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + +def test_current_batch_index(dataset_path, samples): + # Tests if the get_current_batch_index works properly + torch.manual_seed(42) + + class TestTaskEncoder(TaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample): + # print("si stack:", WorkerConfig._sample_index_stack) + return ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000), ) - ) + # First, test simple single main-thread loader with accessing get_current_batch_index + with get_loader( + get_train_dataset( + dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=no_worker_config, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader: batches = list(zip(range(20), loader)) print("bi", [batch.batch_index for batch_idx, batch in batches]) assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) @@ -981,31 +996,32 @@ def encode_sample(self, sample): print("batch_rand_nums: ", batch_rand_nums) assert batch_rand_nums == ref_batch_rand_nums - # Now, test multi-worker loader with accessing get_current_batch_index - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + # Now, test multi-worker loader with accessing get_current_batch_index + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - loader = get_loader( + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - loader_r1 = get_loader( + ) as loader, + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), worker_config=worker_config_r1, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - + ) as loader_r1, + ): batches = list(zip(range(20), loader)) print("bir0", [batch.batch_index for batch_idx, batch in batches]) assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) @@ -1034,28 +1050,29 @@ def encode_sample(self, sample): for batch_idx, batch in batches_r1 ) - # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - loader = get_savable_loader( + # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state + with ( + get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - loader_r1 = get_savable_loader( + ) as loader, + get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), worker_config=worker_config_r1, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - + ) as loader_r1, + ): batches = list(zip(range(20), loader)) print([batch.batch_index for batch_idx, batch in batches]) assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) @@ -1083,19 +1100,17 @@ def encode_sample(self, sample): # Save and restore state state = loader.save_state_rank() - # Restore state and check if the batch index is restored correctly - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) + # Restore state and check if the batch index is restored correctly + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, ) - loader.restore_state_rank(state) - + ).with_restored_state_rank(state) as loader: batches = list(zip(range(20, 40), loader)) print([batch.batch_index for batch_idx, batch in batches]) print([batch.sample_index for batch_idx, batch in batches]) @@ -1108,40 +1123,40 @@ def encode_sample(self, sample): for batch_idx, batch in batches ) - def test_current_batch_index_generator(self): - # Tests if the get_current_batch_index works properly - torch.manual_seed(42) - - class TestTaskEncoder(TaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample): - # print("si stack:", WorkerConfig._sample_index_stack) - yield ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000) + 0, - ) - yield ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000) + 1000, - ) +def test_current_batch_index_generator(dataset_path, samples): + # Tests if the get_current_batch_index works properly + torch.manual_seed(42) - # First, test simple single main-thread loader with accessing get_current_batch_index - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=no_worker_config, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + class TestTaskEncoder(TaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample): + # print("si stack:", WorkerConfig._sample_index_stack) + yield ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000) + 0, ) - ) + yield ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000) + 1000, + ) + + # First, test simple single main-thread loader with accessing get_current_batch_index + with get_loader( + get_train_dataset( + dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=no_worker_config, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader: batches = list(zip(range(20), loader)) print("bi", [batch.batch_index for batch_idx, batch in batches]) assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) @@ -1197,31 +1212,32 @@ def encode_sample(self, sample): print("batch_rand_nums: ", batch_rand_nums) assert batch_rand_nums == ref_batch_rand_nums - # Now, test multi-worker loader with accessing get_current_batch_index - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + # Now, test multi-worker loader with accessing get_current_batch_index + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - loader = get_loader( + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=3, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - loader_r1 = get_loader( + ) as loader, + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=3, task_encoder=TestTaskEncoder(), worker_config=worker_config_r1, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - + ) as loader_r1, + ): batches = list(zip(range(20), loader)) print("bir0", [batch.batch_index for batch_idx, batch in batches]) assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) @@ -1250,30 +1266,29 @@ def encode_sample(self, sample): for batch_idx, batch in batches_r1 ) - # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - loader = get_savable_loader( + # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state + with ( + get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=3, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ), - worker_config=worker_config_r0, - ) - loader_r1 = get_savable_loader( + ) as loader, + get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=3, task_encoder=TestTaskEncoder(), worker_config=worker_config_r1, shuffle_buffer_size=20, max_samples_per_sequence=10, ), - worker_config=worker_config_r1, - ) - + ) as loader_r1, + ): batches = list(zip(range(20), loader)) print("bi:", [batch.batch_index for batch_idx, batch in batches]) print("si:", [batch.sample_index for batch_idx, batch in batches]) @@ -1318,20 +1333,17 @@ def encode_sample(self, sample): for batch_idx, batch in cmp_batches ) - # Restore state and check if the batch index is restored correctly - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ), + # Restore state and check if the batch index is restored correctly + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, - ) - loader.restore_state_rank(state) - + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ), + ).with_restored_state_rank(state) as loader: batches = list(zip(range(20, 40), loader)) print("bi:", [batch.batch_index for batch_idx, batch in batches]) print("si:", [batch.sample_index for batch_idx, batch in batches]) @@ -1349,51 +1361,51 @@ def encode_sample(self, sample): for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches) ) - def test_packing(self): - torch.manual_seed(42) - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) +def test_packing(dataset_path, samples): + torch.manual_seed(42) - @stateless - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), - ) + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) - def select_samples_to_pack( - self, samples: List[EncodedCaptioningSample] - ) -> List[List[EncodedCaptioningSample]]: - assert len(samples) == 21 - return [samples[:1], samples[1 : 1 + 4], samples[1 + 4 : 1 + 4 + 16]] - - @stateless - def pack_selected_samples( - self, samples: List[EncodedCaptioningSample] - ) -> EncodedCaptioningSample: - return EncodedCaptioningSample( - __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=(), - image=torch.stack([sample.image for sample in samples]), - caption=torch.cat([sample.caption for sample in samples]), - ) + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), + ) - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=no_worker_config, - virtual_epoch_length=6, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=TestTaskEncoder(), + def select_samples_to_pack( + self, samples: List[EncodedCaptioningSample] + ) -> List[List[EncodedCaptioningSample]]: + assert len(samples) == 21 + return [samples[:1], samples[1 : 1 + 4], samples[1 + 4 : 1 + 4 + 16]] + + @stateless + def pack_selected_samples( + self, samples: List[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join([sample.__key__ for sample in samples]), + __restore_key__=None, + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), ) - ) + with get_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=no_worker_config, + virtual_epoch_length=6, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ) + ) as loader: assert len(loader) == 6 samples = list(loader) @@ -1422,23 +1434,20 @@ def pack_selected_samples( assert restored_sample_1.__key__ == samples[1].__key__ assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - - loader_r0 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=worker_config_r0, - virtual_epoch_length=8, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=TestTaskEncoder(), - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=worker_config_r0, + virtual_epoch_length=8, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ), + ) as loader_r0: samples_r0 = list(loader_r0) assert [ [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0 @@ -1454,23 +1463,18 @@ def pack_selected_samples( [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0_cmp ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] - loader_r0 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=worker_config_r0, - virtual_epoch_length=8, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=TestTaskEncoder(), - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) - - loader_r0.restore_state_rank(rank_state_r0) - + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=worker_config_r0, + virtual_epoch_length=8, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ), + ).with_restored_state_rank(rank_state_r0) as loader_r0: samples_r0_restored = list(loader_r0) print("cmp", [batch.__key__ for batch in samples_r0_cmp]) print("rst", [batch.__key__ for batch in samples_r0_restored]) @@ -1481,56 +1485,56 @@ def pack_selected_samples( assert all(s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored)) - def test_packing_val(self): - torch.manual_seed(42) - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) +def test_packing_val(dataset_path, samples): + torch.manual_seed(42) - @stateless - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), - ) - - def select_samples_to_pack( - self, samples: List[EncodedCaptioningSample] - ) -> List[List[EncodedCaptioningSample]]: - assert len(samples) in (1 + 3 + 5 + 2, 50 % 11) - if len(samples) < 11: - return [] - return [ - samples[1 + 3 + 5 : 1 + 3 + 5 + 2], - samples[1 + 3 : 1 + 3 + 5], - samples[1 : 1 + 3], - samples[:1], - ] + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) - @stateless - def pack_selected_samples( - self, samples: List[EncodedCaptioningSample] - ) -> EncodedCaptioningSample: - return EncodedCaptioningSample( - __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=(), - image=torch.stack([sample.image for sample in samples]), - caption=torch.cat([sample.caption for sample in samples]), - ) + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), + ) - loader = get_loader( - get_val_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=11, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(), - split_part="train", + def select_samples_to_pack( + self, samples: List[EncodedCaptioningSample] + ) -> List[List[EncodedCaptioningSample]]: + assert len(samples) in (1 + 3 + 5 + 2, 50 % 11) + if len(samples) < 11: + return [] + return [ + samples[1 + 3 + 5 : 1 + 3 + 5 + 2], + samples[1 + 3 : 1 + 3 + 5], + samples[1 : 1 + 3], + samples[:1], + ] + + @stateless + def pack_selected_samples( + self, samples: List[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join([sample.__key__ for sample in samples]), + __restore_key__=None, + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), ) - ) + with get_loader( + get_val_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=11, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(), + split_part="train", + ) + ) as loader: assert len(loader) == 25, f"len(loader) == {len(loader)}" samples = list(loader) @@ -1561,61 +1565,57 @@ def pack_selected_samples( assert restored_sample_1.__key__ == samples[1].__key__ assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ - def test_group_batch(self): - class GroupingTaskEncoder( - TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] - ): - @stateless - def encode_sample(self, sample: CaptioningSample) -> CaptioningSample: - sample.caption = sample.__sources__[0].shard_name.split("/")[-1] - return sample - - def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]: - if sample.caption == "data-0.tar": - return "shard1", 4 - elif sample.caption == "data-1.tar": - return "shard2", 8 - else: - assert False - - @stateless - def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: - return CaptioningEncodedBatch(**dataclasses.asdict(batch)) - - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) + +def test_group_batch(dataset_path, samples): + class GroupingTaskEncoder( + TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] + ): + @stateless + def encode_sample(self, sample: CaptioningSample) -> CaptioningSample: + sample.caption = sample.__sources__[0].shard_name.split("/")[-1] + return sample + + def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]: + if sample.caption == "data-0.tar": + return "shard1", 4 + elif sample.caption == "data-1.tar": + return "shard2", 8 + else: + assert False + + @stateless + def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: + return CaptioningEncodedBatch.extend(batch) + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ) as loader: batches = list(zip(range(40), loader)) print([batch.__key__ for idx, batch in batches]) assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches) - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - - loader_r0 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config_r0, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config_r0, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ) as loader_r0: batches = list(zip(range(40), loader_r0)) print([batch.__key__ for idx, batch in batches]) @@ -1628,20 +1628,16 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: cmp_samples = list(zip(range(40, 80), loader_r0)) print([batch.__key__ for idx, batch in cmp_samples]) - loader_r0 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config_r0, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) - loader_r0.restore_state_rank(state) - + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config_r0, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ).with_restored_state_rank(state) as loader_r0: cmp_samples_rest = list(zip(range(40, 80), loader_r0)) print([batch.__key__ for idx, batch in cmp_samples_rest]) @@ -1658,34 +1654,34 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) ) - def test_debug_dataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - worker_log_level=3, - worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.jsonl", - ) - # Reset this to 0 to make sure the test is deterministic - SavableDataLoader._next_id = 0 +def test_debug_dataset(dataset_path, samples): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + worker_log_level=3, + worker_debug_path=str(dataset_path) + "/worker_debug/{worker_id}.jsonl", + ) - loader = get_savable_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=5, - worker_config=worker_config, - ), - ) + # Reset this to 0 to make sure the test is deterministic + DataLoader._next_id = 0 + with get_savable_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=5, + worker_config=worker_config, + ), + ) as loader: assert len(loader) == 10 samples = [[batch.__key__ for batch in loader] for _ in range(2)] print(samples) - debug_log_path = self.dataset_path / "worker_debug" + debug_log_path = dataset_path / "worker_debug" assert (debug_log_path / "0.jsonl").is_file() assert (debug_log_path / "1.jsonl").is_file() assert (debug_log_path / "2.jsonl").is_file() @@ -1694,181 +1690,184 @@ def test_debug_dataset(self): with (debug_log_path / "0.jsonl").open() as rf: for line in rf: line_data = json.loads(line) - if line_data["t"] == "SavableDataLoader.yield": - print(line_data) + print(line_data) + if line_data["t"] == "DataLoader.epoch_iter.yield": for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["idx"]] is None: - collected_keys_order[i][line_data["idx"]] = line_data["keys"] + if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: + collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ + "keys" + ] break else: assert False, "Too many entries for key" - print(collected_keys_order) - assert collected_keys_order == samples - - runner = CliRunner() - result = runner.invoke( - analyze_debug_command, - [ - str(debug_log_path), - "--include-modality", - "train,val", - "--heatmap-path", - str(self.dataset_path / "heatmap.png"), - ], - catch_exceptions=False, - ) - print(result.stdout) - assert result.exit_code == 0, "Debug analysis failed, see output" - assert "Analyzing 3 logs" in result.stdout - assert "Found 50 unique sample keys, 20 steps" in result.stdout - - def test_validate_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - lint_command, - [str(self.dataset_path), "--split-parts=train"], - catch_exceptions=False, - ) - assert result.exit_code == 0, "Validation failed, see output" - - def test_prepare_dataset(self): - runner = CliRunner() - result = runner.invoke( - prepare_command, - [str(self.dataset_path)], - catch_exceptions=False, - input="y\n1,0,0\ny\n0\nY\npng\ntxt\n", - ) - assert result.exit_code == 0, "Prepare failed, see output" - assert "Done" in result.stdout, "Prepare failed, see output" - - def test_preview_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - preview_command, - [str(self.dataset_path), "--split-parts=train"], - input="n\n", - catch_exceptions=False, - ) - # First sample! - assert "__key__ (): '000030'" in result.stdout - assert result.exit_code == 0, "Preview failed, see output" - - def test_info_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - info_command, - [str(self.dataset_path)], - catch_exceptions=False, - ) - print(result.stdout) - assert "50 samples" in result.stdout - assert "2 shards" in result.stdout - assert str(self.dataset_path) in result.stdout - assert "train" in result.stdout - assert result.exit_code == 0, "Preview failed, see output" - - def test_custom_error_handler(self): - """Test that custom error handlers work correctly in TaskEncoder.""" - torch.manual_seed(42) - - # Track error handler calls - error_calls = [] - - class ErrorProneTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) - - @stateless - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - # Intentionally raise an error for specific samples to test error handling - if "000035" in sample.__key__: - raise ValueError(f"Intentional error for {sample.__key__}") - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), - ) - - # Test with custom error handler - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - global_error_handler=lambda e, s, sources: error_calls.append( - { - "exception": e, - "sample_key": getattr(s, "__key__", None), - "exception_type": type(e).__name__, - } - ), - ) - - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=5, - worker_config=worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - virtual_epoch_length=50, - task_encoder=ErrorProneTaskEncoder(), + print(collected_keys_order) + assert collected_keys_order == samples + + runner = CliRunner() + result = runner.invoke( + analyze_debug_command, + [ + str(debug_log_path), + "--include-modality", + "train,val", + "--heatmap-path", + str(dataset_path / "heatmap.png"), + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0, "Debug analysis failed, see output" + assert "Analyzing 3 logs" in result.stdout + assert "Found 50 unique sample keys, 20 steps" in result.stdout + + +def test_validate_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + lint_command, + [str(dataset_path), "--split-parts=train"], + catch_exceptions=False, + ) + assert result.exit_code == 0, "Validation failed, see output" + + +def test_prepare_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + prepare_command, + [str(dataset_path)], + catch_exceptions=False, + input="y\n1,0,0\ny\n0\nY\npng\ntxt\n", + ) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + + +def test_preview_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + preview_command, + [str(dataset_path), "--split-parts=train"], + input="n\n", + catch_exceptions=False, + ) + # First sample! + assert "__key__ (): '000030'" in result.stdout + assert result.exit_code == 0, "Preview failed, see output" + + +def test_info_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + info_command, + [str(dataset_path)], + catch_exceptions=False, + ) + print(result.stdout) + assert "50 samples" in result.stdout + assert "2 shards" in result.stdout + assert str(dataset_path) in result.stdout + assert "train" in result.stdout + assert result.exit_code == 0, "Preview failed, see output" + + +def test_custom_error_handler(dataset_path, samples): + """Test that custom error handlers work correctly in TaskEncoder.""" + torch.manual_seed(42) + + # Track error handler calls + error_calls = [] + + class ErrorProneTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + # Intentionally raise an error for specific samples to test error handling + if "000035" in sample.__key__: + raise ValueError(f"Intentional error for {sample.__key__}") + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), ) + + # Test with custom error handler + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + global_error_handler=lambda e, s, sources: error_calls.append( + { + "exception": e, + "sample_key": getattr(s, "__key__", None), + "exception_type": type(e).__name__, + } + ), + ) + + loader = get_loader( + get_train_dataset( + dataset_path, + batch_size=5, + worker_config=worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + virtual_epoch_length=50, + task_encoder=ErrorProneTaskEncoder(), + ) + ) + + # Iterate through the loader - errors should be handled by custom handler + batches = [] + for i, batch in enumerate(loader): + batches.append(batch) + if i >= 9: # Get 10 batches (50 samples total) + break + + # Verify that the error handler was called + assert len(error_calls) > 0, "Error handler should have been called" + + # Verify that the error was for the right sample + assert any("000035" in call["sample_key"] for call in error_calls), ( + f"Error should have been for sample 000035, got: {error_calls}" + ) + + # Verify the exception type + assert all(call["exception_type"] == "ValueError" for call in error_calls), ( + "All errors should be ValueError" + ) + + print("Step 2: Reraise") + + def reraise(e, s, sources): + raise ShouldRaiseException() from e + + worker_config_r1 = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + global_error_handler=reraise, + ) + + loader = get_loader( + get_train_dataset( + dataset_path, + batch_size=5, + worker_config=worker_config_r1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + virtual_epoch_length=50, + task_encoder=ErrorProneTaskEncoder(), ) + ) - # Iterate through the loader - errors should be handled by custom handler + with pytest.raises(ShouldRaiseException): batches = [] for i, batch in enumerate(loader): batches.append(batch) if i >= 9: # Get 10 batches (50 samples total) break - - # Verify that the error handler was called - assert len(error_calls) > 0, "Error handler should have been called" - - # Verify that the error was for the right sample - assert any("000035" in call["sample_key"] for call in error_calls), ( - f"Error should have been for sample 000035, got: {error_calls}" - ) - - # Verify the exception type - assert all(call["exception_type"] == "ValueError" for call in error_calls), ( - "All errors should be ValueError" - ) - - print("Step 2: Reraise") - - def reraise(e, s, sources): - raise ShouldRaiseException() from e - - worker_config_r1 = WorkerConfig( - rank=0, - world_size=1, - num_workers=1, - global_error_handler=reraise, - ) - - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=5, - worker_config=worker_config_r1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - virtual_epoch_length=50, - task_encoder=ErrorProneTaskEncoder(), - ) - ) - - with self.assertRaises(ShouldRaiseException): - batches = [] - for i, batch in enumerate(loader): - batches.append(batch) - if i >= 9: # Get 10 batches (50 samples total) - break - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 15650307..f17d5aad 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -8,11 +8,11 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path +import pytest import torch import webdataset as wds import yaml @@ -23,12 +23,12 @@ TextSample, WorkerConfig, get_loader, + get_savable_loader, get_train_dataset, + stateless, ) from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME -from megatron.energon.loader import get_savable_loader -from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.checkpoint import command_redist # Speed up tests significantly by reducing the torch status check interval for broken worker shutdown @@ -58,116 +58,126 @@ def _norng_state(state): return state -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - self.dataset_path.mkdir(exist_ok=True, parents=True) +@pytest.fixture +def dataset_path(temp_dir): + """Create the main dataset directory with test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path) + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) - # Create temporary directories for checkpoint files - self.checkpoint_dir = Path(self.temp_dir.name) / "checkpoints" - self.checkpoint_dir.mkdir(exist_ok=True, parents=True) + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path) - self.redist_dir = Path(self.temp_dir.name) / "redist_checkpoints" - self.redist_dir.mkdir(exist_ok=True, parents=True) + print(dataset_path) + return dataset_path - print(self.dataset_path) - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() +@pytest.fixture +def checkpoint_dir(dataset_path): + """Create checkpoint directory for test files.""" + checkpoint_dir = dataset_path / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True, parents=True) + return checkpoint_dir - @staticmethod - def create_text_test_dataset(path: Path): - """Creates a small dummy test dataset for testing purposes.""" - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) +@pytest.fixture +def redist_dir(dataset_path): + """Create redistribution directory for test files.""" + redist_dir = dataset_path / "redist_checkpoints" + redist_dir.mkdir(exist_ok=True, parents=True) + return redist_dir - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer: - for idx in range(55): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx:06d}", - "txt": f"{idx}".encode(), - }, - ) - # Also create smaller shards, to verify distributions - if idx in (1, 3, 6, 10, 20, 30, 40, 50): - shard_writer.next_stream() - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - ] - ) - ) +def create_text_test_dataset(path: Path): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) - # Split with alternating train/val shards - with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: - yaml.dump( + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer: + for idx in range(55): + # Write individual files to shards + shard_writer.write( { - "split_parts": { - "train": [ - "parts/data-4.tar", - "parts/data-0.tar", - "parts/data-2.tar", - ], - "val": [ - "parts/data-1.tar", - "parts/data-3.tar", - "parts/data-5.tar", - ], - } + "__key__": f"{idx:06d}", + "txt": f"{idx}".encode(), }, - f, ) + # Also create smaller shards, to verify distributions + if idx in (1, 3, 6, 10, 20, 30, 40, 50): + shard_writer.next_stream() + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + ] + ) + ) - def test_split_parts(self): - with open(self.dataset_path / MAIN_FOLDER_NAME / "split.yaml", "r") as f: - print(f.read()) - with open(self.dataset_path / MAIN_FOLDER_NAME / "split2.yaml", "r") as f: - print(f.read()) - - ds = get_dataset_from_config( - self.dataset_path, - split_config="split2.yaml", - split_part="train", - worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0), - training=False, - sample_type=TextSample, + # Split with alternating train/val shards + with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: + yaml.dump( + { + "split_parts": { + "train": [ + "parts/data-4.tar", + "parts/data-0.tar", + "parts/data-2.tar", + ], + "val": [ + "parts/data-1.tar", + "parts/data-3.tar", + "parts/data-5.tar", + ], + } + }, + f, ) - dl = get_loader(ds.build()) + +def test_split_parts(dataset_path): + with open(dataset_path / MAIN_FOLDER_NAME / "split.yaml", "r") as f: + print(f.read()) + with open(dataset_path / MAIN_FOLDER_NAME / "split2.yaml", "r") as f: + print(f.read()) + + ds = get_dataset_from_config( + dataset_path, + split_config="split2.yaml", + split_part="train", + worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0), + training=False, + sample_type=TextSample, + ) + with get_loader(ds.build()) as dl: all_keys = [sample.__key__ for sample in dl] assert all_keys == [ "000011", # Shard 4 first @@ -187,97 +197,101 @@ def test_split_parts(self): "000006", ] - def test_text_dataset(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - ds = get_dataset_from_config( - self.dataset_path, +def test_text_dataset(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) + + def new_ds(): + return get_dataset_from_config( + dataset_path, split_part="train", training=False, sample_type=TextSample, worker_config=worker_config, ).build() - # Check len operator - assert len(ds) == 55 - # Check if iterating returns the same - iter1 = list(get_loader(ds)) - iter2 = list(get_loader(ds)) - assert len(iter1) == 55 - assert len(iter2) == 55 - assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) - assert all(f"{idx}" == x.text for idx, x in enumerate(get_loader(ds))) - - del ds - gc.collect() - - def test_epoch(self): - torch.manual_seed(42) - - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=5) - - # Without shuffle buffer, should yield everything exactly once - ds3 = get_dataset_from_config( - self.dataset_path, - split_part="train", - training=True, - sample_type=TextSample, - worker_config=worker_config, - ) - loader5 = get_loader(ds3.build()) + ds = new_ds() + + # Check len operator + assert len(ds) == 55 + # Check if iterating returns the same + with get_loader(ds) as l1: + iter1 = list(l1) + with get_loader(new_ds()) as l2: + iter2 = list(l2) + assert len(iter1) == 55 + assert len(iter2) == 55 + assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) + with get_loader(new_ds()) as l3: + assert all(f"{idx}" == x.text for idx, x in enumerate(l3)) + + +def test_epoch(dataset_path): + torch.manual_seed(42) + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=5) + + # Without shuffle buffer, should yield everything exactly once + ds3 = get_dataset_from_config( + dataset_path, + split_part="train", + training=True, + sample_type=TextSample, + worker_config=worker_config, + ) + with get_loader(ds3.build()) as loader5: order9 = [data.text for idx, data in zip(range(55), loader5)] print(order9) print(Counter(order9)) assert all(v == 1 for v in Counter(order9).values()) - def test_determinism(self): - worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2) - worker_config2b = WorkerConfig(rank=0, world_size=1, num_workers=2, seed_offset=43) - worker_config4 = WorkerConfig(rank=0, world_size=1, num_workers=4) - - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds1 = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config2, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - ds1b = get_train_dataset( # Same but different seed - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config2b, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - ds2 = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config2, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - ds3 = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config4, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - - # Fork the dataset twice - loader1 = get_loader(ds1) - loader2 = get_loader(ds1) +def test_determinism(dataset_path): + worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2) + worker_config2b = WorkerConfig(rank=0, world_size=1, num_workers=2, seed_offset=43) + worker_config4 = WorkerConfig(rank=0, world_size=1, num_workers=4) + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds1 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds1b = get_train_dataset( # Same but different seed + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2b, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds2 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds3 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config4, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + + # Fork the dataset twice + with get_loader(ds1) as loader1, get_loader(ds2) as loader2: order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)] @@ -289,172 +303,162 @@ def test_determinism(self): assert order4 != order5 assert order4 == order6 - loader3 = get_loader(ds1b) + with get_loader(ds1b) as loader3: order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)] assert order6 != order7 - loader4 = get_loader(ds3) + with get_loader(ds3) as loader4: order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)] assert order6 != order8[: len(order6)] print(Counter(order8)) assert all(90 <= v <= 110 for v in Counter(order8).values()) - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - - def test_determinism_taskencoder(self): - class TestTaskEncoder(DefaultTaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample: TextSample) -> TextSample: - rand_str = f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" - return TextSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavors__=sample.__subflavors__, - text=sample.text + rand_str, - ) - - for num_workers in [0, 1]: - worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds1a = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), +def test_determinism_taskencoder(dataset_path): + class TestTaskEncoder(DefaultTaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample: TextSample) -> TextSample: + rand_str = f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" + return TextSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavors__=sample.__subflavors__, + text=sample.text + rand_str, ) - torch.manual_seed(44) - ds1b = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), - ) + for num_workers in [0, 1]: + worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds1a = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config1, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), + ) - # Fork the dataset twice - loader1a = get_loader(ds1a) - loader1b = get_loader(ds1b) + torch.manual_seed(44) + ds1b = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config1, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), + ) + # Fork the dataset twice + with get_loader(ds1a) as loader1a, get_loader(ds1b) as loader1b: order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)] order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)] assert order1a == order1b + assert order1a == order1b - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - - def test_determinism_taskencoder_save_restore(self): - class TestTaskEncoder(DefaultTaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample: TextSample) -> TextSample: - rand_str = ( - f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" - + f"_{self.current_batch_index}_{self.current_sample_index}" - ) - - return TextSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavors__=sample.__subflavors__, - text=sample.text + rand_str, - ) - - for num_workers in [1, 0]: - worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds1a = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), +def test_determinism_taskencoder_save_restore(dataset_path): + class TestTaskEncoder(DefaultTaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample: TextSample) -> TextSample: + rand_str = ( + f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" + + f"_{WorkerConfig.active_worker_config.worker_seed()}" + + f"_{self.current_batch_index}_{self.current_sample_index}" ) + print(f"For sample {sample.__restore_key__}: {sample.text}{rand_str}") - torch.manual_seed(44) - ds1b = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), + return TextSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavors__=sample.__subflavors__, + text=sample.text + rand_str, ) - # Fork the dataset twice - loader1a = get_savable_loader(ds1a) - loader1b = get_savable_loader(ds1b) + for num_workers in [1, 0]: + worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds1a = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config1, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), + ) + + torch.manual_seed(44) + ds1b = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config1, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), + ) + # Fork the dataset twice + with get_savable_loader(ds1a) as loader1a: # Load 7 samples - data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] + _data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] # Then save state state = loader1a.save_state_rank() + print("iterating loader1a") # Load another 20 samples data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] # Restore state - loader1b.restore_state_rank(state) - - # Load 20 samples again - data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] + with get_savable_loader(ds1b).with_restored_state_rank(state) as loader1b: + print("iterating loader1b") + # Load 20 samples again + data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] - print("Data post:", data_post) - print("Data restored:", data_restored) + print("Data post:", data_post) + print("Data restored:", data_restored) - assert data_post == data_restored + assert data_post == data_restored - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - def test_restore_state(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) +def test_restore_state(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - count1 = 55 * 20 - count2 = 55 * 20 - sbs = 42 - # count1 = 4 - # count2 = 2 - # sbs = None - psi = None + count1 = 55 * 20 + count2 = 55 * 20 + sbs = 42 + # count1 = 4 + # count2 = 2 + # sbs = None + psi = None - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, ) - + ) as loader: # print("save state") state_0 = loader.save_state_global(global_dst_rank=0) # print("save state done") @@ -469,20 +473,19 @@ def test_restore_state(self): print("state0", state_0) print("state1", state_1) - torch.manual_seed(213) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) + torch.manual_seed(213) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, ) - loader.restore_state_global(state_0, src_rank=None) + ).with_restored_state_global(state_0, src_rank=None) as loader: order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] order_4 = order_45[:count1] order_5 = order_45[count1:] @@ -494,9 +497,9 @@ def test_restore_state(self): assert order_2 == order_5 torch.manual_seed(145) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, split_part="train", sample_type=TextSample, worker_config=worker_config, @@ -505,47 +508,44 @@ def test_restore_state(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - # print("restore state") - loader.restore_state_global(state_1, src_rank=None) - # print("restore state done") - order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] - # print("order1", order_1) - # print("order2", order_2[:100]) - # print("order3", order_3[:100]) - assert order_2 == order_3 + ).with_restored_state_global(state_1, src_rank=None) as loader: + order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] + # print("order1", order_1) + # print("order2", order_2[:100]) + # print("order3", order_3[:100]) + assert order_2 == order_3 - def test_restore_state_dist(self): - from multiprocessing import Manager, Process - import torch.distributed as dist +def test_restore_state_dist(dataset_path): + from multiprocessing import Manager, Process - world_size = 3 + import torch.distributed as dist - count1 = 55 * 20 - count2 = 55 * 20 - sbs = 42 - psi = None + world_size = 3 - def phase1(rank: int, world_size: int, shared_dict: dict): - worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + count1 = 55 * 20 + count2 = 55 * 20 + sbs = 42 + psi = None - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) + def phase1(rank: int, world_size: int, shared_dict: dict): + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ) + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + ) as loader: state_0 = loader.save_state_global(global_dst_rank=0) order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] assert len(order_1) == count1 @@ -563,34 +563,32 @@ def phase1(rank: int, world_size: int, shared_dict: dict): shared_dict["state_0"] = state_0 shared_dict["state_1"] = state_1 - def phase2(rank: int, world_size: int, shared_dict: dict): - order_1 = shared_dict[(rank, "order_1")] - order_2 = shared_dict[(rank, "order_2")] + def phase2(rank: int, world_size: int, shared_dict: dict): + order_1 = shared_dict[(rank, "order_1")] + order_2 = shared_dict[(rank, "order_2")] - if rank == 0: - state_0 = shared_dict["state_0"] - state_1 = shared_dict["state_1"] - else: - state_0 = None - state_1 = None + if rank == 0: + state_0 = shared_dict["state_0"] + state_1 = shared_dict["state_1"] + else: + state_0 = None + state_1 = None - worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) - torch.manual_seed(213) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) + torch.manual_seed(213) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, ) - loader.restore_state_global(state_0, src_rank=0) - + ).with_restored_state_global(state_0, src_rank=0) as loader: order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] order_4 = order_45[:count1] order_5 = order_45[count1:] @@ -600,81 +598,79 @@ def phase2(rank: int, world_size: int, shared_dict: dict): assert order_1 == order_4 assert order_2 == order_5 - torch.manual_seed(213) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) + torch.manual_seed(213) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, ) - loader.restore_state_global(state_1, src_rank=0) + ).with_restored_state_global(state_1, src_rank=0) as loader: order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] assert order_2 == order_3 - def init_process(rank, world_size, shared_dict, fn, backend="gloo"): - """Initializes the distributed environment.""" - dist.init_process_group( - backend=backend, - init_method="tcp://127.0.0.1:12355", - world_size=world_size, - rank=rank, - ) - fn(rank, world_size, shared_dict) - dist.destroy_process_group() - - with Manager() as manager: - shared_dict = manager.dict() - - # Phase 1 (save state) - processes = [] - for rank in range(world_size): - p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) - p.start() - processes.append(p) - - for p in processes: - p.join() - - # Phase 2 (restore state) - processes = [] - for rank in range(world_size): - p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) - p.start() - processes.append(p) - - for p in processes: - p.join() - - def test_restore_state_workers(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) - - psi = 2 - sbs = 42 - n1 = 18 - n2 = 109 - n3 = 28 - ces = 0 - - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, + def init_process(rank, world_size, shared_dict, fn, backend="gloo"): + """Initializes the distributed environment.""" + dist.init_process_group( + backend=backend, + init_method="tcp://127.0.0.1:12355", + world_size=world_size, + rank=rank, ) - loader = get_savable_loader(ds, checkpoint_every_sec=ces) - + fn(rank, world_size, shared_dict) + dist.destroy_process_group() + + with Manager() as manager: + shared_dict = manager.dict() + + # Phase 1 (save state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # Phase 2 (restore state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + +def test_restore_state_workers(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) + + psi = 2 + sbs = 42 + n1 = 18 + n2 = 109 + n3 = 28 + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds) as loader: # print("save state") state_0 = loader.save_state_rank() it1 = iter(loader) @@ -696,136 +692,133 @@ def test_restore_state_workers(self): print("state1", state_1) print("state2", state_2) - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_0) + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_0) as loader: order_6 = [data.text[0] for idx, data in zip(range(n1), loader)] print("order1", order_1) print("order6", order_6) assert order_6 == order_1 - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_1) + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_1) as loader: order_7 = [data.text[0] for idx, data in zip(range(n2), loader)] print("order2", order_2[:100]) print("order7", order_7[:100]) assert order_7 == order_2 - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - max_samples_per_sequence=2, - shuffle_buffer_size=sbs, - parallel_shard_iters=psi, - ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_2) + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + max_samples_per_sequence=2, + shuffle_buffer_size=sbs, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_2) as loader: order_8 = [data.text[0] for idx, data in zip(range(n3), loader)] print("order3", order_3) print("order8", order_8) assert order_8 == order_3 - def test_invariance_global_samples(self): - # We'd like to ensure that the user can keep the same global batches - # (deterministic pseudo random order) when changing the number of ranks (world size). - - # This can be achieved by obeying a few constraints: - # - Global batch size must stay the same across runs - # - Global batch size must be a multiple of (micro-batch size * world_size * num_workers) - # - Global batch size = micro-batch size * world_size * num_workers * gradient_accum_steps - # - world_size * num_workers must stay the same across runs - # Set the same torch.manual_seed(...) on each rank before constructing the dataset and the data loader - - scenarios = [ - dict( - configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), - micro_batch_size=2, - global_batch_size=8, - ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, + +def test_invariance_global_samples(dataset_path): + # We'd like to ensure that the user can keep the same global batches + # (deterministic pseudo random order) when changing the number of ranks (world size). + + # This can be achieved by obeying a few constraints: + # - Global batch size must stay the same across runs + # - Global batch size must be a multiple of (micro-batch size * world_size * num_workers) + # - Global batch size = micro-batch size * world_size * num_workers * gradient_accum_steps + # - world_size * num_workers must stay the same across runs + # Set the same torch.manual_seed(...) on each rank before constructing the dataset and the data loader + + scenarios = [ + dict( + configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=4, num_workers=1), - WorkerConfig(rank=1, world_size=4, num_workers=1), - WorkerConfig(rank=2, world_size=4, num_workers=1), - WorkerConfig(rank=3, world_size=4, num_workers=1), - ), - micro_batch_size=2, - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=4, num_workers=1), + WorkerConfig(rank=1, world_size=4, num_workers=1), + WorkerConfig(rank=2, world_size=4, num_workers=1), + WorkerConfig(rank=3, world_size=4, num_workers=1), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=1, # Micro-batch 1, more accum - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - ] + micro_batch_size=1, # Micro-batch 1, more accum + global_batch_size=8, + ), + ] - # Constraints to user: + # Constraints to user: - global_batches_per_scenario = [] - for scenario in scenarios: - assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( - "Global batch size must be a multiple of the micro-batch size." - ) - - world_size = len(scenario["configs"]) - gradient_accum_steps = scenario["global_batch_size"] // ( - scenario["micro_batch_size"] * world_size - ) + global_batches_per_scenario = [] + for scenario in scenarios: + assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( + "Global batch size must be a multiple of the micro-batch size." + ) - batches_per_rank = [] + world_size = len(scenario["configs"]) + gradient_accum_steps = scenario["global_batch_size"] // ( + scenario["micro_batch_size"] * world_size + ) - for rank_config in scenario["configs"]: - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=rank_config, - batch_size=scenario["micro_batch_size"], - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - loader = get_loader(ds) + batches_per_rank = [] + for rank_config in scenario["configs"]: + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=rank_config, + batch_size=scenario["micro_batch_size"], + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + with get_loader(ds) as loader: micro_batches = [ data.text for idx, data in zip( @@ -834,122 +827,118 @@ def test_invariance_global_samples(self): ] batches_per_rank.append(micro_batches) - # Compose global batches - global_batches_cur_rank = [] - batch_index = 0 - while batch_index < len(batches_per_rank[0]): - global_batch = [] - for _ in range(gradient_accum_steps): - for rank_batches in batches_per_rank: - global_batch.extend(rank_batches[batch_index]) - batch_index += 1 - if batch_index >= len(batches_per_rank[0]): - # last global batch may be smaller - break - global_batches_cur_rank.append(sorted(global_batch)) - - global_batches_per_scenario.append(global_batches_cur_rank) - - # Check that the global batches are the same - - # Assert that all scenarios produced the same number of global batches - assert all( - len(global_batches) == len(global_batches_per_scenario[0]) - for global_batches in global_batches_per_scenario - ), "Number of global batches per scenario does not match." - - for global_batches in global_batches_per_scenario: - print("= Global batches per scenario") - for global_batch in global_batches: - print(" Global batch: ", global_batch) - - # Assert that all global batches are the same - for i in range(len(global_batches_per_scenario[0])): - for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): - assert global_batches[i] == global_batches_per_scenario[0][i], ( - f"Global batch {i} of scenario {scenerio_idx} does not match." - ) + # Compose global batches + global_batches_cur_rank = [] + batch_index = 0 + while batch_index < len(batches_per_rank[0]): + global_batch = [] + for _ in range(gradient_accum_steps): + for rank_batches in batches_per_rank: + global_batch.extend(rank_batches[batch_index]) + batch_index += 1 + if batch_index >= len(batches_per_rank[0]): + # last global batch may be smaller + break + global_batches_cur_rank.append(sorted(global_batch)) - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - - def test_redist(self): - scenarios = [ - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, - ), - dict( - configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), - micro_batch_size=2, - global_batch_size=8, + global_batches_per_scenario.append(global_batches_cur_rank) + + # Check that the global batches are the same + + # Assert that all scenarios produced the same number of global batches + assert all( + len(global_batches) == len(global_batches_per_scenario[0]) + for global_batches in global_batches_per_scenario + ), "Number of global batches per scenario does not match." + + for global_batches in global_batches_per_scenario: + print("= Global batches per scenario") + for global_batch in global_batches: + print(" Global batch: ", global_batch) + + # Assert that all global batches are the same + for i in range(len(global_batches_per_scenario[0])): + for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): + assert global_batches[i] == global_batches_per_scenario[0][i], ( + f"Global batch {i} of scenario {scenerio_idx} does not match." + ) + + +def test_redist(dataset_path, checkpoint_dir, redist_dir): + scenarios = [ + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=4, num_workers=1), - WorkerConfig(rank=1, world_size=4, num_workers=1), - WorkerConfig(rank=2, world_size=4, num_workers=1), - WorkerConfig(rank=3, world_size=4, num_workers=1), - ), - micro_batch_size=2, - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=4, num_workers=1), + WorkerConfig(rank=1, world_size=4, num_workers=1), + WorkerConfig(rank=2, world_size=4, num_workers=1), + WorkerConfig(rank=3, world_size=4, num_workers=1), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=1, # Micro-batch 1, more accum - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - dict( # Same as original - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, + micro_batch_size=1, # Micro-batch 1, more accum + global_batch_size=8, + ), + dict( # Same as original + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - ] + micro_batch_size=2, + global_batch_size=8, + ), + ] - # === Stage 1 first generate a saved state using scenario 0 - checkpoint_files = [] + # === Stage 1 first generate a saved state using scenario 0 + checkpoint_files = [] - global_batches_per_scenario = [] - scenario = scenarios[0] + global_batches_per_scenario = [] + scenario = scenarios[0] - world_size = len(scenario["configs"]) - gradient_accum_steps = scenario["global_batch_size"] // ( - scenario["micro_batch_size"] * world_size - ) + world_size = len(scenario["configs"]) + gradient_accum_steps = scenario["global_batch_size"] // ( + scenario["micro_batch_size"] * world_size + ) - batches_per_rank = [] + batches_per_rank = [] - for rank_config in scenario["configs"]: - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=rank_config, - batch_size=scenario["micro_batch_size"], - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) + for rank_config in scenario["configs"]: + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=rank_config, + batch_size=scenario["micro_batch_size"], + shuffle_buffer_size=42, + max_samples_per_sequence=2, ) - + ) as loader: # Throw away some samples to advance the loader state num_pre_samples = 20 for _ in zip(range(num_pre_samples), loader): pass # Save the state to a file - checkpoint_file = self.checkpoint_dir / f"state_rank{rank_config.rank}.pt" + checkpoint_file = checkpoint_dir / f"state_rank{rank_config.rank}.pt" state = loader.save_state_rank() torch.save(state, str(checkpoint_file)) checkpoint_files.append(checkpoint_file) @@ -963,6 +952,81 @@ def test_redist(self): ] batches_per_rank.append(micro_batches) + # Compose global batches + global_batches_cur_rank = [] + batch_index = 0 + while batch_index < len(batches_per_rank[0]): + global_batch = [] + for _ in range(gradient_accum_steps): + for rank_batches in batches_per_rank: + global_batch.extend(rank_batches[batch_index]) + batch_index += 1 + if batch_index >= len(batches_per_rank[0]): + # last global batch may be smaller + break + global_batches_cur_rank.append(sorted(global_batch)) + + global_batches_per_scenario.append(global_batches_cur_rank) + + # === Stage 2: Now check that the global batches are the same after redistribution + + for scenario in scenarios[1:]: + print(f"\n\nRunning scenario {scenario}") + # Redistribute the saved state + runner = CliRunner() + result = runner.invoke( + command_redist, + [ + "--new-world-size", + str(len(scenario["configs"])), + "--new-micro-batch-size", + str(scenario["micro_batch_size"]), + *[str(cpt) for cpt in checkpoint_files], + str(redist_dir), + ], + ) + print(result.output) + if result.exception is not None: + raise result.exception + assert result.exception is None, result.exception + assert result.exit_code == 0, "Redistribution failed" + + # Load state and check that the global batches are the same + assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( + "Global batch size must be a multiple of the micro-batch size." + ) + + world_size = len(scenario["configs"]) + gradient_accum_steps = scenario["global_batch_size"] // ( + scenario["micro_batch_size"] * world_size + ) + + batches_per_rank = [] + + for rank_config in scenario["configs"]: + state = torch.load( + str(redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False + ) + + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=rank_config, + batch_size=scenario["micro_batch_size"], + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ).with_restored_state_rank(state) as loader: + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader + ) + ] + batches_per_rank.append(micro_batches) + # Compose global batches global_batches_cur_rank = [] batch_index = 0 @@ -979,104 +1043,24 @@ def test_redist(self): global_batches_per_scenario.append(global_batches_cur_rank) - # === Stage 2: Now check that the global batches are the same after redistribution + # Check that the global batches are the same - for scenario in scenarios[1:]: - # Redistribute the saved state - runner = CliRunner() - result = runner.invoke( - command_redist, - [ - "--new-world-size", - str(len(scenario["configs"])), - *[str(cpt) for cpt in checkpoint_files], - str(self.redist_dir), - ], - ) - print(result.output) - assert result.exception is None, result.exception - assert result.exit_code == 0, "Redistribution failed" - - # Load state and check that the global batches are the same - assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( - "Global batch size must be a multiple of the micro-batch size." - ) - - world_size = len(scenario["configs"]) - gradient_accum_steps = scenario["global_batch_size"] // ( - scenario["micro_batch_size"] * world_size - ) - - batches_per_rank = [] - - for rank_config in scenario["configs"]: - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=rank_config, - batch_size=scenario["micro_batch_size"], - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - ) - - state = torch.load( - str(self.redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False - ) - loader.restore_state_rank(state) + print() - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) - - # Compose global batches - global_batches_cur_rank = [] - batch_index = 0 - while batch_index < len(batches_per_rank[0]): - global_batch = [] - for _ in range(gradient_accum_steps): - for rank_batches in batches_per_rank: - global_batch.extend(rank_batches[batch_index]) - batch_index += 1 - if batch_index >= len(batches_per_rank[0]): - # last global batch may be smaller - break - global_batches_cur_rank.append(sorted(global_batch)) - - global_batches_per_scenario.append(global_batches_cur_rank) - - # Check that the global batches are the same - - print() - - # Assert that all scenarios produced the same global batches - assert all( - len(global_batches) == len(global_batches_per_scenario[0]) - for global_batches in global_batches_per_scenario - ), "Number of global batches per scenario does not match." - - for global_batches in global_batches_per_scenario: - print("= Global batches per scenario") - for global_batch in global_batches: - print(" Global batch: ", global_batch) - - # Assert that all global batches are the same - for i in range(len(global_batches_per_scenario[0])): - for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): - assert global_batches[i] == global_batches_per_scenario[0][i], ( - f"Global batch {i} of scenario {scenerio_idx} does not match." - ) + # Assert that all scenarios produced the same global batches + assert all( + len(global_batches) == len(global_batches_per_scenario[0]) + for global_batches in global_batches_per_scenario + ), "Number of global batches per scenario does not match." - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() + for idx, (global_batches, scenario) in enumerate(zip(global_batches_per_scenario, scenarios)): + print(f"= Global batches per scenario {idx} {scenario}") + for global_batch in global_batches: + print(" Global batch: ", global_batch) - -if __name__ == "__main__": - unittest.main() + # Assert that all global batches are the same + for i in range(len(global_batches_per_scenario[0])): + for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): + assert global_batches[i] == global_batches_per_scenario[0][i], ( + f"Global batch {i} of scenario {scenerio_idx} does not match." + ) diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index bb00467c..7f3696c6 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -9,275 +9,282 @@ import pickle import struct import sys -import unittest + +import pytest from megatron.energon.epathlib import EPath from tests.epath_s3_emulator import setup_s3_emulator -class TestEPath(unittest.TestCase): - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) - def tearDown(self): - pass - def test_basic(self): - """Some basic functionality tests""" +def test_basic(): + """Some basic functionality tests""" - p_rel = EPath("./subdir") - p_abs = EPath("/tmp") + p_rel = EPath("./subdir") + p_abs = EPath("/tmp") - p_comb = p_abs / p_rel - # Those should not raise: - assert not p_comb.is_file() - assert not p_abs.is_file() - assert p_abs.is_dir() + p_comb = p_abs / p_rel + # Those should not raise: + assert not p_comb.is_file() + assert not p_abs.is_file() + assert p_abs.is_dir() - def test_contextman(self): - """Test the context manager""" - tmp_file_path = "/tmp/testfile.bin" - # First create a file - with open(tmp_file_path, "wb") as f: - f.write(struct.pack("H10s", 1337, b"1234567890")) +def test_contextman(): + """Test the context manager""" - # Test context manager reading - p = EPath(tmp_file_path).open("rb") - print(p) - with p: - b = p.read() - assert isinstance(b, bytes) + tmp_file_path = "/tmp/testfile.bin" + # First create a file + with open(tmp_file_path, "wb") as f: + f.write(struct.pack("H10s", 1337, b"1234567890")) - num, data = struct.unpack("H10s", b) - logging.info(f"num: {num}") - assert num == 1337 - assert data == b"1234567890" + # Test context manager reading + p = EPath(tmp_file_path).open("rb") + print(p) + with p: + b = p.read() + assert isinstance(b, bytes) - # Test context manager writing - tmp_file_path2 = "/tmp/testfile2.bin" - with EPath(tmp_file_path2).open("wb") as p: - p.write(struct.pack("H10s", 1337, b"1234567890")) + num, data = struct.unpack("H10s", b) + logging.info(f"num: {num}") + assert num == 1337 + assert data == b"1234567890" - def test_localfs(self): - """Test the local filesystem""" - p = EPath("/tmp/testfile.bin") - with p.open("wb") as f: - f.write(b"dummycontent") + # Test context manager writing + tmp_file_path2 = "/tmp/testfile2.bin" + with EPath(tmp_file_path2).open("wb") as p: + p.write(struct.pack("H10s", 1337, b"1234567890")) + + +def test_localfs(): + """Test the local filesystem""" + p = EPath("/tmp/testfile.bin") + with p.open("wb") as f: + f.write(b"dummycontent") + assert p.is_file() + assert p.size() == 12 + with p.open("rb") as f: + assert f.read() == b"dummycontent" + + # Test relative paths + revert_dir = os.getcwd() + try: + os.chdir("/tmp") + p = EPath("testfile.bin") + assert str(p) == "/tmp/testfile.bin" assert p.is_file() assert p.size() == 12 with p.open("rb") as f: assert f.read() == b"dummycontent" - # Test relative paths - revert_dir = os.getcwd() - try: - os.chdir("/tmp") - p = EPath("testfile.bin") - assert str(p) == "/tmp/testfile.bin" - assert p.is_file() - assert p.size() == 12 - with p.open("rb") as f: - assert f.read() == b"dummycontent" - - p = EPath("nonexisting/../testfile.bin") - assert str(p) == "/tmp/testfile.bin" - - p = EPath("../tmp/testfile.bin") - assert str(p) == "/tmp/testfile.bin" - finally: - os.chdir(revert_dir) + p = EPath("nonexisting/../testfile.bin") + assert str(p) == "/tmp/testfile.bin" - p.unlink() - assert p.is_file() is False + p = EPath("../tmp/testfile.bin") + assert str(p) == "/tmp/testfile.bin" + finally: + os.chdir(revert_dir) - def test_glob(self): - """Test the glob functionality""" + p.unlink() + assert p.is_file() is False - # First create some files - for i in range(10): - with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f: - f.write(b"dummycontent") - # Test globbing - p = EPath("/tmp").glob("epathtestfile_*.bin") - - logging.info(f"p: {p}, type of p: {type(p)}") - elems = list(p) - assert len(elems) == 10 - for i, e in enumerate(elems): - logging.info(f"glob_result[{i}]: {e}") - assert isinstance(e, EPath) - assert e.is_file() - - # Test globbing with a pattern - p = EPath("/tmp").glob("epathtestfile_[0-3].bin") - assert len(list(p)) == 4 - - def test_s3_path_resolution(self): - """Test s3 path resolution""" - rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf") - with rclone_config_path.open("w") as f: - f.write( - "\n".join( - [ - "[s3]", - "type = s3", - "env_auth = false", - "access_key_id = dummy", - "secret_access_key = dummy", - "region = dummy", - "endpoint = https://localhost", - ] - ) +def test_glob(): + """Test the glob functionality""" + + # First create some files + for i in range(10): + with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f: + f.write(b"dummycontent") + + # Test globbing + p = EPath("/tmp").glob("epathtestfile_*.bin") + + logging.info(f"p: {p}, type of p: {type(p)}") + elems = list(p) + assert len(elems) == 10 + for i, e in enumerate(elems): + logging.info(f"glob_result[{i}]: {e}") + assert isinstance(e, EPath) + assert e.is_file() + + # Test globbing with a pattern + p = EPath("/tmp").glob("epathtestfile_[0-3].bin") + assert len(list(p)) == 4 + + +def test_s3_path_resolution(): + """Test s3 path resolution""" + rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf") + with rclone_config_path.open("w") as f: + f.write( + "\n".join( + [ + "[s3]", + "type = s3", + "env_auth = false", + "access_key_id = dummy", + "secret_access_key = dummy", + "region = dummy", + "endpoint = https://localhost", + ] ) + ) - orig_xdg_config_home = os.environ.get("XDG_CONFIG_HOME") - os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config" - os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME" - # Hack to clear the cache of the rclone config for msc to get the "s3" profile - from multistorageclient.rclone import read_rclone_config - - read_rclone_config.cache_clear() - try: - # Test globbing - p = EPath("msc://s3/tmp/path/subpath.txt") - assert str(p) == "msc://s3/tmp/path/subpath.txt", str(p) - - p2 = p / ".." / "subpath2.txt" - assert str(p2) == "msc://s3/tmp/path/subpath2.txt", str(p2) - - p3 = EPath("msc://s3/tmp/path/.././subpath.txt") - assert str(p3) == "msc://s3/tmp/subpath.txt", str(p3) - - p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt" - assert str(p4) == "msc://s3/subpath2.txt", str(p4) - - # Test warning for deprecated rclone protocol - with self.assertWarns((DeprecationWarning, FutureWarning)) as warning: - # Test rclone backwards compatibility - pr = EPath("rclone://s3/tmp/path/.././subpath.txt") - assert str(pr) == "msc://s3/tmp/subpath.txt", str(pr) - assert "deprecated" in str(warning.warnings[0].message) - - # Test pickle / unpickle - p4serialized = pickle.dumps(p4) - # No secret must be serialized - assert b"dummy" not in p4serialized - finally: - if orig_xdg_config_home is not None: - os.environ["XDG_CONFIG_HOME"] = orig_xdg_config_home - else: - del os.environ["XDG_CONFIG_HOME"] - rclone_config_path.unlink() - - def test_multi_storage_client(self): - """Test the Multi-Storage Client integration""" - # Test path handling - p = EPath("msc://default/etc/resolv.conf") - assert str(p) == "/etc/resolv.conf", str(p) - assert p.is_file() + orig_xdg_config_home = os.environ.get("XDG_CONFIG_HOME") + os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config" + os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME" + # Hack to clear the cache of the rclone config for msc to get the "s3" profile + from multistorageclient.rclone import read_rclone_config - p2 = p / ".." / "hosts" - assert str(p2) == "/etc/hosts", str(p2) + read_rclone_config.cache_clear() + try: + # Test globbing + p = EPath("msc://s3/tmp/path/subpath.txt") + assert str(p) == "msc://s3/tmp/path/subpath.txt", str(p) - # Test glob - p3 = EPath("msc://default/etc/") - assert p3.is_dir() - for i in p3.glob("*.conf"): - assert str(i).endswith(".conf") + p2 = p / ".." / "subpath2.txt" + assert str(p2) == "msc://s3/tmp/path/subpath2.txt", str(p2) - # Test open file - assert p.size() > 0 - with p.open("r") as fp: - assert len(fp.read()) > 0 - - # Test move and delete - p4 = EPath("msc://default/tmp/random_file_0001") - if p4.is_file(): - p4.unlink() - with p4.open("w") as fp: - fp.write("*****") - assert p4.is_file() - p5 = EPath("msc://default/tmp/random_file_0002") - if p5.is_file(): - p5.unlink() - assert p5.is_file() is False - p4.move(p5) - assert p5.is_file() - assert p4.is_file() is False - p5.unlink() - assert p5.is_file() is False + p3 = EPath("msc://s3/tmp/path/.././subpath.txt") + assert str(p3) == "msc://s3/tmp/subpath.txt", str(p3) + + p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt" + assert str(p4) == "msc://s3/subpath2.txt", str(p4) + + # Test warning for deprecated rclone protocol + with pytest.warns((DeprecationWarning, FutureWarning)) as warning: + # Test rclone backwards compatibility + pr = EPath("rclone://s3/tmp/path/.././subpath.txt") + assert str(pr) == "msc://s3/tmp/subpath.txt", str(pr) + assert "deprecated" in str(warning[0].message) # Test pickle / unpickle - p5serialized = pickle.dumps(p5) - p5unserialized = pickle.loads(p5serialized) - assert p5unserialized == p5 - assert str(p5unserialized) == str(p5) - - def test_multiprocessing(self): - """Test EPath in multiprocessing context""" - p = EPath("/tmp/path/subpath.txt") - - orig_start_method = multiprocessing.get_start_method() - try: - multiprocessing.set_start_method("spawn", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - - multiprocessing.set_start_method("fork", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - finally: - multiprocessing.set_start_method(orig_start_method, force=True) - - def test_multiprocessing_msc(self): - """Test EPath in multiprocessing context""" - p = EPath("msc://default/tmp/random_file_0001") - with p.open("w") as fp: - fp.write("*****") - - orig_start_method = multiprocessing.get_start_method() - try: - multiprocessing.set_start_method("spawn", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - - multiprocessing.set_start_method("fork", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - finally: - multiprocessing.set_start_method(orig_start_method, force=True) - p.unlink() - - def test_msc_s3(self): - # Test S3 with MSC - with setup_s3_emulator(profile_name="s3test_msc"): - p = EPath("msc://s3test_msc/test/dir/file.txt") - assert not p.is_file() - p.write_text("dummy") - assert p.is_file() - assert p.size() > 0 - assert p.read_text() == "dummy" - # TODO: Fix when fixed in MSC. - # assert EPath("msc://s3test_msc/test").is_dir() - assert EPath("msc://s3test_msc/test/dir").is_dir() - p.unlink() - assert not p.is_file() - # assert not EPath("msc://s3test_msc/test").is_dir() - assert not EPath("msc://s3test_msc/test/dir").is_dir() + p4serialized = pickle.dumps(p4) + # No secret must be serialized + assert b"dummy" not in p4serialized + finally: + if orig_xdg_config_home is not None: + os.environ["XDG_CONFIG_HOME"] = orig_xdg_config_home + else: + del os.environ["XDG_CONFIG_HOME"] + rclone_config_path.unlink() + + +def test_multi_storage_client(): + """Test the Multi-Storage Client integration""" + # Test path handling + p = EPath("msc://default/etc/resolv.conf") + assert str(p) == "/etc/resolv.conf", str(p) + assert p.is_file() + + p2 = p / ".." / "hosts" + assert str(p2) == "/etc/hosts", str(p2) + + # Test glob + p3 = EPath("msc://default/etc/") + assert p3.is_dir() + for i in p3.glob("*.conf"): + assert str(i).endswith(".conf") + + # Test open file + assert p.size() > 0 + with p.open("r") as fp: + assert len(fp.read()) > 0 + + # Test move and delete + p4 = EPath("msc://default/tmp/random_file_0001") + if p4.is_file(): + p4.unlink() + with p4.open("w") as fp: + fp.write("*****") + assert p4.is_file() + p5 = EPath("msc://default/tmp/random_file_0002") + if p5.is_file(): + p5.unlink() + assert p5.is_file() is False + p4.move(p5) + assert p5.is_file() + assert p4.is_file() is False + p5.unlink() + assert p5.is_file() is False + + # Test pickle / unpickle + p5serialized = pickle.dumps(p5) + p5unserialized = pickle.loads(p5serialized) + assert p5unserialized == p5 + assert str(p5unserialized) == str(p5) + + +def test_multiprocessing(): + """Test EPath in multiprocessing context""" + p = EPath("/tmp/path/subpath.txt") + + orig_start_method = multiprocessing.get_start_method() + try: + multiprocessing.set_start_method("spawn", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + multiprocessing.set_start_method("fork", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + finally: + multiprocessing.set_start_method(orig_start_method, force=True) + + +def test_multiprocessing_msc(): + """Test EPath in multiprocessing context""" + p = EPath("msc://default/tmp/random_file_0001") + with p.open("w") as fp: + fp.write("*****") + + orig_start_method = multiprocessing.get_start_method() + try: + multiprocessing.set_start_method("spawn", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + multiprocessing.set_start_method("fork", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + finally: + multiprocessing.set_start_method(orig_start_method, force=True) + p.unlink() + + +def test_msc_s3(): + # Test S3 with MSC + with setup_s3_emulator(profile_name="s3test_msc"): + p = EPath("msc://s3test_msc/test/dir/file.txt") + assert not p.is_file() + p.write_text("dummy") + assert p.is_file() + assert p.size() > 0 + assert p.read_text() == "dummy" + # TODO: Fix when fixed in MSC. + # assert EPath("msc://s3test_msc/test").is_dir() + assert EPath("msc://s3test_msc/test/dir").is_dir() + p.unlink() + assert not p.is_file() + # assert not EPath("msc://s3test_msc/test").is_dir() + assert not EPath("msc://s3test_msc/test/dir").is_dir() def _multiproc_test_func(p: EPath, test_function: bool): @@ -285,7 +292,3 @@ def _multiproc_test_func(p: EPath, test_function: bool): print(f"str: {str(p)}") if test_function: print(f"is_file: {p.is_file()}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_file_cache_pool.py b/tests/test_file_cache_pool.py index 5f835987..016f6e1b 100644 --- a/tests/test_file_cache_pool.py +++ b/tests/test_file_cache_pool.py @@ -5,10 +5,11 @@ import tempfile import threading import time -import unittest from pathlib import Path from typing import Any, Dict, Optional +import pytest + from megatron.energon.cache import DecodeFileStore, FileCacheLazy, FileStore, FileStoreCachePool from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder from megatron.energon.source_info import SourceInfo @@ -32,6 +33,15 @@ def __getitem__(self, key: str) -> tuple[Any, SourceInfo]: def get_path(self) -> str: return self._path + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + + def close(self) -> None: + pass + class MockDecoder(SampleDecoder): """Mock decoder for DecodeFileStore""" @@ -40,642 +50,643 @@ def decode(self, fname: str, raw: bytes) -> Any: return f"{fname}: {raw.decode()}" -class TestFileStoreCachePool(unittest.TestCase): - """Test cases for FileStoreCachePool""" - - def setUp(self): - """Setup test environment before each test""" - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = Path(self.temp_dir.name) - - def tearDown(self): - """Clean up after each test""" - self.temp_dir.cleanup() - - def test_get_method(self): - """Test the synchronous get method""" - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - "file2": b"test data 2", - "file3": b"test data 3", - } +"""Test cases for FileStoreCachePool""" + + +@pytest.fixture +def temp_dir(): + """Setup test environment before each test""" + # Create a temporary directory + temp_dir = tempfile.TemporaryDirectory() + temp_path = Path(temp_dir.name) + yield temp_path + temp_dir.cleanup() + + +def test_get_method(temp_dir): + """Test the synchronous get method""" + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + "file2": b"test data 2", + "file3": b"test data 3", + } + ) + + mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, + decoder=MockDecoder(), + ) + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + try: + # get should directly read from the dataset without caching + sample_for_source_info = {"__sources__": []} + result = pool.get(mock_raw_file_store, "file1", sample_for_source_info) + assert result == b"test data 1" + assert len(sample_for_source_info["__sources__"]) == 1 + assert ( + sample_for_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path() ) - - mock_decode_file_store = DecodeFileStore( - inner=mock_raw_file_store, - decoder=MockDecoder(), + assert sample_for_source_info["__sources__"][0].index is None + assert sample_for_source_info["__sources__"][0].shard_name is None + assert sample_for_source_info["__sources__"][0].file_names == ("file1",) + + # get should directly read from the dataset without caching + sample_for_source_info = {"__sources__": []} + result = pool.get(mock_decode_file_store, "file1", sample_for_source_info) + assert result == "file1: test data 1" + assert len(sample_for_source_info["__sources__"]) == 1 + assert ( + sample_for_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - try: - # get should directly read from the dataset without caching - sample_for_source_info = {"__sources__": []} - result = pool.get(mock_raw_file_store, "file1", sample_for_source_info) - assert result == b"test data 1" - assert len(sample_for_source_info["__sources__"]) == 1 - assert ( - sample_for_source_info["__sources__"][0].dataset_path - == mock_raw_file_store.get_path() - ) - assert sample_for_source_info["__sources__"][0].index is None - assert sample_for_source_info["__sources__"][0].shard_name is None - assert sample_for_source_info["__sources__"][0].file_names == ("file1",) - - # get should directly read from the dataset without caching - sample_for_source_info = {"__sources__": []} - result = pool.get(mock_decode_file_store, "file1", sample_for_source_info) - assert result == "file1: test data 1" - assert len(sample_for_source_info["__sources__"]) == 1 - assert ( - sample_for_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() - ) - assert sample_for_source_info["__sources__"][0].index is None - assert sample_for_source_info["__sources__"][0].shard_name is None - assert sample_for_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_get_lazy_method(self): - """Test the lazy get method for background prefetching""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } - ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") - - # Verify the return type - assert isinstance(lazy_ref, FileCacheLazy) - - # Wait for the background task - lazy_ref.entry.send_to_cache_future.result() - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.glob("*")) - assert len(cache_files) == 1 - - # Get the data - result = lazy_ref.get() - assert result == b"test data 1" - finally: - pool.close() - - def test_shared_references(self): - """Test that multiple references share the same background task""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } - ) - try: - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 0 - - # Request lazy loading for the same file twice - lazy_ref1 = pool.get_lazy(mock_raw_file_store, "file1") - lazy_ref2 = pool.get_lazy(mock_raw_file_store, "file1") - - # Check that they share the same entry - assert lazy_ref1.entry is lazy_ref2.entry - - # Check that refcount is 2 - assert lazy_ref1.entry.refcount == 2 - - # Wait for the background task - lazy_ref1.entry.send_to_cache_future.result() - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 1, cache_files - - # Get data from both references - sample_with_source_info = {"__sources__": []} - result1 = lazy_ref1.get(sample_with_source_info) - assert lazy_ref1.entry.refcount == 1 - sample_with_source_info2 = {"__sources__": []} - result2 = lazy_ref2.get(sample_with_source_info2) - assert lazy_ref1.entry.refcount == 0 - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 0 - - assert result1 == b"test data 1" - assert result2 == b"test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == sample_with_source_info2["__sources__"][0].dataset_path - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert ( - sample_with_source_info["__sources__"][0].file_names - == sample_with_source_info2["__sources__"][0].file_names - ) - finally: - pool.close() - - def test_cache_size_management(self): - """Test that the cache respects size limits and evicts files""" - # Create a cache pool with strict limits - pool = FileStoreCachePool( - parent_cache_dir=self.temp_path, - max_cache_size_gbytes=0.0001, # ~100KB - max_cache_count=2, - num_workers=1, + assert sample_for_source_info["__sources__"][0].index is None + assert sample_for_source_info["__sources__"][0].shard_name is None + assert sample_for_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_get_lazy_method(temp_dir): + """Test the lazy get method for background prefetching""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") + + # Verify the return type + assert isinstance(lazy_ref, FileCacheLazy) + + # Wait for the background task + lazy_ref.entry.send_to_cache_future.result() + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.glob("*")) + assert len(cache_files) == 1 + + # Get the data + result = lazy_ref.get() + assert result == b"test data 1" + finally: + pool.close() + + +def test_shared_references(temp_dir): + """Test that multiple references share the same background task""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + try: + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 0 + + # Request lazy loading for the same file twice + lazy_ref1 = pool.get_lazy(mock_raw_file_store, "file1") + lazy_ref2 = pool.get_lazy(mock_raw_file_store, "file1") + + # Check that they share the same entry + assert lazy_ref1.entry is lazy_ref2.entry + + # Check that refcount is 2 + assert lazy_ref1.entry.refcount == 2 + + # Wait for the background task + lazy_ref1.entry.send_to_cache_future.result() + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 1, cache_files + + # Get data from both references + sample_with_source_info = {"__sources__": []} + result1 = lazy_ref1.get(sample_with_source_info) + assert lazy_ref1.entry.refcount == 1 + sample_with_source_info2 = {"__sources__": []} + result2 = lazy_ref2.get(sample_with_source_info2) + assert lazy_ref1.entry.refcount == 0 + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 0 + + assert result1 == b"test data 1" + assert result2 == b"test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == sample_with_source_info2["__sources__"][0].dataset_path ) - # Set to a safe byte size - pool.max_cache_size = 75_000 - - mock_raw_file_store = MockFileStore( - { - "large_file1": b"a" * 50_000, - "large_file2": b"b" * 50_000, - "large_file3": b"c" * 50_000, - "large_file4": b"d" * 25_000, - "large_file5": b"e" * 25_000, - "large_file6": b"f" * 25_000, - } + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert ( + sample_with_source_info["__sources__"][0].file_names + == sample_with_source_info2["__sources__"][0].file_names ) - - try: - # Enqueue all fetches - lazy1 = pool.get_lazy(mock_raw_file_store, "large_file1") - lazy2 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy3 = pool.get_lazy(mock_raw_file_store, "large_file3") - lazy4 = pool.get_lazy(mock_raw_file_store, "large_file4") - lazy2_2 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy2_3 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy3_2 = pool.get_lazy(mock_raw_file_store, "large_file3") - lazy5 = pool.get_lazy(mock_raw_file_store, "large_file5") - lazy6 = pool.get_lazy(mock_raw_file_store, "large_file6") - lazy6_2 = pool.get_lazy(mock_raw_file_store, "large_file6") - - def status(): - return [ - ( - name, - lazy.entry.refcount, - "consumed" - if lazy._data - else ("cached" if lazy.entry.send_to_cache_future.done() else "pending"), + finally: + pool.close() + + +def test_cache_size_management(temp_dir): + """Test that the cache respects size limits and evicts files""" + # Create a cache pool with strict limits + pool = FileStoreCachePool( + parent_cache_dir=temp_dir, + max_cache_size_gbytes=0.0001, # ~100KB + max_cache_count=2, + num_workers=1, + ) + # Set to a safe byte size + pool.max_cache_size = 75_000 + + mock_raw_file_store = MockFileStore( + { + "large_file1": b"a" * 50_000, + "large_file2": b"b" * 50_000, + "large_file3": b"c" * 50_000, + "large_file4": b"d" * 25_000, + "large_file5": b"e" * 25_000, + "large_file6": b"f" * 25_000, + } + ) + + try: + # Enqueue all fetches + lazy1 = pool.get_lazy(mock_raw_file_store, "large_file1") + lazy2 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy3 = pool.get_lazy(mock_raw_file_store, "large_file3") + lazy4 = pool.get_lazy(mock_raw_file_store, "large_file4") + lazy2_2 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy2_3 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy3_2 = pool.get_lazy(mock_raw_file_store, "large_file3") + lazy5 = pool.get_lazy(mock_raw_file_store, "large_file5") + lazy6 = pool.get_lazy(mock_raw_file_store, "large_file6") + lazy6_2 = pool.get_lazy(mock_raw_file_store, "large_file6") + + def status(): + return [ + ( + name, + lazy.entry.refcount, + "consumed" + if lazy._data + else ("cached" if lazy.entry.send_to_cache_future.done() else "pending"), + ) + for lazy, name in ( + [ + (lazy1, "1"), + (lazy2, "2"), + (lazy2_2, "2_2"), + (lazy2_3, "2_3"), + (lazy3, "3"), + (lazy3_2, "3_2"), + (lazy4, "4"), + (lazy5, "5"), + (lazy6, "6"), + ] + + ([(lazy6_2, "6_2")] if lazy6_2 is not None else []) + ) + ] + + def txt_status(): + out = [] + for lazy in [ + lazy1, + lazy2, + lazy2_2, + lazy2_3, + lazy3, + lazy3_2, + lazy4, + lazy5, + lazy6, + ] + ([lazy6_2] if lazy6_2 is not None else []): + if lazy._data is not None: + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] consumed" ) - for lazy, name in ( - [ - (lazy1, "1"), - (lazy2, "2"), - (lazy2_2, "2_2"), - (lazy2_3, "2_3"), - (lazy3, "3"), - (lazy3_2, "3_2"), - (lazy4, "4"), - (lazy5, "5"), - (lazy6, "6"), - ] - + ([(lazy6_2, "6_2")] if lazy6_2 is not None else []) + elif lazy.entry.send_to_cache_future.done(): + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] cached" ) - ] - - def txt_status(): - out = [] - for lazy in [ - lazy1, - lazy2, - lazy2_2, - lazy2_3, - lazy3, - lazy3_2, - lazy4, - lazy5, - lazy6, - ] + ([lazy6_2] if lazy6_2 is not None else []): - if lazy._data is not None: - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] consumed" - ) - elif lazy.entry.send_to_cache_future.done(): - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] cached" - ) - else: - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] pending" - ) - return ( - f"Cached Count: {pool.current_cache_count}, Cache size: {pool.current_cache_size}\n" - + "\n".join(out) - ) - - # lazy2_2 and lazy2_3 should share the same entry as lazy2 - assert lazy2_2.entry is lazy2.entry - assert lazy2_3.entry is lazy2.entry - - lazy1.entry.send_to_cache_future.result(timeout=1) - # Wait for the background tasks to finish - time.sleep(0.5) - - print("Checking cache status") - # They should not be able to finish, because the cache is full - # Queue state: [2<50>, 3<50>, 4<25>, 5<25>, 6<25>], cached out: [1<50>], removed: [] - assert status() == [ - ("1", 1, "cached"), - ("2", 3, "pending"), - ("2_2", 3, "pending"), - ("2_3", 3, "pending"), - ("3", 2, "pending"), - ("3_2", 2, "pending"), - ("4", 1, "pending"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - - # Check cache count and size before second file - assert pool.current_cache_count == 1, pool.current_cache_count - assert pool.current_cache_size == 50_000, pool.current_cache_size - - print("Fetching lazy2_3") - # Now, fetching the second file should still work directly and ignore the caching - # But it will requeue fetching the second file to the background thread for the remaining lazies. - result2_3 = lazy2_3.get() - assert result2_3 == b"b" * 50_000 - - # They should not be able to finish, because the cache is full - # Queue state: [3<50>, 4<25>, 5<25>, 6<25>, 2<50>], cached out: [1<50>], removed: [] - assert status() == [ - ("1", 1, "cached"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 2, "pending"), - ("3_2", 2, "pending"), - ("4", 1, "pending"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - - # Fetch - result1 = lazy1.get() - assert result1 == b"a" * 50_000 - - lazy3.entry.send_to_cache_future.result(timeout=1) - - time.sleep(0.5) - - # Second file is now queued at the end. - # File 3 and 4 should now be cached. - # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 2, "cached"), - ("3_2", 2, "cached"), - ("4", 1, "cached"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result3 = lazy3.get() - assert result3 == b"c" * 50_000 - - time.sleep(0.5) - - # Space by large_file3 is still occupied in cache - # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 1, "consumed"), - ("3_2", 1, "cached"), - ("4", 1, "cached"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result3_2 = lazy3_2.get() - assert result3_2 == b"c" * 50_000 - - time.sleep(0.5) - - # Space by large_file3 was freed now, 4, 5, and 6 should fit now, large_file2 not yet - # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 1, "cached"), - ("5", 1, "cached"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 50_000 - - result4 = lazy4.get() - assert result4 == b"d" * 25_000 - - time.sleep(0.5) - - # Nothing changed, no space for large_file2 still - # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>, 4<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 1, "cached"), - ("6", 2, "cached"), - ("6_2", 2, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 50_000 - - result5 = lazy5.get() - assert result5 == b"e" * 25_000 - - time.sleep(0.5) - - # Now large_file2 can be cached - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "cached"), - ("2_2", 2, "cached"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 2, "cached"), - ("6_2", 2, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result6 = lazy6.get() - assert result6 == b"f" * 25_000 - - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "cached"), - ("2_2", 2, "cached"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result2 = lazy2.get() - assert result2 == b"b" * 50_000 - - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 1, "consumed"), - ("2_2", 1, "cached"), - ("2_3", 1, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result2_2 = lazy2_2.get() - assert result2_2 == b"b" * 50_000 - - # Cache should only contain large_file6 now - # Queue state: [], cached out: [6<25>], removed: [1<50>, 3<50>, 4<25>, 5<25>, 2<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 0, "consumed"), - ("2_2", 0, "consumed"), - ("2_3", 0, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 1, txt_status() - assert pool.current_cache_size == 25_000 - - # Delete the last reference to large_file6, it should be removed from the cache - lazy6_2 = None - gc.collect() - - # Cache should be empty now - # Queue state: [], cached out: [], removed: [1<50>, 3<50>, 4<25>, 5<25>, 6<25>, 2<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 0, "consumed"), - ("2_2", 0, "consumed"), - ("2_3", 0, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 0, "consumed"), - ], txt_status() - assert pool.current_cache_count == 0, txt_status() - assert pool.current_cache_size == 0 - # Check that the cache directory is empty - assert not list(pool.cache_dir.glob("*")) - finally: - pool.close() - - def test_raw_method(self): - """Test the 'raw' caching method with DecodeFileStore""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="raw") - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } - ) - mock_decode_file_store = DecodeFileStore( - inner=mock_raw_file_store, - decoder=MockDecoder(), - ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") - - # Wait for background task - time.sleep(0.5) - - # Get the data - should be decoded - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == "file1: test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() + else: + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] pending" + ) + return ( + f"Cached Count: {pool.current_cache_count}, Cache size: {pool.current_cache_size}\n" + + "\n".join(out) ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_pickle_method(self): - """Test the 'pickle' caching method""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="pickle") - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } + + # lazy2_2 and lazy2_3 should share the same entry as lazy2 + assert lazy2_2.entry is lazy2.entry + assert lazy2_3.entry is lazy2.entry + + lazy1.entry.send_to_cache_future.result(timeout=1) + # Wait for the background tasks to finish + time.sleep(0.5) + + print("Checking cache status") + # They should not be able to finish, because the cache is full + # Queue state: [2<50>, 3<50>, 4<25>, 5<25>, 6<25>], cached out: [1<50>], removed: [] + assert status() == [ + ("1", 1, "cached"), + ("2", 3, "pending"), + ("2_2", 3, "pending"), + ("2_3", 3, "pending"), + ("3", 2, "pending"), + ("3_2", 2, "pending"), + ("4", 1, "pending"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + + # Check cache count and size before second file + assert pool.current_cache_count == 1, pool.current_cache_count + assert pool.current_cache_size == 50_000, pool.current_cache_size + + print("Fetching lazy2_3") + # Now, fetching the second file should still work directly and ignore the caching + # But it will requeue fetching the second file to the background thread for the remaining lazies. + result2_3 = lazy2_3.get() + assert result2_3 == b"b" * 50_000 + + # They should not be able to finish, because the cache is full + # Queue state: [3<50>, 4<25>, 5<25>, 6<25>, 2<50>], cached out: [1<50>], removed: [] + assert status() == [ + ("1", 1, "cached"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 2, "pending"), + ("3_2", 2, "pending"), + ("4", 1, "pending"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + + # Fetch + result1 = lazy1.get() + assert result1 == b"a" * 50_000 + + lazy3.entry.send_to_cache_future.result(timeout=1) + + time.sleep(0.5) + + # Second file is now queued at the end. + # File 3 and 4 should now be cached. + # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 2, "cached"), + ("3_2", 2, "cached"), + ("4", 1, "cached"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result3 = lazy3.get() + assert result3 == b"c" * 50_000 + + time.sleep(0.5) + + # Space by large_file3 is still occupied in cache + # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 1, "consumed"), + ("3_2", 1, "cached"), + ("4", 1, "cached"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result3_2 = lazy3_2.get() + assert result3_2 == b"c" * 50_000 + + time.sleep(0.5) + + # Space by large_file3 was freed now, 4, 5, and 6 should fit now, large_file2 not yet + # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 1, "cached"), + ("5", 1, "cached"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 50_000 + + result4 = lazy4.get() + assert result4 == b"d" * 25_000 + + time.sleep(0.5) + + # Nothing changed, no space for large_file2 still + # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>, 4<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 1, "cached"), + ("6", 2, "cached"), + ("6_2", 2, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 50_000 + + result5 = lazy5.get() + assert result5 == b"e" * 25_000 + + time.sleep(0.5) + + # Now large_file2 can be cached + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "cached"), + ("2_2", 2, "cached"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 2, "cached"), + ("6_2", 2, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result6 = lazy6.get() + assert result6 == b"f" * 25_000 + + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "cached"), + ("2_2", 2, "cached"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result2 = lazy2.get() + assert result2 == b"b" * 50_000 + + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 1, "consumed"), + ("2_2", 1, "cached"), + ("2_3", 1, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result2_2 = lazy2_2.get() + assert result2_2 == b"b" * 50_000 + + # Cache should only contain large_file6 now + # Queue state: [], cached out: [6<25>], removed: [1<50>, 3<50>, 4<25>, 5<25>, 2<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 0, "consumed"), + ("2_2", 0, "consumed"), + ("2_3", 0, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 1, txt_status() + assert pool.current_cache_size == 25_000 + + # Delete the last reference to large_file6, it should be removed from the cache + lazy6_2 = None + gc.collect() + + # Cache should be empty now + # Queue state: [], cached out: [], removed: [1<50>, 3<50>, 4<25>, 5<25>, 6<25>, 2<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 0, "consumed"), + ("2_2", 0, "consumed"), + ("2_3", 0, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 0, "consumed"), + ], txt_status() + assert pool.current_cache_count == 0, txt_status() + assert pool.current_cache_size == 0 + # Check that the cache directory is empty + assert not list(pool.cache_dir.glob("*")) + finally: + pool.close() + + +def test_raw_method(temp_dir): + """Test the 'raw' caching method with DecodeFileStore""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir, method="raw") + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, + decoder=MockDecoder(), + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") + + # Wait for background task + time.sleep(0.5) + + # Get the data - should be decoded + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == "file1: test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - mock_decode_file_store = DecodeFileStore( - inner=mock_raw_file_store, - decoder=MockDecoder(), + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_pickle_method(temp_dir): + """Test the 'pickle' caching method""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir, method="pickle") + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + mock_decode_file_store = DecodeFileStore( + inner=mock_raw_file_store, + decoder=MockDecoder(), + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") + + # Wait for background task + lazy_ref.entry.send_to_cache_future.result() + + # Get the data - should be unpickled correctly + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == "file1: test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") - - # Wait for background task - lazy_ref.entry.send_to_cache_future.result() - - # Get the data - should be unpickled correctly - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == "file1: test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - - # Request lazy loading - lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") - - # Wait for background task - lazy_ref.entry.send_to_cache_future.result() - - # Get the data - should be unpickled correctly - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == b"test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_raw_file_store.get_path() - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_concurrent_access(self): - """Test concurrent access to the cache pool""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + + # Request lazy loading + lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") + + # Wait for background task + lazy_ref.entry.send_to_cache_future.result() + + # Get the data - should be unpickled correctly + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == b"test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path() ) - results = [] - - def worker(filename): - lazy_ref = pool.get_lazy(mock_raw_file_store, filename) - result, source_info = lazy_ref.get() - results.append(result) - assert source_info.dataset_path == mock_raw_file_store.get_path() - assert source_info.index is None - assert source_info.shard_name is None - assert source_info.file_names == (filename,) - - try: - # Start multiple threads accessing the same file - threads = [] - for i in range(5): - t = threading.Thread(target=worker, args=("file1",)) - threads.append(t) - t.start() - - # Wait for all threads to complete - for t in threads: - t.join() - - # All threads should get the correct result - for r in results: - assert r == b"test data 1" - finally: - pool.close() - - def test_to_cache(self): - """Test that the cache out method works""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - try: - # Get the data - should be pickled / unpickled correctly - result = pool.to_cache((1, "some_data", 2), "file1") - - cache_path = result.cache_path - - # Check that the cache file exists - assert cache_path is not None - assert cache_path.is_file() - assert pool.cache_dir == cache_path.parent - - # Verify that the data is read correctly, also two times. - assert result.get() == (1, "some_data", 2) - assert result.get() == (1, "some_data", 2) - - # Verify that the cache file is deleted now that we've read the data. - assert result.cache_path is None - assert not cache_path.is_file() - - # Verify that the cache file is deleted when the object is deleted before reading the file. - result2 = pool.to_cache((1, "some_data", 2), "file2") - assert result2.cache_path is not None - assert result2.cache_path.is_file() - assert result2.cache_path != cache_path - cache_path = result2.cache_path - del result2 - gc.collect() - assert not cache_path.is_file() - finally: - pool.close() - - -if __name__ == "__main__": - unittest.main() + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_concurrent_access(temp_dir): + """Test concurrent access to the cache pool""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + results = [] + + def worker(filename): + lazy_ref = pool.get_lazy(mock_raw_file_store, filename) + result, source_info = lazy_ref.get() + results.append(result) + assert source_info.dataset_path == mock_raw_file_store.get_path() + assert source_info.index is None + assert source_info.shard_name is None + assert source_info.file_names == (filename,) + + try: + # Start multiple threads accessing the same file + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=("file1",)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # All threads should get the correct result + for r in results: + assert r == b"test data 1" + finally: + pool.close() + + +def test_to_cache(temp_dir): + """Test that the cache out method works""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + try: + # Get the data - should be pickled / unpickled correctly + result = pool.to_cache((1, "some_data", 2), "file1") + + cache_path = result.cache_path + + # Check that the cache file exists + assert cache_path is not None + assert cache_path.is_file() + assert pool.cache_dir == cache_path.parent + + # Verify that the data is read correctly, also two times. + assert result.get() == (1, "some_data", 2) + assert result.get() == (1, "some_data", 2) + + # Verify that the cache file is deleted now that we've read the data. + assert result.cache_path is None + assert not cache_path.is_file() + + # Verify that the cache file is deleted when the object is deleted before reading the file. + result2 = pool.to_cache((1, "some_data", 2), "file2") + assert result2.cache_path is not None + assert result2.cache_path.is_file() + assert result2.cache_path != cache_path + cache_path = result2.cache_path + del result2 + gc.collect() + assert not cache_path.is_file() + finally: + pool.close() diff --git a/tests/test_jsonl_dataset.py b/tests/test_jsonl_dataset.py index 0ffe3192..d2a51270 100644 --- a/tests/test_jsonl_dataset.py +++ b/tests/test_jsonl_dataset.py @@ -9,12 +9,12 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path from typing import Iterable +import pytest import torch from click.testing import CliRunner @@ -58,95 +58,94 @@ class SimpleCookingTaskEncoder(DefaultTaskEncoder): cookers = [Cooker(cook=cook_text)] -class TestJsonlDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - random.seed(42) - - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy datasets - self.create_text_test_dataset(self.dataset_path / "ds1.jsonl", range(55), range(55)) - self.create_text_test_dataset( - self.dataset_path / "ds2.jsonl", range(100, 155), range(100, 155) - ) - self.create_text_test_dataset(self.dataset_path / "ds3.jsonl", range(200, 255), range(55)) - - self.mds_all_path = self.dataset_path / "metadataset_all.yaml" - with open(self.mds_all_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - path: ds1.jsonl", - " subflavors:", - " ds: ds1", - " - path: ds2.jsonl", - " subflavors:", - " ds: ds2", - " - path: ds3.jsonl", - " subflavors:", - " ds: ds3", - ] - ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + random.seed(42) + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy datasets + create_text_test_dataset(dataset_path / "ds1.jsonl", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2.jsonl", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3.jsonl", range(200, 255), range(55)) + + mds_all_path = dataset_path / "metadataset_all.yaml" + with open(mds_all_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - path: ds1.jsonl", + " subflavors:", + " ds: ds1", + " - path: ds2.jsonl", + " subflavors:", + " ds: ds2", + " - path: ds3.jsonl", + " subflavors:", + " ds: ds3", + ] ) + ) - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() + return dataset_path - @staticmethod - def create_text_test_dataset( - path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" - ): - """Creates a small dummy test dataset for testing purposes.""" - # Write jsonl file - with open(path, "w") as wf: - for key, txt in zip(key_range, txt_range): - # Write JSON entries to the file, one per line. - wf.write(json.dumps({"idx": key, "txt": f"{prefix}{txt}"}) + "\n") +def create_text_test_dataset( + path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" +): + """Creates a small dummy test dataset for testing purposes.""" - from megatron.energon.flavors import CrudeJsonlDatasetFactory + # Write jsonl file + with open(path, "w") as wf: + for key, txt in zip(key_range, txt_range): + # Write JSON entries to the file, one per line. + wf.write(json.dumps({"idx": key, "txt": f"{prefix}{txt}"}) + "\n") - CrudeJsonlDatasetFactory.prepare_dataset(path) + from megatron.energon.flavors import CrudeJsonlDatasetFactory - def test_dataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) + CrudeJsonlDatasetFactory.prepare_dataset(path) - # Train mode dataset - train_dataset = get_train_dataset( - self.dataset_path / "ds1.jsonl", - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=SimpleCookingTaskEncoder(), - ) - print(len(train_dataset)) - assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}" - train_loader1 = get_loader(train_dataset) +def test_dataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "ds1.jsonl", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + ) + print(len(train_dataset)) + assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}" + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -155,29 +154,29 @@ def test_dataset(self): assert len(Counter(train_order1)) == 55 assert all(v == 10 for v in Counter(train_order1).values()) - def test_metadataset_all(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - seed_offset=42, - ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_all_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=SimpleCookingTaskEncoder(), - ) - print(len(train_dataset)) - assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}" +def test_metadataset_all(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) - train_loader1 = get_loader(train_dataset) + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_all.yaml", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}" + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -186,124 +185,125 @@ def test_metadataset_all(self): assert len(Counter(train_order1)) == 55 * 3 assert all(2 <= v <= 5 for v in Counter(train_order1).values()) - def test_metadataset_multirank(self): - torch.manual_seed(42) - sample_counts = Counter() - expected_lens = [19, 19, 17] +def test_metadataset_multirank(dataset_path): + torch.manual_seed(42) - for cur_rank in range(3): - worker_config = WorkerConfig( - rank=cur_rank, - world_size=3, - num_workers=5, - seed_offset=42, - ) + sample_counts = Counter() + expected_lens = [19, 19, 17] - # Train mode dataset - train_dataset = get_train_dataset( - self.dataset_path / "ds1.jsonl", - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=SimpleCookingTaskEncoder(), - repeat=False, - ) - print(len(train_dataset)) - assert len(train_dataset) == expected_lens[cur_rank], ( - f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}" - ) + for cur_rank in range(3): + worker_config = WorkerConfig( + rank=cur_rank, + world_size=3, + num_workers=5, + seed_offset=42, + ) - train_loader1 = get_loader(train_dataset) + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "ds1.jsonl", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + repeat=False, + ) + print(len(train_dataset)) + assert len(train_dataset) == expected_lens[cur_rank], ( + f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}" + ) + with get_loader(train_dataset) as train_loader1: for data in train_loader1: sample_counts[int(data.text[0])] += 1 - for i in range(55): - assert sample_counts[i] == 1, ( - f"Sample {i} should have been seen exactly once, but was seen {sample_counts[i]} times." - ) + for i in range(55): + assert sample_counts[i] == 1, ( + f"Sample {i} should have been seen exactly once, but was seen {sample_counts[i]} times." + ) - def test_s3(self): - # Create a joined dataset configuration - mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml" - with open(mixed_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: msc://s3test_jsonl_dataset/test/dataset/metadataset_all.yaml", - ] - ) - ) - with setup_s3_emulator(profile_name="s3test_jsonl_dataset") as emu: - # Upload the dataset to the S3 emulator - # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) - emu.add_file(self.dataset_path, "test/dataset") - - train_dataset = get_loader( - get_train_dataset( - mixed_mds_path, - worker_config=WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ), - batch_size=1, - shuffle_buffer_size=10, - max_samples_per_sequence=None, - virtual_epoch_length=55 * 10, - task_encoder=SimpleCookingTaskEncoder(), - ) +def test_s3(dataset_path): + # Create a joined dataset configuration + mixed_mds_path = dataset_path / "metadataset_mixed.yaml" + with open(mixed_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: msc://s3test_jsonl_dataset/test/dataset/metadataset_all.yaml", + ] ) - - data = list(enumerate(train_dataset)) - assert len(data) == 55 * 10, len(data) - cnt = Counter(t for _, entry in data for t in entry.text) - assert len(cnt) == 55 * 3 - assert all(2 <= v <= 5 for v in cnt.values()) - - def test_prepare(self): - print("Creating new dataset") - with open(self.dataset_path / "ds_prep.jsonl", "w") as f: - for i in range(10): - f.write(json.dumps({"idx": i, "txt": f"{i}"}) + "\n\n") - - runner = CliRunner() - result = runner.invoke( - prepare_command, - [str(self.dataset_path / "ds_prep.jsonl")], - catch_exceptions=False, ) - print(result.stdout) - assert result.exit_code == 0, "Prepare failed, see output" - assert "Done" in result.stdout, "Prepare failed, see output" - assert "Found 10 samples" in result.stdout, "Prepare failed, see output" - assert (self.dataset_path / "ds_prep.jsonl.idx").exists() - torch.manual_seed(42) + with setup_s3_emulator(profile_name="s3test_jsonl_dataset") as emu: + # Upload the dataset to the S3 emulator + # EPath(dataset_path).copy(EPath("msc://s3/test/dataset")) + emu.add_file(dataset_path, "test/dataset") - # Train mode dataset - train_loader = get_loader( + with get_loader( get_train_dataset( - self.dataset_path / "ds_prep.jsonl", + mixed_mds_path, worker_config=WorkerConfig( rank=0, world_size=1, - num_workers=0, - seed_offset=42, + num_workers=2, ), batch_size=1, - shuffle_buffer_size=None, + shuffle_buffer_size=10, max_samples_per_sequence=None, + virtual_epoch_length=55 * 10, task_encoder=SimpleCookingTaskEncoder(), ) + ) as train_dataset: + data = list(enumerate(train_dataset)) + assert len(data) == 55 * 10, len(data) + cnt = Counter(t for _, entry in data for t in entry.text) + assert len(cnt) == 55 * 3 + assert all(2 <= v <= 5 for v in cnt.values()) + + +def test_prepare(dataset_path): + print("Creating new dataset") + with open(dataset_path / "ds_prep.jsonl", "w") as f: + for i in range(10): + f.write(json.dumps({"idx": i, "txt": f"{i}"}) + "\n\n") + + runner = CliRunner() + result = runner.invoke( + prepare_command, + [str(dataset_path / "ds_prep.jsonl")], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + assert "Found 10 samples" in result.stdout, "Prepare failed, see output" + assert (dataset_path / "ds_prep.jsonl.idx").exists() + + torch.manual_seed(42) + + # Train mode dataset + with get_loader( + get_train_dataset( + dataset_path / "ds_prep.jsonl", + worker_config=WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ), + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), ) + ) as train_loader: assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}" train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text] @@ -311,7 +311,3 @@ def test_prepare(self): print(Counter(train_order1)) assert len(Counter(train_order1)) == 10 assert all(v == 5 for v in Counter(train_order1).values()) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index 4126a295..4ea9d323 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -3,17 +3,19 @@ """This module defines tests for meta datasets.""" +import dataclasses import gc import logging import sys import tempfile import time -import unittest import warnings from collections import Counter from pathlib import Path from typing import Any, Iterable +import numpy as np +import pytest import torch import webdataset as wds @@ -81,14 +83,13 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: Raises: AssertionError: If a mismatch is found. """ - # Check if types differ if type(a) is not type(b): + # Check if types differ mismatch_details = f"Type mismatch at {path or ''}: {type(a)} != {type(b)}" print(mismatch_details) raise AssertionError(mismatch_details) - - # If they are both dictionaries, compare each key and value if isinstance(a, dict): + # If they are both dictionaries, compare each key and value # Check if they have the same keys a_keys = set(a.keys()) b_keys = set(b.keys()) @@ -109,9 +110,8 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: for key in a: sub_path = f"{path}['{key}']" if path else f"['{key}']" assert_nested_equal(a[key], b[key], sub_path) - - # If they are lists (or tuples), compare elements in order elif isinstance(a, (list, tuple)): + # If they are lists (or tuples), compare elements in order if len(a) != len(b): mismatch_details = f"Length mismatch at {path or ''}: {len(a)} != {len(b)}" print(mismatch_details) @@ -119,169 +119,192 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: for index, (item_a, item_b) in enumerate(zip(a, b)): sub_path = f"{path}[{index}]" if path else f"[{index}]" assert_nested_equal(item_a, item_b, sub_path) - - # Otherwise, compare values directly + elif isinstance(a, torch.Tensor): + if a.shape != b.shape: + mismatch_details = f"Shape mismatch at {path or ''}: {a.shape} != {b.shape}" + print(mismatch_details) + raise AssertionError(mismatch_details) + if not torch.all(a == b): + mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" + print(mismatch_details) + raise AssertionError(mismatch_details) + elif isinstance(a, np.ndarray): + if a.shape != b.shape: + mismatch_details = f"Shape mismatch at {path or ''}: {a.shape} != {b.shape}" + print(mismatch_details) + raise AssertionError(mismatch_details) + if not np.all(a == b): + mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" + print(mismatch_details) + raise AssertionError(mismatch_details) + elif dataclasses.is_dataclass(a): + for field in dataclasses.fields(a): + assert_nested_equal( + getattr(a, field.name), getattr(b, field.name), f"{path}.{field.name}" + ) else: + # Otherwise, compare values directly if a != b: mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" print(mismatch_details) raise AssertionError(mismatch_details) -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55)) - - self.mds_path = self.dataset_path / "metadataset.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: Metadataset", - "splits:", - " train:", - " datasets:", - " - weight: 1", - " path: ds1", - " subflavor: ds1", - " subflavors:", - " source: metadataset.yaml", - " number: 43", - " mds: mds", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavor: ds2", - " subflavors:", - " source: metadataset.yaml", - " number: 44", - " mds: mds", - " val:", - " datasets:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(0, 55)) + + mds_path = dataset_path / "metadataset.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: Metadataset", + "splits:", + " train:", + " datasets:", + " - weight: 1", + " path: ds1", + " subflavor: ds1", + " subflavors:", + " source: metadataset.yaml", + " number: 43", + " mds: mds", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavor: ds2", + " subflavors:", + " source: metadataset.yaml", + " number: 44", + " mds: mds", + " val:", + " datasets:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] ) - self.nested_mds_path = self.dataset_path / "nested_metadataset.yaml" - with open(self.nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "splits:", - " train:", - " datasets:", - " - weight: 4", - " path: ./metadataset.yaml", - " split_part: train", - " subflavor: train", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_train", - " - path: ./metadataset.yaml", - " split_part: val", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_val", - ] - ) + ) + nested_mds_path = dataset_path / "nested_metadataset.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "splits:", + " train:", + " datasets:", + " - weight: 4", + " path: ./metadataset.yaml", + " split_part: train", + " subflavor: train", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_train", + " - path: ./metadataset.yaml", + " split_part: val", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_val", + ] ) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, ) + print(dataset_path) + return dataset_path - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -295,24 +318,23 @@ def test_metadataset(self): for idx, data in zip(range(55), train_loader1) for subflavor in data.__subflavors__ ] - print(train_subflavors[:10]) - print(Counter(train_subflavors)) + print("train_subflavors[:10]", train_subflavors[:10]) + print("Counter(train_subflavors)", Counter(train_subflavors)) assert len(Counter(train_subflavors)) == 2 assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=25, - max_samples_per_sequence=25, - ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - - train_loader1 = get_loader(train_dataset) - + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=25, + max_samples_per_sequence=25, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -321,82 +343,83 @@ def test_metadataset(self): assert len(Counter(train_order1)) == 110 assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - # Val mode dataset - val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) - print(len(val_dataset)) - assert len(val_dataset) == 11 - - val_loader1 = get_loader(val_dataset) + # Val mode dataset + val_dataset = get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ) + print(len(val_dataset)) + assert len(val_dataset) == 11 + with get_loader(val_dataset) as val_loader1: val_order1 = [text for data in val_loader1 for text in data.text] assert len(val_order1) == 110 print(Counter(val_order1)) assert all(v == 1 for v in Counter(val_order1).values()) - def test_nested_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - ) - - dataset = load_dataset(self.nested_mds_path) - - raw_datasets = dataset.get_datasets( - training=False, split_part="train", worker_config=worker_config - ) - assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT - assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [0.4, 0.4, 0.1, 0.1] - assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ - "ds1", - "ds2", - "ds1", - "ds2", - ] - print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) - assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "nested_train", - "__subflavor__": "train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "nested_train", - "__subflavor__": "train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - ] - - # Train mode dataset - train_dataset = get_train_dataset( - self.nested_mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 22 - - train_loader1 = get_loader(train_dataset) +def test_nested_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + dataset = load_dataset(dataset_path / "nested_metadataset.yaml") + + raw_datasets = dataset.get_datasets( + training=False, split_part="train", worker_config=worker_config + ) + assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [0.4, 0.4, 0.1, 0.1] + assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ + "ds1", + "ds2", + "ds1", + "ds2", + ] + print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) + assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "nested_train", + "__subflavor__": "train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "nested_train", + "__subflavor__": "train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + ] + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 22 + + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -466,19 +489,18 @@ def test_nested_metadataset(self): < avg * 1 + 20 ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=25, - max_samples_per_sequence=25, - ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - - train_loader1 = get_loader(train_dataset) - + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=25, + max_samples_per_sequence=25, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -487,167 +509,170 @@ def test_nested_metadataset(self): assert len(Counter(train_order1)) == 110 assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - # Val mode dataset - val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) - print(len(val_dataset)) - assert len(val_dataset) == 11 - - val_loader1 = get_loader(val_dataset) + # Val mode dataset + val_dataset = get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ) + print(len(val_dataset)) + assert len(val_dataset) == 11 + with get_loader(val_dataset) as val_loader1: val_order1 = [text for data in val_loader1 for text in data.text] assert len(val_order1) == 110 print(Counter(val_order1)) assert all(v == 1 for v in Counter(val_order1).values()) - def test_worker_sample_balance(self): - torch.manual_seed(42) - for num_workers in [6, 30]: - samples_per_global_worker = Counter() +def test_worker_sample_balance(dataset_path): + torch.manual_seed(42) - for rank in range(2): - wc = WorkerConfig( - rank=rank, - world_size=2, - num_workers=num_workers, - ) + for num_workers in [6, 30]: + samples_per_global_worker = Counter() - train_dataset = get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) + for rank in range(2): + wc = WorkerConfig( + rank=rank, + world_size=2, + num_workers=num_workers, + ) - blend_dataset = get_blend_dataset(train_dataset) - assert isinstance(blend_dataset, BlendDataset) + train_dataset = get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) - ds_weights = blend_dataset.dataset_weights - assert len(ds_weights) == 4 # 4 datasets + blend_dataset = get_blend_dataset(train_dataset) + assert isinstance(blend_dataset, BlendDataset) - # We are now going to count the number of samples that was assigned to each - # globally unique worker. This corresponds to the shard_ranges that energon - # prints out when the dataset is built. + ds_weights = blend_dataset.dataset_weights + assert len(ds_weights) == 4 # 4 datasets - for ds, w in ds_weights: - worker_slice_offsets = ds.dataset.dataset.workers_slice_offsets - assert len(worker_slice_offsets) == num_workers + # We are now going to count the number of samples that was assigned to each + # globally unique worker. This corresponds to the shard_ranges that energon + # prints out when the dataset is built. - for worker_idx, slice_offsets in enumerate(worker_slice_offsets): - samples_per_global_worker[(rank, worker_idx)] += ( - slice_offsets[-1] - slice_offsets[0] - ) - print(samples_per_global_worker) - - # Check the sample assignnent is balanced across all global workers - if num_workers == 6: - assert list(samples_per_global_worker.values()) == [ - 19, # rank 0 - 18, - 18, - 19, - 18, - 18, - 19, # rank 1 - 18, - 18, - 19, - 18, - 18, - ] - elif num_workers == 30: - # This should match the pattern of the first 40 items of a generalized bit - # reversal sequence of length 60. - # Given 4 * 55 = 220 samples modulo 60 workers, is 40 remaining samples - assert list(samples_per_global_worker.values()) == [ - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - ] + for ds, w in ds_weights: + worker_slice_offsets = ds.dataset.dataset.workers_slice_offsets + assert len(worker_slice_offsets) == num_workers + + for worker_idx, slice_offsets in enumerate(worker_slice_offsets): + samples_per_global_worker[(rank, worker_idx)] += ( + slice_offsets[-1] - slice_offsets[0] + ) + print(samples_per_global_worker) + + # Check the sample assignnent is balanced across all global workers + if num_workers == 6: + assert list(samples_per_global_worker.values()) == [ + 19, # rank 0 + 18, + 18, + 19, + 18, + 18, + 19, # rank 1 + 18, + 18, + 19, + 18, + 18, + ] + elif num_workers == 30: + # This should match the pattern of the first 40 items of a generalized bit + # reversal sequence of length 60. + # Given 4 * 55 = 220 samples modulo 60 workers, is 40 remaining samples + assert list(samples_per_global_worker.values()) == [ + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + ] - def test_save_restore_state_train(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) +def test_save_restore_state_train(dataset_path): + torch.manual_seed(42) - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - shuffle_over_epochs_multiplier=2, - ), - ) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + def new_loader(): + return get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + parallel_shard_iters=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + shuffle_over_epochs_multiplier=2, + ), + ) - # Train mode dataset - loader = new_loader() + # Train mode dataset + with new_loader() as loader: state_0 = loader.save_state_rank() order_0 = [data.text for idx, data in zip(range(10), loader)] state_1 = loader.save_state_rank() @@ -682,315 +707,264 @@ def new_loader(): # Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that order_6 = [data.text for idx, data in zip(range(70), loader)] - loader = new_loader() + with new_loader().with_restored_state_rank(state_1) as loader: print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] assert order_1 == order_1_rest - loader = new_loader() - loader.restore_state_rank(state_0) + with new_loader().with_restored_state_rank(state_0) as loader: order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] assert order_0 == order_0_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_2) as loader: print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] print("order_2:", order_2) print("order_2_rest:", order_2_rest) assert order_2 == order_2_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_3) as loader: print("state_3:", _norng_state(state_3)) - loader.restore_state_rank(state_3) order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] print("order_3:", order_3) print("order_3_rest:", order_3_rest) assert order_3 == order_3_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_4) as loader: print("state_4:", _norng_state(state_4)) - loader.restore_state_rank(state_4) order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] print("order_4:", order_4) print("order_4_rest:", order_4_rest) assert order_4 == order_4_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_5) as loader: print("state_5:", _norng_state(state_5)) - loader.restore_state_rank(state_5) order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] print("order_5:", order_5) print("order_5_rest:", order_5_rest) assert order_5 == order_5_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_6) as loader: print("state_6:", _norng_state(state_6)) - loader.restore_state_rank(state_6) order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] print("order_6:", order_6) print("order_6_rest:", order_6_rest) assert order_6 == order_6_rest - wrk_cfg = worker_config.config() - assert wrk_cfg == { - "rank": 0, - "world_size": 1, - "num_workers": 0, - "data_parallel_group": None, - } - print("loader.config():") - print(loader.config()) - print() - reference_config = { - "type": "SavableDataLoader", - "num_workers": 0, - "persistent_workers": False, - "pin_memory": True, - "prefetch_factor": None, + wrk_cfg = worker_config.config() + assert wrk_cfg == { + "rank": 0, + "world_size": 1, + "num_workers": 0, + "data_parallel_group": None, + } + print("loader.config():") + print(loader.config()) + print() + reference_config = { + "type": "MapDataset", + "dataset": { + "type": "BatchDataset", + "batch_size": 10, + "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", + "batcher_stateless": True, + "drop_last": False, + "worker_config": wrk_cfg, "dataset": { "type": "MapDataset", "dataset": { - "type": "BatchDataset", - "batch_size": 10, - "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", - "batcher_stateless": True, - "drop_last": False, - "worker_config": wrk_cfg, - "dataset": { - "type": "MapDataset", - "dataset": { - "type": "BlendDataset", - "dataset_weights": [ - ( - { - "type": "RepeatDataset", - "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 6, - "parallel_slice_iters": 2, + "type": "BlendDataset", + "dataset_weights": [ + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", + "dataset": { + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], + "worker_config": wrk_cfg, + "shuffle_over_epochs": 6, + "parallel_slice_iters": 2, + }, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(dataset_path / "ds1"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-0.tar"), + }, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-1.tar"), + }, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-2.tar"), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-3.tar"), }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds1"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path - / "ds1/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 6, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subset": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "mds", - "__subflavor__": "ds1", - }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-4.tar"), }, - "map_fn_stateless": True, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str(dataset_path / "ds1/parts/data-5.tar"), + }, + ], + "sample_excludes": [], + "shuffle_over_epochs": 6, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subset": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "mds", + "__subflavor__": "ds1", }, - "repeats": None, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, + }, + "map_fn_stateless": True, + }, + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", + "dataset": { + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], "worker_config": wrk_cfg, + "shuffle_over_epochs": 2, + "parallel_slice_iters": 2, }, - 0.5, - ), - ( - { - "type": "RepeatDataset", - "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 2, - "parallel_slice_iters": 2, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(dataset_path / "ds2"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-0.tar"), + }, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-1.tar"), + }, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-2.tar"), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-3.tar"), + }, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-4.tar"), }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds2"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path - / "ds2/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 2, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subset": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "mds", - "__subflavor__": "ds2", - }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str(dataset_path / "ds2/parts/data-5.tar"), }, - "map_fn_stateless": True, + ], + "sample_excludes": [], + "shuffle_over_epochs": 2, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subset": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "mds", + "__subflavor__": "ds2", }, - "repeats": None, - "worker_config": wrk_cfg, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, }, - 0.5, - ), - ], - "worker_config": wrk_cfg, - }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", - "map_fn_stateless": True, - }, + "map_fn_stateless": True, + }, + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ], + "worker_config": wrk_cfg, }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", "map_fn_stateless": True, }, - } - print("Comparing dataset configs in test_save_restore_state_train.") - assert_nested_equal(loader.config(), reference_config) - - def test_save_restore_state_train_workers(self): - torch.manual_seed(42) - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=1, - seed_offset=42, + }, + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", + "map_fn_stateless": True, + } + print("Comparing dataset configs in test_save_restore_state_train.") + assert_nested_equal(loader.config(), reference_config) + + +def test_save_restore_state_train_workers(dataset_path): + torch.manual_seed(42) + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + seed_offset=42, + ) + + def new_loader(): + return get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + parallel_shard_iters=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), ) - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0.5, - checkpoint_every_min_n_samples=1, - ) - - # Train mode dataset - loader = new_loader() + # Train mode dataset + with new_loader() as loader: state_0 = loader.save_state_rank() order_0 = [data.text for idx, data in zip(range(10), loader)] time.sleep(0.5) @@ -1038,85 +1012,79 @@ def new_loader(): # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that order_6 = [data.text for idx, data in zip(range(10), loader)] - loader = new_loader() + with new_loader().with_restored_state_rank(state_1) as loader: print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] print("order_1:", order_1) print("order_1_rest:", order_1_rest) assert order_1 == order_1_rest - loader = new_loader() - loader.restore_state_rank(state_0) + with new_loader().with_restored_state_rank(state_0) as loader: order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] assert order_0 == order_0_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_2) as loader: print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] print("order_2:", order_2) print("order_2_rest:", order_2_rest) assert order_2 == order_2_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_3) as loader: print("state_3:", _norng_state(state_3)) - loader.restore_state_rank(state_3) order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] print("order_3:", order_3) print("order_3_rest:", order_3_rest) assert order_3 == order_3_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_4) as loader: print("state_4:", _norng_state(state_4)) - loader.restore_state_rank(state_4) order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] print("order_4:", order_4) print("order_4_rest:", order_4_rest) assert order_4 == order_4_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_5) as loader: print("state_5:", _norng_state(state_5)) - loader.restore_state_rank(state_5) order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] print("order_5:", order_5) print("order_5_rest:", order_5_rest) assert order_5 == order_5_rest - loader = new_loader() + with new_loader().with_restored_state_rank(state_6) as loader: print("state_6:", _norng_state(state_6)) - loader.restore_state_rank(state_6) order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] print("order_6:", order_6) print("order_6_rest:", order_6_rest) assert order_6 == order_6_rest - def test_save_restore_state_train_epochize_workers(self): - torch.manual_seed(42) - psi = 2 - vel = 19 - sbs = 10 - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - seed_offset=42, - ) - # Train mode dataset - torch.manual_seed(42) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ) +def test_save_restore_state_train_epochize_workers(dataset_path): + torch.manual_seed(42) + psi = 2 + vel = 19 + sbs = 10 + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ) as loader: state_0 = loader.save_state_rank() order_1 = [data.text[0] for data in loader] state_1 = loader.save_state_rank() @@ -1124,196 +1092,194 @@ def test_save_restore_state_train_epochize_workers(self): state_2 = loader.save_state_rank() order_3 = [data.text[0] for idx, data in zip(range(17), loader)] - torch.manual_seed(42) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ) + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_0) as loader: print("state_0:", _norng_state(state_0)) - loader.restore_state_rank(state_0) order_5 = [data.text[0] for data in loader] print("order_1:", order_1) print("order_5:", order_5) assert order_1 == order_5 - torch.manual_seed(42) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ) + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_1) as loader: print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) order_6 = [data.text[0] for data in loader] print("order_2:", order_2) print("order_6:", order_6) assert order_2 == order_6 - torch.manual_seed(42) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ) + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_2) as loader: print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) order_7 = [data.text[0] for idx, data in zip(range(17), loader)] print("order_3:", order_3) print("order_7:", order_7) assert order_3 == order_7 - def test_save_restore_state_val(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) +def test_save_restore_state_val(dataset_path): + torch.manual_seed(42) - # Train mode dataset - loader = get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ) as loader: state_0 = loader.save_state_rank() order_1 = [data.text for idx, data in zip(range(55 * 20), loader)] state_1 = loader.save_state_rank() # print("save state done") order_2 = [data.text for idx, data in zip(range(55 * 20), loader)] - loader = get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) - loader.restore_state_rank(state_1) + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ).with_restored_state_rank(state_1) as loader: order_3 = [data.text for idx, data in zip(range(55 * 20), loader)] assert order_2 == order_3 - loader = get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) - loader.restore_state_rank(state_0) + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ).with_restored_state_rank(state_0) as loader: order_4 = [data.text for idx, data in zip(range(55 * 20), loader)] assert order_1 == order_4 - def test_blending_randomness(self): - import random - import numpy +def test_blending_randomness(dataset_path): + import random - for num_workers in [0, 1, 2]: # Especially also check the num_workers=0 case - world_size = 4 - micro_batch_size = 1 - seed = 42 + import numpy - configs = ( - WorkerConfig(rank=0, world_size=world_size, num_workers=num_workers), - WorkerConfig(rank=1, world_size=world_size, num_workers=num_workers), - WorkerConfig(rank=2, world_size=world_size, num_workers=num_workers), - ) + for num_workers in [0, 1, 2]: # Especially also check the num_workers=0 case + world_size = 4 + micro_batch_size = 1 + seed = 42 - all_ranks_subflavors = [] - for rank_config in configs: - torch.manual_seed(seed) - numpy.random.seed(seed) - random.seed(seed) - - ds = get_train_dataset( - self.mds_path, - split_part="train", - worker_config=rank_config, - batch_size=micro_batch_size, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - loader = get_loader(ds) + configs = ( + WorkerConfig(rank=0, world_size=world_size, num_workers=num_workers), + WorkerConfig(rank=1, world_size=world_size, num_workers=num_workers), + WorkerConfig(rank=2, world_size=world_size, num_workers=num_workers), + ) + all_ranks_subflavors = [] + for rank_config in configs: + torch.manual_seed(seed) + numpy.random.seed(seed) + random.seed(seed) + + ds = get_train_dataset( + dataset_path / "metadataset.yaml", + split_part="train", + worker_config=rank_config, + batch_size=micro_batch_size, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_loader(ds) as loader: subflavors = [ data.__subflavors__[0].get("__subflavor__") for idx, data in zip(range(25), loader) ] - all_ranks_subflavors.append(subflavors) + all_ranks_subflavors.append(subflavors) - print(f"Subflavors for rank {rank_config.rank}:", subflavors) + print(f"Subflavors for rank {rank_config.rank}:", subflavors) - # Assert that all ranks got different data - for i in range(len(all_ranks_subflavors)): - for j in range(i + 1, len(all_ranks_subflavors)): - assert all_ranks_subflavors[i] != all_ranks_subflavors[j], ( - f"Rank {i} and rank {j} got the same subflavors." - ) - - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - - def test_slice_iter_shuffle_over_epochs(self): - torch.manual_seed(42) - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - shuffle_over_epochs_multiplier=-1, - ), - ) + # Assert that all ranks got different data + for i in range(len(all_ranks_subflavors)): + for j in range(i + 1, len(all_ranks_subflavors)): + assert all_ranks_subflavors[i] != all_ranks_subflavors[j], ( + f"Rank {i} and rank {j} got the same subflavors." + ) - # Train mode dataset - loader = new_loader() - _ = [data.text for idx, data in zip(range(1000), loader)] - def test_save_restore_next(self): - torch.manual_seed(42) +def test_slice_iter_shuffle_over_epochs(dataset_path): + torch.manual_seed(42) - wc = WorkerConfig( - rank=0, - world_size=1, - num_workers=6, - ) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) - initial_loader = get_savable_loader( + def new_loader(): + return get_savable_loader( get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + parallel_shard_iters=2, shuffle_buffer_size=None, max_samples_per_sequence=None, + shuffle_over_epochs_multiplier=-1, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, ) + + # Train mode dataset + with new_loader() as loader: + _ = [data.text for idx, data in zip(range(1000), loader)] + + +def test_save_restore_next(dataset_path): + torch.manual_seed(42) + + wc = WorkerConfig( + rank=0, + world_size=1, + num_workers=6, + ) + + with get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as initial_loader: skip_initial = 9 previous_cp = initial_loader.save_state_rank() @@ -1321,24 +1287,23 @@ def test_save_restore_next(self): for i, sample in zip(range(skip_initial), initial_loader): print(f"sample[@{i}]: {sample.text}") print("previous_cp:", previous_cp) - rst_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( - self.nested_mds_path, + dataset_path / "nested_metadataset.yaml", worker_config=wc, batch_size=1, shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, - ) - rst_loader.restore_state_rank(previous_cp) - for i, rst_sample in zip(range(1), rst_loader): - print(f"rst_sample[@{i}]: {rst_sample.text}") - assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" - assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" - assert sample.__restore_key__ == rst_sample.__restore_key__, f"{sample} != {rst_sample}" - previous_cp = initial_loader.save_state_rank() + ).with_restored_state_rank(previous_cp) as rst_loader: + for i, rst_sample in zip(range(1), rst_loader): + print(f"rst_sample[@{i}]: {rst_sample.text}") + assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" + assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" + assert sample.__restore_key__ == rst_sample.__restore_key__, ( + f"{sample} != {rst_sample}" + ) + previous_cp = initial_loader.save_state_rank() # Iterate 10 samples, the save state and store the next 10 samples for reference. state_initial = initial_loader.save_state_rank() @@ -1352,26 +1317,29 @@ def test_save_restore_next(self): ) ) - del initial_loader - gc.collect() - - second_loader = get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, - ) - second_loader.restore_state_rank(state_initial) - + second_loader = get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) + second_loader.restore_state_rank(state_initial) + # Save the state again, to check that it is the same as the just restored state + same_state = second_loader.save_state_rank() + print("same_state:", same_state) + assert_nested_equal(same_state, state_initial) + assert same_state is state_initial + + # This will propagate the state to the workers. + second_loader.start() + try: # Save the state again, to check that it is the same as the just restored state same_state = second_loader.save_state_rank() print("same_state:", same_state) - assert same_state == state_initial + assert_nested_equal(same_state, state_initial) for offset in range(10): try: @@ -1409,44 +1377,40 @@ def test_save_restore_next(self): raise ValueError(f"Failed to iterate @{offset + skip_initial} samples") from e # Restore state in a new loader - ref_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( - self.nested_mds_path, + dataset_path / "nested_metadataset.yaml", worker_config=wc, batch_size=1, shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, - ) - ref_loader.restore_state_rank(state_offset) - - # Get 1 sample from the restored loader - next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] - assert len(next_loader_samples) == 6 - next_loader_sample = next_loader_samples[0] - print( - "next_loader_samples:" - + f"\n [@{offset + skip_initial}] {sample.text}" - + "".join( - f"\n [@{idx}] {sample}" - for idx, sample in zip( - range(skip_initial + offset, skip_initial + offset + 6), - next_loader_samples, + ).with_restored_state_rank(state_offset) as ref_loader: + # Get 1 sample from the restored loader + next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] + assert len(next_loader_samples) == 6 + next_loader_sample = next_loader_samples[0] + print( + "next_loader_samples:" + + f"\n [@{offset + skip_initial}] {sample.text}" + + "".join( + f"\n [@{idx}] {sample}" + for idx, sample in zip( + range(skip_initial + offset, skip_initial + offset + 6), + next_loader_samples, + ) ) ) - ) - assert next_loader_sample.text == sample.text, f"{next_loader_sample} != {sample}" - assert next_loader_sample.__key__ == sample.__key__, ( - f"{next_loader_sample} != {sample}" - ) - assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( - f"{next_loader_sample} != {sample}" - ) + assert next_loader_sample.text == sample.text, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__key__ == sample.__key__, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( + f"{next_loader_sample} != {sample}" + ) except Exception as e: raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e - - -if __name__ == "__main__": - unittest.main() + finally: + second_loader.shutdown() diff --git a/tests/test_metadataset_fewsamp.py b/tests/test_metadataset_fewsamp.py index e8918237..ee746e2b 100644 --- a/tests/test_metadataset_fewsamp.py +++ b/tests/test_metadataset_fewsamp.py @@ -7,11 +7,11 @@ import logging import sys import tempfile -import unittest import warnings from pathlib import Path from typing import Iterable +import pytest import torch import webdataset as wds @@ -61,130 +61,129 @@ def get_blend_dataset(ds: SavableDataset): raise ValueError("No blend dataset found") -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds3").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 107), range(100, 107)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55)) - - self.mds_path = self.dataset_path / "metadataset_v2.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " path: ds1", - " - weight: 1", - " path: ds2", - " - weight: 1", - " path: ds3", - ] - ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds3").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 107), range(100, 107)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(0, 55)) + + mds_path = dataset_path / "metadataset_v2.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " - weight: 1", + " path: ds2", + " - weight: 1", + " path: ds3", + ] ) - - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: TextWebdataset", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) + print(dataset_path) + return dataset_path - def test_metadataset_few_samples_save_restore(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=32, - num_workers=1, - seed_offset=42, - ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=100, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 4 - - # The middle dataset should have 0 samples assigned to this rank - blend_ds = get_blend_dataset(train_dataset) - assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1 - assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0 - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: TextWebdataset", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) + +def test_metadataset_few_samples_save_restore(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=32, + num_workers=1, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_v2.yaml", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=100, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 4 + + # The middle dataset should have 0 samples assigned to this rank + blend_ds = get_blend_dataset(train_dataset) + assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1 + assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0 + + with get_savable_loader( + train_dataset, + ) as train_loader: # Load 3 samples list(zip(train_loader, range(3))) @@ -194,19 +193,16 @@ def test_metadataset_few_samples_save_restore(self): # Load 5 samples data1b = list(zip(train_loader, range(5))) - # Restore state - train_loader = get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=100, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - train_loader.restore_state_rank(state1) + # Restore state + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset_v2.yaml", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=100, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(state1) as train_loader: # Load 5 samples data2_restore = list(zip(train_loader, range(5))) @@ -221,23 +217,22 @@ def test_metadataset_few_samples_save_restore(self): assert order1b == order2, "The restored state does not match the original state." - def test_too_few_samples(self): - # Will only give a single sample, as there are 117 samples in total, and 100 ranks - ws = 100 - lens = [] - for i_rank in range(ws): - worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0) - loader = get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=1, - worker_config=worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, - ) + +def test_too_few_samples(dataset_path): + # Will only give a single sample, as there are 117 samples in total, and 100 ranks + ws = 100 + lens = [] + for i_rank in range(ws): + worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset_v2.yaml", + batch_size=1, + worker_config=worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as loader: lens.append(len(loader)) txts = [] @@ -249,15 +244,7 @@ def test_too_few_samples(self): f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}" ) - assert lens == [ - 2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1 - for i in range(100) - ] - - -if __name__ == "__main__": - # unittest.main() - ds = TestDataset() - ds.setUp() - ds.test_metadataset_few_samples_save_restore() - ds.tearDown() + assert lens == [ + 2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1 + for i in range(100) + ] diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index e5ff6444..e728ba61 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -8,13 +8,13 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path from typing import Iterable from unittest.mock import patch +import pytest import torch import webdataset as wds @@ -64,191 +64,191 @@ def _norng_state(state): @edataclass -class TestJoinedSample(Sample): +class JoinedSample(Sample): text1: torch.Tensor text2: torch.Tensor @staticmethod - def from_joined(ds1: TextSample, ds2: TextSample) -> "TestJoinedSample": - return TestJoinedSample.derive_from( + def from_joined(ds1: TextSample, ds2: TextSample) -> "JoinedSample": + return JoinedSample.derive_from( ds1, text1=ds1.text, text2=ds2.text, ) -def test_joiner(text1: TextSample, text2: TextSample) -> TestJoinedSample: - return TestJoinedSample.derive_from(text1, text1=f"j{text1.text}", text2=f"j{text2.text}") - - -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - random.seed(42) - - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy datasets - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(55)) - - # Create a shuffled dataset for joining with the ds1. It has overlap but includes more samples - shuffled_range_100 = list(range(100)) - random.shuffle(shuffled_range_100) - - self.create_text_test_dataset( - self.dataset_path / "ds1b", shuffled_range_100, shuffled_range_100, prefix="B" - ) - - shuffled_range_100 = list(range(100)) - random.shuffle(shuffled_range_100) - self.create_text_test_dataset( - self.dataset_path / "ds1c", shuffled_range_100, shuffled_range_100, prefix="C" - ) - - self.mds_path = self.dataset_path / "metadataset_v2.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " path: ds1", - " subflavors:", - " source: metadataset_v2.yaml", - " number: 43", - " mds: mds", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavors:", - " source: metadataset_v2.yaml", - " number: 44", - " mds: mds", - " val:", - " blend:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) - ) - self.nested_mds_path = self.dataset_path / "nested_metadataset_v2.yaml" - with open(self.nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 4", - " path: ./metadataset_v2.yaml", - " split_part: train", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_train", - " - path: ./metadataset_v2.yaml", - " split_part: val", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_val", - ] - ) +def my_joiner(text1: TextSample, text2: TextSample) -> JoinedSample: + return JoinedSample.derive_from(text1, text1=f"j{text1.text}", text2=f"j{text2.text}") + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + random.seed(42) + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy datasets + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(55)) + + # Create a shuffled dataset for joining with the ds1. It has overlap but includes more samples + shuffled_range_100 = list(range(100)) + random.shuffle(shuffled_range_100) + + create_text_test_dataset( + dataset_path / "ds1b", shuffled_range_100, shuffled_range_100, prefix="B" + ) + + shuffled_range_100 = list(range(100)) + random.shuffle(shuffled_range_100) + create_text_test_dataset( + dataset_path / "ds1c", shuffled_range_100, shuffled_range_100, prefix="C" + ) + + mds_path = dataset_path / "metadataset_v2.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " subflavors:", + " source: metadataset_v2.yaml", + " number: 43", + " mds: mds", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavors:", + " source: metadataset_v2.yaml", + " number: 44", + " mds: mds", + " val:", + " blend:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] ) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset( - path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" - ): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{prefix}{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, ) - - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) + nested_mds_path = dataset_path / "nested_metadataset_v2.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 4", + " path: ./metadataset_v2.yaml", + " split_part: train", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_train", + " - path: ./metadataset_v2.yaml", + " split_part: val", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_val", + ] ) - - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, ) - - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + print(dataset_path) + return dataset_path + + +def create_text_test_dataset( + path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" +): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{prefix}{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_v2.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: train_order1 = [ text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] @@ -257,115 +257,114 @@ def test_metadataset(self): assert len(Counter(train_order1)) == 110 assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - def test_nested_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - ) - dataset = load_dataset(self.nested_mds_path) - - raw_datasets = dataset.get_datasets( - training=False, split_part="train", worker_config=worker_config - ) - assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT - assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [ - 0.4, - 0.4, - 0.1, - 0.1, - ], [raw_dataset.weight for raw_dataset in raw_datasets.datasets] - assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ - "ds1", - "ds2", - "ds1", - "ds2", - ] - print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) - assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "nested_train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "nested_train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - ] - - def test_joined_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_v2.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " ds1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " ds2:", - " path: ds3", - " subflavors:", - " source2: ds3", - " number: 44", - " joiner:", - f" __module__: {TestJoinedSample.__module__}", - f" __class__: {TestJoinedSample.__name__}", - ] - ) +def test_nested_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + dataset = load_dataset(dataset_path / "nested_metadataset_v2.yaml") + + raw_datasets = dataset.get_datasets( + training=False, split_part="train", worker_config=worker_config + ) + assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [ + 0.4, + 0.4, + 0.1, + 0.1, + ], [raw_dataset.weight for raw_dataset in raw_datasets.datasets] + assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ + "ds1", + "ds2", + "ds1", + "ds2", + ] + print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) + assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "nested_train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "nested_train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + ] + + +def test_joined_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_v2.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " ds1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " ds2:", + " path: ds3", + " subflavors:", + " source2: ds3", + " number: 44", + " joiner:", + f" __module__: {JoinedSample.__module__}", + f" __class__: {JoinedSample.__name__}", + ] ) - prepare_metadataset(EPath(joined_mds_path)) - - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 + + with get_savable_loader( + train_dataset, + ) as train_loader: data = list(zip(range(2 * 55), train_loader)) txt1_order = [data.text1[0] for idx, data in data] txt2_order = [data.text2[0] for idx, data in data] @@ -394,21 +393,16 @@ def test_joined_metadataset(self): txt2_order = [data.text2 for idx, data in data] key_order = [data.__key__ for idx, data in data] - # Restore state - train_loader = get_savable_loader( - get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - - train_loader.restore_state_rank(state) - + # Restore state + with get_savable_loader( + get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(state) as train_loader: # Iterate 360 more items data = list(zip(range(60), train_loader)) txt1_order_rest = [data.text1 for idx, data in data] @@ -420,63 +414,61 @@ def test_joined_metadataset(self): assert txt2_order == txt2_order_rest assert key_order == key_order_rest - def test_joined_metadataset_joiner(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_joiner.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " text2:", - " path: ds3", - " subflavors:", - " source2: ds3", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) +def test_joined_metadataset_joiner(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_joiner.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " text2:", + " path: ds3", + " subflavors:", + " source2: ds3", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] ) - prepare_metadataset(EPath(joined_mds_path)) - - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 + + with get_savable_loader( + train_dataset, + ) as train_loader: data = list(zip(range(2 * 55), train_loader)) txt1_order = [data.text1[0] for idx, data in data] txt2_order = [data.text2[0] for idx, data in data] @@ -499,64 +491,62 @@ def test_joined_metadataset_joiner(self): # Every item must occurr 2 times (2*55). assert Counter(txt1_order).most_common(1)[0][1] == 2 - def test_left_join(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "left_join.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " text2:", - " path: ds1b", - " nonmatch: skip", - " subflavors:", - " source2: ds1b", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) +def test_left_join(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "left_join.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " text2:", + " path: ds1b", + " nonmatch: skip", + " subflavors:", + " source2: ds1b", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] ) - prepare_metadataset(EPath(joined_mds_path)) - - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 55, len(train_dataset) - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) - + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: data = list(zip(range(2 * 55), train_loader)) txt1_order = [data.text1[0] for idx, data in data] txt2_order = [data.text2[0] for idx, data in data] @@ -577,128 +567,48 @@ def test_left_join(self): # Every item must occurr 2 times (2*55). assert Counter(txt1_order).most_common(1)[0][1] == 2 - # Test that changing the file works as expected - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1c", - " subflavors:", - " source1: ds1c", - " number: 43", - " text2:", - " path: ds1b", - " nonmatch: skip", - " subflavors:", - " source2: ds1b", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - " - weight: 1", - " join:", - " text1:", - " path: ds1b", - " text2:", - " path: ds1", - " nonmatch: skip", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) + # Test that changing the file works as expected + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1c", + " subflavors:", + " source1: ds1c", + " number: 43", + " text2:", + " path: ds1b", + " nonmatch: skip", + " subflavors:", + " source2: ds1b", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + " - weight: 1", + " join:", + " text1:", + " path: ds1b", + " text2:", + " path: ds1", + " nonmatch: skip", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] ) - - # Expect this to fail. Preparation does not match! - with self.assertRaises(Exception): - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - - # Shall succeed after preparation - prepare_metadataset(EPath(joined_mds_path)) - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - # Check that there are no remainder files - cache_folder = joined_mds_path.with_name(joined_mds_path.name + ".cache") - assert sum(1 for f in cache_folder.iterdir() if f.is_file()) == 2, list( - cache_folder.iterdir() - ) - - def test_left_join_exclude(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, ) - # Create a joined dataset configuration - orig_split_path = self.dataset_path / "ds1" / ".nv-meta" / "split.yaml" - exclude_split_path = self.dataset_path / "ds1" / ".nv-meta" / "exclude_split.yaml" - with open(exclude_split_path, "w") as f: - f.write( - "\n".join( - [ - orig_split_path.read_text(), - "exclude:", - ' - "parts/data-0.tar/000000"', - ' - "parts/data-0.tar/000001"', - ' - "parts/data-0.tar/000002"', - ' - "parts/data-0.tar/000003"', - ' - "parts/data-0.tar/000004"', - ' - "parts/data-1.tar"', - ' - "parts/data-2.tar/000029"', - ] - ) - ) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "left_join.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " split_config: exclude_split.yaml", - " text2:", - " path: ds1b", - " nonmatch: skip", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) - + # Expect this to fail. Preparation does not match! + with pytest.raises(Exception): # Train mode dataset train_dataset = get_train_dataset( joined_mds_path, @@ -707,15 +617,91 @@ def test_left_join_exclude(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - 16, len(train_dataset) - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, + # Shall succeed after preparation + prepare_metadataset(EPath(joined_mds_path)) + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + # Check that there are no remainder files + cache_folder = joined_mds_path.with_name(joined_mds_path.name + ".cache") + assert sum(1 for f in cache_folder.iterdir() if f.is_file()) == 2, list(cache_folder.iterdir()) + + +def test_left_join_exclude(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + orig_split_path = dataset_path / "ds1" / ".nv-meta" / "split.yaml" + exclude_split_path = dataset_path / "ds1" / ".nv-meta" / "exclude_split.yaml" + with open(exclude_split_path, "w") as f: + f.write( + "\n".join( + [ + orig_split_path.read_text(), + "exclude:", + ' - "parts/data-0.tar/000000"', + ' - "parts/data-0.tar/000001"', + ' - "parts/data-0.tar/000002"', + ' - "parts/data-0.tar/000003"', + ' - "parts/data-0.tar/000004"', + ' - "parts/data-1.tar"', + ' - "parts/data-2.tar/000029"', + ] + ) ) + # Create a joined dataset configuration + joined_mds_path = dataset_path / "left_join.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " split_config: exclude_split.yaml", + " text2:", + " path: ds1b", + " nonmatch: skip", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] + ) + ) + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 - 16, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: data = list(zip(range(2 * 55), train_loader)) txt1_order = [data.text1[0] for idx, data in data] txt2_order = [data.text2[0] for idx, data in data] @@ -735,103 +721,102 @@ def test_left_join_exclude(self): assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums) assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums) - def test_joined_metadataset_prepare_mock(self): - torch.manual_seed(42) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " - path: ds1", - " - path: ds3", - " joiner:", - " __module__: __main__", - " __class__: NonExistantSample", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock2.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " - path: ds1", - " - path: ds3", - " joiner:", - " __module__: non_existant_module", - " __class__: MyCaptioningSample", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) - def test_metadataset_fixed_epochs(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, +def test_joined_metadataset_prepare_mock(dataset_path): + torch.manual_seed(42) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_prepare_mock.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " - path: ds1", + " - path: ds3", + " joiner:", + " __module__: __main__", + " __class__: NonExistantSample", + ] + ) ) - - # Create a joined dataset configuration - fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml" - with open(fixed_epochs_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend_epochized:", - " - repetitions: 2", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 3", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) + prepare_metadataset(EPath(joined_mds_path)) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_prepare_mock2.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " - path: ds1", + " - path: ds3", + " joiner:", + " __module__: non_existant_module", + " __class__: MyCaptioningSample", + ] ) - - # Train mode dataset - train_dataset = get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, ) - print(len(train_dataset)) - assert len(train_dataset) == 5 * 55, len(train_dataset) - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, + prepare_metadataset(EPath(joined_mds_path)) + + +def test_metadataset_fixed_epochs(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + fixed_epochs_mds_path = dataset_path / "metadataset_fixed_epochs.yaml" + with open(fixed_epochs_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend_epochized:", + " - repetitions: 2", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 3", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] + ) ) + # Train mode dataset + train_dataset = get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + print(len(train_dataset)) + assert len(train_dataset) == 5 * 55, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: data = list(enumerate(train_loader)) txt_order = [data.text[0] for idx, data in data] key_order = [data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data] @@ -879,20 +864,17 @@ def test_metadataset_fixed_epochs(self): assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) assert all(txt_cnt[key] in (2, 3) for key in txt_order) - # Restore state - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - train_loader.restore_state_rank(state1) + # Restore state + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), + ).with_restored_state_rank(state1) as train_loader: data2_restore = list(enumerate(train_loader)) assert len(data2_restore) == 2 * 55 txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore] @@ -914,60 +896,58 @@ def test_metadataset_fixed_epochs(self): assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst) assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst) - def test_metadataset_fixed_fractional_epochs(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - # Create a joined dataset configuration - fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml" - with open(fixed_epochs_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend_epochized:", - " - repetitions: 0.7", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 1.5", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) +def test_metadataset_fixed_fractional_epochs(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + fixed_epochs_mds_path = dataset_path / "metadataset_fixed_epochs.yaml" + with open(fixed_epochs_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend_epochized:", + " - repetitions: 0.7", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 1.5", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) - - # ===== Part 1: Verify fractions ===== - - # Train mode dataset - train_dataset = get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ) - - train_loader = get_savable_loader( - train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) + # ===== Part 1: Verify fractions ===== + + # Train mode dataset + train_dataset = get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + + with get_savable_loader( + train_dataset, + ) as train_loader: assert len(train_loader) == 38 + 55 + 27, len(train_loader) data = list(enumerate(train_loader)) @@ -987,43 +967,37 @@ def test_metadataset_fixed_fractional_epochs(self): # The remaining samples from ds2 (127 to incl. 154) should be repeated only once assert all(sample_counts[sample] == 1 for sample in range(127, 155)) - # ===== Part 2: Save and restore state ===== - - # Now let's check if the state is stored and restored correctly + # ===== Part 2: Save and restore state ===== - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) + # Now let's check if the state is stored and restored correctly + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ) as train_loader: data1 = list(zip(range(95), train_loader)) state1 = train_loader.save_state_rank() - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - train_loader.restore_state_rank(state1) + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ).with_restored_state_rank(state1) as train_loader: data2_restore = list(enumerate(train_loader)) total_samples_save_restore = len(data1) + len(data2_restore) @@ -1040,25 +1014,22 @@ def test_metadataset_fixed_fractional_epochs(self): "Sample counts do not match when using save/restore" ) - # ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset ===== + # ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset ===== - torch.manual_seed(42) - - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ) as train_loader: ds1_counter = 0 data1 = [] for idx, sample in enumerate(train_loader): @@ -1071,21 +1042,18 @@ def test_metadataset_fixed_fractional_epochs(self): state1 = train_loader.save_state_rank() - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - train_loader.restore_state_rank(state1) + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ).with_restored_state_rank(state1) as train_loader: data2_restore = list(enumerate(train_loader)) total_samples_save_restore = len(data1) + len(data2_restore) @@ -1102,22 +1070,19 @@ def test_metadataset_fixed_fractional_epochs(self): "Sample counts do not match when using save/restore" ) - # Try in repeat mode - # Train mode dataset - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) - + # Try in repeat mode + # Train mode dataset + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + ), + ) as train_loader: data = list(zip(range(200), train_loader)) assert len(train_loader) == 38 + 55 + 27, len(train_loader) @@ -1125,31 +1090,28 @@ def test_metadataset_fixed_fractional_epochs(self): # Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding) assert len(data) == 200, len(data) - # ===== Part 4: Test count for multiple workers ===== + # ===== Part 4: Test count for multiple workers ===== - worker_config = WorkerConfig( - rank=0, - world_size=2, - num_workers=2, - seed_offset=42, - ) - - # Train mode dataset - train_loader = get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, - ) + worker_config = WorkerConfig( + rank=0, + world_size=2, + num_workers=2, + seed_offset=42, + ) + # Train mode dataset + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ) as train_loader: # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py assert len(train_loader) == 58, len(train_loader) @@ -1160,8 +1122,10 @@ def test_metadataset_fixed_fractional_epochs(self): # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py assert len(data) == 58, len(data) - @patch.object(WatchdogDataset, "_watchdog_trigger") - def test_watchdog_dataset(self, mock_watchdog_trigger): + +def test_watchdog_dataset(dataset_path): + with patch.object(WatchdogDataset, "_watchdog_trigger") as mock_watchdog_trigger: + class TestTaskEncoder(DefaultTaskEncoder): def __init__(self): super().__init__() @@ -1187,7 +1151,7 @@ def encode_sample(self, sample: TextSample) -> TextSample: # Train mode dataset train_dataset = get_train_dataset( - self.mds_path, + dataset_path / "metadataset_v2.yaml", worker_config=worker_config, batch_size=1, shuffle_buffer_size=None, @@ -1195,99 +1159,53 @@ def encode_sample(self, sample: TextSample) -> TextSample: task_encoder=TestTaskEncoder(), ) - train_loader = get_loader( + with get_loader( train_dataset, watchdog_timeout_seconds=3, fail_on_timeout=False, - ) - - for idx, data in enumerate(train_loader): - print(idx, data.text[0]) - if idx > 255: - break + ) as train_loader: + for idx, data in enumerate(train_loader): + print(idx, data.text[0]) + if idx > 255: + break mock_watchdog_trigger.assert_called() - def test_dataset_absolute_nested_subset_fail(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # Absolute range on outer level should fail - " subset: {range: [50, 55]}", - " blend_epochized:", - " - path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) - try: - get_loader( - get_train_dataset( - ratio_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ) +def test_dataset_absolute_nested_subset_fail(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # Absolute range on outer level should fail + " subset: {range: [50, 55]}", + " blend_epochized:", + " - path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) - assert False, "Should have failed" - except Exception as e: - assert "only allowed for a leaf dataset" in str( - e - ) or "only use absolute subset ranges for a leaf dataset" in str(e), str(e) - return - - def test_dataset_with_subset_end_keyword(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # Absolute range: [50, end] - # I.e. corresponds to sample range: [50, 55] (end is not included, so up to 54) - " subset: {range: [50, end]}", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - ] - ) - ) - loader = get_loader( + try: + with get_loader( get_train_dataset( ratio_mds_path, worker_config=worker_config, @@ -1298,58 +1216,104 @@ def test_dataset_with_subset_end_keyword(self): max_samples_per_sequence=None, repeat=False, ) + ): + assert False, "Should have failed" + except Exception as e: + assert "only allowed for a leaf dataset" in str( + e + ) or "only use absolute subset ranges for a leaf dataset" in str(e), str(e) + return + + +def test_dataset_with_subset_end_keyword(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # Absolute range: [50, end] + # I.e. corresponds to sample range: [50, 55] (end is not included, so up to 54) + " subset: {range: [50, end]}", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + ] + ) ) + with get_loader( + get_train_dataset( + ratio_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: all_numbers = [int(s.text[0]) for s in loader] assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]" - def test_dataset_with_subset_ratio(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # 20% of the dataset will be from ds1, 80% from ds2 - # I.e. sample range: [0.2*55, 0.8*55] = [11, 44] - " subset: {range: [20%, 80%]}", - " blend_epochized:", - " - path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) - loader = get_loader( - get_train_dataset( - ratio_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, +def test_dataset_with_subset_ratio(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # 20% of the dataset will be from ds1, 80% from ds2 + # I.e. sample range: [0.2*55, 0.8*55] = [11, 44] + " subset: {range: [20%, 80%]}", + " blend_epochized:", + " - path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) ) + with get_loader( + get_train_dataset( + ratio_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: data = list(enumerate(loader)) assert len(data) == 33 + 33 * 2, len(data) @@ -1362,49 +1326,48 @@ def test_dataset_with_subset_ratio(self): assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts assert sample_counts.total() == 33 + 33 * 2, sample_counts.total() - # Combine with subset_samples - - ratio2_mds_path = self.dataset_path / "metadataset_ratio2.yaml" - with open(ratio2_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # take [10, 30] from ds1, [20, 40] from ds2 and then only [20%, 80%] - # I.e. sample range: [14, 26], 2 * [124, 136] - " subset: {range: [20%, 80%]}", - " blend_epochized:", - " - path: ds1", - " subset: {range: [10, 30]}", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " subset: {range: [20, 40]}", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) - - loader = get_loader( - get_train_dataset( - ratio2_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, + # Combine with subset_samples + + ratio2_mds_path = dataset_path / "metadataset_ratio2.yaml" + with open(ratio2_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # take [10, 30] from ds1, [20, 40] from ds2 and then only [20%, 80%] + # I.e. sample range: [14, 26], 2 * [124, 136] + " subset: {range: [20%, 80%]}", + " blend_epochized:", + " - path: ds1", + " subset: {range: [10, 30]}", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " subset: {range: [20, 40]}", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) ) + with get_loader( + get_train_dataset( + ratio2_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: data = list(enumerate(loader)) assert len(data) == 12 + 12 * 2, len(data) @@ -1417,48 +1380,47 @@ def test_dataset_with_subset_ratio(self): assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts assert sample_counts.total() == 12 + 12 * 2, sample_counts.total() - # Combine with subset_ratio and subset_samples and nested metadataset - nested_mds_path = self.dataset_path / "metadataset_nested_subset.yaml" - with open(nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " subset: {range: [0%, 50%]}", - " blend_epochized:", - " - path: ds3", - # take [30, 50] from ds3, then first 50%, resulting in samples [230, 240] - " subset: {range: [30, 50]}", - " subflavors:", - " source: ds3", - " number: 45", - " - repetitions: 2", - # Inner sample range: [14, 26], 2 * [124, 136], total=12*3=36 - # Applying subset ratio 25%-75%: [17, 23], 2*[127, 133], total=3*6=18 - # Applying outer 50%: [17, 20], 2*[127, 130], total=3*3=9 - # Applying repetition: 2*[17, 20], 4*[127, 130], total=2*9=18 - " subset: {range: [25%, 75%]}", - " path: metadataset_ratio2.yaml", - ] - ) - ) - - loader = get_loader( - get_train_dataset( - nested_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, + # Combine with subset_ratio and subset_samples and nested metadataset + nested_mds_path = dataset_path / "metadataset_nested_subset.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " subset: {range: [0%, 50%]}", + " blend_epochized:", + " - path: ds3", + # take [30, 50] from ds3, then first 50%, resulting in samples [230, 240] + " subset: {range: [30, 50]}", + " subflavors:", + " source: ds3", + " number: 45", + " - repetitions: 2", + # Inner sample range: [14, 26], 2 * [124, 136], total=12*3=36 + # Applying subset ratio 25%-75%: [17, 23], 2*[127, 133], total=3*6=18 + # Applying outer 50%: [17, 20], 2*[127, 130], total=3*3=9 + # Applying repetition: 2*[17, 20], 4*[127, 130], total=2*9=18 + " subset: {range: [25%, 75%]}", + " path: metadataset_ratio2.yaml", + ] ) ) + with get_loader( + get_train_dataset( + nested_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: data = list(enumerate(loader)) assert len(data) == 10 + 9 * 2, len(data) sample_counts = Counter([int(s[1].text[0]) for s in data]) @@ -1473,45 +1435,41 @@ def test_dataset_with_subset_ratio(self): assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts assert sample_counts.total() == 10 + 9 * 2, sample_counts.total() - def test_s3(self): - # Create a joined dataset configuration - mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml" - with open(mixed_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: msc://s3test_metadataset/test/dataset/nested_metadataset_v2.yaml", - ] - ) - ) - with setup_s3_emulator(profile_name="s3test_metadataset") as emu: - # Upload the dataset to the S3 emulator - # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) - emu.add_file(self.dataset_path, "test/dataset") - - train_dataset = get_loader( - get_train_dataset( - mixed_mds_path, - worker_config=WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ), - batch_size=1, - shuffle_buffer_size=10, - max_samples_per_sequence=None, - virtual_epoch_length=10, - ) +def test_s3(dataset_path): + # Create a joined dataset configuration + mixed_mds_path = dataset_path / "metadataset_mixed.yaml" + with open(mixed_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: msc://s3test_metadataset/test/dataset/nested_metadataset_v2.yaml", + ] ) + ) + + with setup_s3_emulator(profile_name="s3test_metadataset") as emu: + # Upload the dataset to the S3 emulator + # EPath(dataset_path).copy(EPath("msc://s3/test/dataset")) + emu.add_file(dataset_path, "test/dataset") + with get_loader( + get_train_dataset( + mixed_mds_path, + worker_config=WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ), + batch_size=1, + shuffle_buffer_size=10, + max_samples_per_sequence=None, + virtual_epoch_length=10, + ) + ) as train_dataset: data = list(enumerate(train_dataset)) assert len(data) == 10, len(data) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c76bd087..ddca5718 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,10 +6,10 @@ import logging import random import sys -import unittest from typing import Any, Tuple import numpy as np +import pytest import torch from PIL import Image from torchvision import transforms as T @@ -18,293 +18,287 @@ from megatron.energon.transforms import CustomTransform, MergeTransform -class TestTransforms(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) - def tearDown(self): - pass - @staticmethod - def reset_rng(seed: int = 42): - """Creates a PIL image with random noise.""" - np.random.seed(seed) - torch.manual_seed(seed) - random.seed(seed) +def reset_rng(seed: int = 42): + """Creates a PIL image with random noise.""" + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) - @staticmethod - def get_test_image(width: int, height: int): - """Creates a PIL image with random noise.""" - arr = np.zeros((width, height, 3), dtype=np.uint8) +def get_test_image(width: int, height: int): + """Creates a PIL image with random noise.""" - # Some colorful borders - arr[0, :, :] = [255, 0, 0] - arr[:, 0, :] = [255, 255, 0] - arr[-1, :, :] = [255, 255, 255] - arr[:, -1, :] = [0, 255, 0] + arr = np.zeros((width, height, 3), dtype=np.uint8) - # A single white pixel - if width > 3 and height > 3: - arr[3, 3, :] = [255, 255, 255] + # Some colorful borders + arr[0, :, :] = [255, 0, 0] + arr[:, 0, :] = [255, 255, 0] + arr[-1, :, :] = [255, 255, 255] + arr[:, -1, :] = [0, 255, 0] - # And in the middle some noise - if width > 10 and height > 10: - arr[5:-5, 5:-5, :] = np.random.randint(0, 255, (width - 10, height - 10, 3)) + # A single white pixel + if width > 3 and height > 3: + arr[3, 3, :] = [255, 255, 255] - return Image.fromarray(arr) + # And in the middle some noise + if width > 10 and height > 10: + arr[5:-5, 5:-5, :] = np.random.randint(0, 255, (width - 10, height - 10, 3)) - @staticmethod - def get_test_image_soft(width: int, height: int): - """Creates a PIL image smooth content""" + return Image.fromarray(arr) - arr = np.zeros((width, height, 3), dtype=np.uint8) - # Fill red channel the image with a smooth gradient from left to right. - arr[:, :, 0] = np.arange(width)[:, None] / width * 255 - # The same for green from top to bottom: - arr[:, :, 1] = np.arange(height)[None, :] / height * 255 +def get_test_image_soft(width: int, height: int): + """Creates a PIL image smooth content""" - return Image.fromarray(arr) + arr = np.zeros((width, height, 3), dtype=np.uint8) - def _apply_and_compare( - self, testable_transform, img, atol=2, seed=42, msg=None, only_nonblack=False - ): - # Then transform using our method - merge_transform = MergeTransform([testable_transform]) + # Fill red channel the image with a smooth gradient from left to right. + arr[:, :, 0] = np.arange(width)[:, None] / width * 255 + # The same for green from top to bottom: + arr[:, :, 1] = np.arange(height)[None, :] / height * 255 - self.reset_rng(seed=seed) - test_result = merge_transform(img) + return Image.fromarray(arr) - # And also transform using torchvision directly - self.reset_rng(seed=seed) - ref_result = testable_transform(img) - # Then compare the sizes and the images contents - self.assertEqual(test_result.size, ref_result.size) +def _apply_and_compare(testable_transform, img, atol=2, seed=42, msg=None, only_nonblack=False): + # Then transform using our method + merge_transform = MergeTransform([testable_transform]) - # Check that image contents are close - np_test = np.array(test_result) - np_ref = np.array(ref_result) + reset_rng(seed=seed) + test_result = merge_transform(img) - if only_nonblack: - nonblack_mask = (np_test > 0) & (np_ref > 0) - np_test = np_test[nonblack_mask] - np_ref = np_ref[nonblack_mask] + # And also transform using torchvision directly + reset_rng(seed=seed) + ref_result = testable_transform(img) - # The maximum allowed difference between pixel values is 2 (uint8) - self.assertTrue(np.allclose(np_test, np_ref, atol=atol), msg=msg) + # Then compare the sizes and the images contents + assert test_result.size == ref_result.size - def test_resize(self): - """Tests ResizeMapper""" + # Check that image contents are close + np_test = np.array(test_result) + np_ref = np.array(ref_result) - MAX_SIZE = 150 - # These are the different setups we test. Each entry is a tuple of - # (source size, resize_kwargs) + if only_nonblack: + nonblack_mask = (np_test > 0) & (np_ref > 0) + np_test = np_test[nonblack_mask] + np_ref = np_ref[nonblack_mask] - size_list = [ # source size (w, h), resize_kwargs - [(100, 100), {"size": (100, 100)}], - [(200, 50), {"size": (100, 100)}], - [(50, 50), {"size": (100, 100)}], - [(500, 500), {"size": (10, 10)}], - [(1, 2), {"size": (1, 3)}], # Scale width by 1.5x - [(50, 100), {"size": 100, "max_size": MAX_SIZE}], # Test max_size - ] + # The maximum allowed difference between pixel values is 2 (uint8) + assert np.allclose(np_test, np_ref, atol=atol), msg - for source_size, resize_kwargs in size_list: - logging.info( - f"Testing Resize with source size {source_size} and resize_kwargs {resize_kwargs}" - ) - # Create a test image of the given source size - img = TestTransforms.get_test_image(*source_size) - transform = T.Resize(**resize_kwargs, interpolation=InterpolationMode.NEAREST) +def test_resize(): + """Tests ResizeMapper""" - self._apply_and_compare( - transform, - img, - msg=f"Resize: source_size={source_size}, resize_kwargs={resize_kwargs}", - ) + MAX_SIZE = 150 + # These are the different setups we test. Each entry is a tuple of + # (source size, resize_kwargs) - def test_random_resized_crop(self): - """Tests RandomResizedCropMapper""" + size_list = [ # source size (w, h), resize_kwargs + [(100, 100), {"size": (100, 100)}], + [(200, 50), {"size": (100, 100)}], + [(50, 50), {"size": (100, 100)}], + [(500, 500), {"size": (10, 10)}], + [(1, 2), {"size": (1, 3)}], # Scale width by 1.5x + [(50, 100), {"size": 100, "max_size": MAX_SIZE}], # Test max_size + ] - randcrop = T.RandomResizedCrop( - 90, scale=(0.3, 0.7), ratio=(0.75, 1.3), interpolation=InterpolationMode.BILINEAR + for source_size, resize_kwargs in size_list: + logging.info( + f"Testing Resize with source size {source_size} and resize_kwargs {resize_kwargs}" ) - source_size = (50, 60) - - logging.info(f"Testing RandomResizedCrop with source size {source_size}") # Create a test image of the given source size - img = TestTransforms.get_test_image_soft(*source_size) + img = get_test_image(*source_size) + transform = T.Resize(**resize_kwargs, interpolation=InterpolationMode.NEAREST) - self._apply_and_compare(randcrop, img, msg="RandomResizedCrop") + _apply_and_compare( + transform, + img, + msg=f"Resize: source_size={source_size}, resize_kwargs={resize_kwargs}", + ) - def test_random_flip(self): - source_size = (55, 33) - img = TestTransforms.get_test_image(*source_size) - logging.info("Testing RandomHorizontalFlip 5 times") - for idx in range(5): - randhflip = T.RandomHorizontalFlip(p=0.8) - self._apply_and_compare(randhflip, img, seed=idx, msg="RandomHorizontalFlip") +def test_random_resized_crop(): + """Tests RandomResizedCropMapper""" - logging.info("Testing RandomVerticalFlip 5 times") - for idx in range(5): - randvflip = T.RandomVerticalFlip(p=0.8) - self._apply_and_compare(randvflip, img, seed=idx, msg="RandomVerticalFlip") + randcrop = T.RandomResizedCrop( + 90, scale=(0.3, 0.7), ratio=(0.75, 1.3), interpolation=InterpolationMode.BILINEAR + ) + source_size = (50, 60) - def test_random_rotation(self): - source_size = (55, 33) - img = TestTransforms.get_test_image_soft(*source_size) + logging.info(f"Testing RandomResizedCrop with source size {source_size}") - logging.info("Testing RandomRotation without expand") - for idx in range(5): - randrot = T.RandomRotation((-90, 269), interpolation=InterpolationMode.BILINEAR) - self._apply_and_compare( - randrot, - img, - seed=idx, - msg="RandomRotation without expand", - ) + # Create a test image of the given source size + img = get_test_image_soft(*source_size) - logging.info("Testing RandomRotation with expand") - for idx in range(5): - randrot = T.RandomRotation( - (-180, 269), interpolation=InterpolationMode.BILINEAR, expand=True - ) - self._apply_and_compare( - randrot, - img, - seed=idx, - msg="RandomRotation with expand", - ) + _apply_and_compare(randcrop, img, msg="RandomResizedCrop") - def test_random_crop(self): - source_size = (155, 120) - img = TestTransforms.get_test_image(*source_size) - - size_list = [ # crop size (w, h) - (155, 120), # Same size - (100, 50), - 3, # Single int as size - 120, - (155, 8), # One dimension same size - ] - - logging.info("Testing RandomCrop") - for idx, size in enumerate(size_list): - randcrop = T.RandomCrop(size) - self._apply_and_compare( - randcrop, - img, - seed=idx, - msg=f"RandomCrop: crop size={size}", - ) - # Test `pad_if_needed` (Crop size larger than image size) - randcrop = T.RandomCrop((500, 500), pad_if_needed=True) - self._apply_and_compare(randcrop, img) +def test_random_flip(): + source_size = (55, 33) + img = get_test_image(*source_size) - def test_random_perspective(self): - source_size = (128, 133) - img = TestTransforms.get_test_image_soft(*source_size) + logging.info("Testing RandomHorizontalFlip 5 times") + for idx in range(5): + randhflip = T.RandomHorizontalFlip(p=0.8) + _apply_and_compare(randhflip, img, seed=idx, msg="RandomHorizontalFlip") - logging.info("Testing RandomPerspective") - for idx in range(5): - randpersp = T.RandomPerspective(interpolation=InterpolationMode.BILINEAR) - self._apply_and_compare( - randpersp, - img, - seed=idx, - msg=f"RandomPerspective: source_size={source_size}", - only_nonblack=True, # Sometimes one pixel is off - ) + logging.info("Testing RandomVerticalFlip 5 times") + for idx in range(5): + randvflip = T.RandomVerticalFlip(p=0.8) + _apply_and_compare(randvflip, img, seed=idx, msg="RandomVerticalFlip") - def test_center_crop(self): - source_size_list = [ # source size (w, h) - (155, 120), - (154, 119), - ] - crop_size_list = [ # crop size (w, h) - (155, 120), # Same size - (100, 50), - 3, # Single int as size - 120, - (200, 50), # Large than image in x direction - (50, 200), # Large than image in y direction - (200, 200), # Large than image in both directions - ] +def test_random_rotation(): + source_size = (55, 33) + img = get_test_image_soft(*source_size) - logging.info("Testing CenterCrop") + logging.info("Testing RandomRotation without expand") + for idx in range(5): + randrot = T.RandomRotation((-90, 269), interpolation=InterpolationMode.BILINEAR) + _apply_and_compare( + randrot, + img, + seed=idx, + msg="RandomRotation without expand", + ) - for source_size in source_size_list: - img = TestTransforms.get_test_image(*source_size) + logging.info("Testing RandomRotation with expand") + for idx in range(5): + randrot = T.RandomRotation( + (-180, 269), interpolation=InterpolationMode.BILINEAR, expand=True + ) + _apply_and_compare( + randrot, + img, + seed=idx, + msg="RandomRotation with expand", + ) - for idx, crop_size in enumerate(crop_size_list): - centcrop = T.CenterCrop(crop_size) - self._apply_and_compare( - centcrop, - img, - seed=idx, - msg=f"CenterCrop: source_size={source_size}, crop_size={crop_size}", - ) - def test_custom(self): - """Tests if a custom transform works""" +def test_random_crop(): + source_size = (155, 120) + img = get_test_image(*source_size) + + size_list = [ # crop size (w, h) + (155, 120), # Same size + (100, 50), + 3, # Single int as size + 120, + (155, 8), # One dimension same size + ] + + logging.info("Testing RandomCrop") + for idx, size in enumerate(size_list): + randcrop = T.RandomCrop(size) + _apply_and_compare( + randcrop, + img, + seed=idx, + msg=f"RandomCrop: crop size={size}", + ) - source_size = (128, 133) + # Test `pad_if_needed` (Crop size larger than image size) + randcrop = T.RandomCrop((500, 500), pad_if_needed=True) + _apply_and_compare(randcrop, img) - class FixedTranslate(CustomTransform): - """Translates the image by 5 pixels in both x and y direction""" - def __init__(self): - pass +def test_random_perspective(): + source_size = (128, 133) + img = get_test_image_soft(*source_size) - def apply_transform( - self, matrix: np.ndarray, dst_size: np.ndarray - ) -> Tuple[Any, Any, Any]: - matrix = self.translate(5, 5) @ matrix - return matrix, dst_size, (self.__class__.__name__, (5, 5)) + logging.info("Testing RandomPerspective") + for idx in range(5): + randpersp = T.RandomPerspective(interpolation=InterpolationMode.BILINEAR) + _apply_and_compare( + randpersp, + img, + seed=idx, + msg=f"RandomPerspective: source_size={source_size}", + only_nonblack=True, # Sometimes one pixel is off + ) - img = TestTransforms.get_test_image(*source_size) - merge_transform = MergeTransform([FixedTranslate()]) - test_result = merge_transform(img) +def test_center_crop(): + source_size_list = [ # source size (w, h) + (155, 120), + (154, 119), + ] - reference_img = Image.new(img.mode, img.size, (0, 0, 0)) - reference_img.paste(img, (5, 5)) + crop_size_list = [ # crop size (w, h) + (155, 120), # Same size + (100, 50), + 3, # Single int as size + 120, + (200, 50), # Large than image in x direction + (50, 200), # Large than image in y direction + (200, 200), # Large than image in both directions + ] - self.assertTrue( - np.allclose(np.array(test_result), np.array(reference_img), atol=1), - msg="FixedTranslate", - ) + logging.info("Testing CenterCrop") + + for source_size in source_size_list: + img = get_test_image(*source_size) - def test_merge(self): - """Tests if two merged transforms yield the same result. - Merging RandomCrop and RandomPerspective.""" + for idx, crop_size in enumerate(crop_size_list): + centcrop = T.CenterCrop(crop_size) + _apply_and_compare( + centcrop, + img, + seed=idx, + msg=f"CenterCrop: source_size={source_size}, crop_size={crop_size}", + ) - source_size = (128, 133) - img = TestTransforms.get_test_image_soft(*source_size) - randcrop = T.RandomCrop((70, 70)) - randrot = T.RandomRotation((45, 269), interpolation=InterpolationMode.BILINEAR) +def test_custom(): + """Tests if a custom transform works""" - merge_transform = MergeTransform([randrot, randcrop]) - self.reset_rng(1) - test_result = merge_transform(img) + source_size = (128, 133) - self.reset_rng(1) - ref_result = randcrop(randrot(img)) + class FixedTranslate(CustomTransform): + """Translates the image by 5 pixels in both x and y direction""" - self.assertTrue( - np.allclose(np.array(test_result), np.array(ref_result), atol=1), - msg="MergeTransform of RandomRotation and RandomCrop", - ) + def __init__(self): + pass + + def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]: + matrix = self.translate(5, 5) @ matrix + return matrix, dst_size, (self.__class__.__name__, (5, 5)) + + img = get_test_image(*source_size) + + merge_transform = MergeTransform([FixedTranslate()]) + test_result = merge_transform(img) + + reference_img = Image.new(img.mode, img.size, (0, 0, 0)) + reference_img.paste(img, (5, 5)) + + assert np.allclose(np.array(test_result), np.array(reference_img), atol=1), "FixedTranslate" + + +def test_merge(): + """Tests if two merged transforms yield the same result. + Merging RandomCrop and RandomPerspective.""" + + source_size = (128, 133) + img = get_test_image_soft(*source_size) + + randcrop = T.RandomCrop((70, 70)) + randrot = T.RandomRotation((45, 269), interpolation=InterpolationMode.BILINEAR) + + merge_transform = MergeTransform([randrot, randcrop]) + reset_rng(1) + test_result = merge_transform(img) + reset_rng(1) + ref_result = randcrop(randrot(img)) -if __name__ == "__main__": - unittest.main() + assert np.allclose(np.array(test_result), np.array(ref_result), atol=1), ( + "MergeTransform of RandomRotation and RandomCrop" + ) diff --git a/tests/test_typed_converter.py b/tests/test_typed_converter.py new file mode 100644 index 00000000..74f2ee73 --- /dev/null +++ b/tests/test_typed_converter.py @@ -0,0 +1,465 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses +from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Type, Union + +import pytest + +from megatron.energon import typed_converter + + +class MyNamedTuple(NamedTuple): + x: int + y: str + + +@dataclasses.dataclass +class MyDataClass: + a: int + b: str = "default" + + +class ComplexNamedTuple(NamedTuple): + id: int + name: str + active: bool + + +@dataclasses.dataclass +class NestedDataClass: + value: float + metadata: Dict[str, Any] + tags: List[str] + + +@dataclasses.dataclass +class ExtendedNestedDataClass(NestedDataClass): + extra: str + + +def my_call(a: int, b: str) -> NestedDataClass: + return NestedDataClass(value=a, metadata={"b": b}, tags=[b]) + + +def my_call_extended(a: int, b: str) -> ExtendedNestedDataClass: + return ExtendedNestedDataClass(value=a, metadata={"b": b}, tags=[b], extra=f"extra_{b}") + + +@dataclasses.dataclass +class ComprehensiveDataClass: + # Primitive types + string_field: str + int_field: int + float_field: float + bool_field: bool + + # Optional types + optional_string: Optional[str] = None + optional_int: Optional[int] = None + + # Union types + union_field: Union[str, int] = "default" + union_optional: Union[str, None] = None + + # List types + string_list: List[str] = dataclasses.field(default_factory=list) + int_list: List[int] = dataclasses.field(default_factory=list) + nested_list: List[List[str]] = dataclasses.field(default_factory=list) + + # Dict types + string_dict: Dict[str, str] = dataclasses.field(default_factory=dict) + mixed_dict: Dict[str, Any] = dataclasses.field(default_factory=dict) + nested_dict: Dict[str, Dict[str, int]] = dataclasses.field(default_factory=dict) + + # Tuple types + fixed_tuple: Tuple[str, int, bool] = ("default", 0, False) + variable_tuple: Tuple[str, ...] = ("single",) + + # Set types + set_field: Set[int] = dataclasses.field(default_factory=set) + + # Literal types + status: Literal["active", "inactive", "pending"] = "pending" + priority: Literal[1, 2, 3, 4, 5] = 3 + + # Nested dataclass + nested: Optional[NestedDataClass] = None + + # Referencing a type + type_ref: Type[NestedDataClass] = NestedDataClass + + # Referencing a function + function_ref: Callable[[int, str], NestedDataClass] = my_call + + # NamedTuple + named_tuple: Optional[ComplexNamedTuple] = None + + # Any type + any_field: Any = None + + +def test_raw_to_typed_namedtuple(): + parser = typed_converter.JsonParser() + raw = {"x": 42, "y": "foo"} + result = parser.raw_to_typed(raw, MyNamedTuple) + assert isinstance(result, MyNamedTuple) + assert result.x == 42 + assert result.y == "foo" + + +def test_raw_to_typed_dataclass(): + parser = typed_converter.JsonParser() + raw = {"a": 7, "b": "bar"} + result = parser.raw_to_typed(raw, MyDataClass) + assert isinstance(result, MyDataClass) + assert result.a == 7 + assert result.b == "bar" + + +def test_raw_to_typed_dataclass_default(): + parser = typed_converter.JsonParser() + raw = {"a": 5} + result = parser.raw_to_typed(raw, MyDataClass) + assert result.a == 5 + assert result.b == "default" + + +def test_raw_to_typed_union(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed(123, Union[int, str]) == 123 + assert parser.raw_to_typed("abc", Union[int, str]) == "abc" + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(1.5, Union[int, str]) + + +def test_raw_to_typed_optional(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed(None, Optional[int]) is None + assert parser.raw_to_typed(10, Optional[int]) == 10 + + +def test_raw_to_typed_list(): + parser = typed_converter.JsonParser() + raw = [1, 2, 3] + result = parser.raw_to_typed(raw, List[int]) + assert result == [1, 2, 3] + + +def test_raw_to_typed_dict(): + parser = typed_converter.JsonParser() + raw = {"foo": 1, "bar": 2} + result = parser.raw_to_typed(raw, Dict[str, int]) + assert result == {"foo": 1, "bar": 2} + + +def test_raw_to_typed_set(): + parser = typed_converter.JsonParser() + raw = [1, 2, 3] + result = parser.raw_to_typed(raw, Set[int]) + assert result == {1, 2, 3} + + +def test_raw_to_typed_literal(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed("yes", Literal["yes", "no"]) == "yes" + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed("maybe", Literal["yes", "no"]) + + +def test_to_json_object_namedtuple(): + obj = MyNamedTuple(x=1, y="abc") + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"x": 1, "y": "abc"} + + +def test_to_json_object_dataclass(): + obj = MyDataClass(a=2, b="xyz") + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"a": 2, "b": "xyz"} + + +def test_to_json_object_list(): + obj = [1, 2, 3] + json_obj = typed_converter.to_json_object(obj) + assert json_obj == [1, 2, 3] + + +def test_to_json_object_dict(): + obj = {"foo": 1, "bar": 2} + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"foo": 1, "bar": 2} + + +def test_isinstance_deep(): + assert typed_converter._isinstance_deep(1, int) + assert not typed_converter._isinstance_deep(1, str) + assert not typed_converter._isinstance_deep(1, float) + assert not typed_converter._isinstance_deep("1", int) + assert not typed_converter._isinstance_deep("1", float) + assert typed_converter._isinstance_deep([1, 2], List[int]) + assert not typed_converter._isinstance_deep([1, "a"], List[int]) + assert typed_converter._isinstance_deep({"a": 1}, Dict[str, int]) + assert not typed_converter._isinstance_deep({"a": "b"}, Dict[str, int]) + + +def test_missing_value_error(): + parser = typed_converter.JsonParser() + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(typed_converter._missing_value, int) + + +def test_strict_extra_keys(): + parser = typed_converter.JsonParser(strict=True) + raw = {"a": 1, "b": "foo", "extra": 123} + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw, MyDataClass) + + +def test_non_strict_extra_keys(): + parser = typed_converter.JsonParser(strict=False) + raw = {"a": 1, "b": "foo", "extra": 123} + result = parser.raw_to_typed(raw, MyDataClass) + assert result.a == 1 + assert result.b == "foo" + + +def test_comprehensive_dataclass(): + """Test a complex dataclass with all supported types.""" + parser = typed_converter.JsonParser() + + # Create comprehensive raw data + raw_data = { + "string_field": "test_string", + "int_field": 42, + "float_field": 3.14159, + "bool_field": True, + "optional_string": "optional_value", + "optional_int": 100, + "union_field": 123, # Using int instead of string + "union_optional": "union_string", + "string_list": ["item1", "item2", "item3"], + "int_list": [1, 2, 3, 4, 5], + "nested_list": [["a", "b"], ["c", "d"]], + "string_dict": {"key1": "value1", "key2": "value2"}, + "mixed_dict": {"str_key": "string", "int_key": 42, "bool_key": True}, + "nested_dict": {"outer1": {"inner1": 1, "inner2": 2}, "outer2": {"inner3": 3}}, + "fixed_tuple": ["tuple_string", 99, True], + "variable_tuple": ["var1", "var2", "var3"], + "set_field": [1, 2, 3], + "status": "active", + "priority": 5, + "nested": { + "value": 2.71828, + "metadata": {"nested_key": "nested_value", "count": 42}, + "tags": ["tag1", "tag2"], + }, + "named_tuple": {"id": 123, "name": "test_name", "active": False}, + "any_field": {"arbitrary": "data", "number": 999}, + } + + # Convert raw data to typed object + result = parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Verify all fields + assert result.string_field == "test_string" + assert result.int_field == 42 + assert result.float_field == 3.14159 + assert result.bool_field is True + assert result.optional_string == "optional_value" + assert result.optional_int == 100 + assert result.union_field == 123 + assert result.union_optional == "union_string" + assert result.string_list == ["item1", "item2", "item3"] + assert result.int_list == [1, 2, 3, 4, 5] + assert result.nested_list == [["a", "b"], ["c", "d"]] + assert result.string_dict == {"key1": "value1", "key2": "value2"} + assert result.mixed_dict == {"str_key": "string", "int_key": 42, "bool_key": True} + assert result.nested_dict == {"outer1": {"inner1": 1, "inner2": 2}, "outer2": {"inner3": 3}} + assert result.fixed_tuple == ("tuple_string", 99, True) + assert result.variable_tuple == ("var1", "var2", "var3") + assert result.set_field == {1, 2, 3} + assert result.status == "active" + assert result.priority == 5 + + # Verify nested dataclass + assert isinstance(result.nested, NestedDataClass) + assert result.nested.value == 2.71828 + assert result.nested.metadata == {"nested_key": "nested_value", "count": 42} + assert result.nested.tags == ["tag1", "tag2"] + + # Verify NamedTuple + assert isinstance(result.named_tuple, ComplexNamedTuple) + assert result.named_tuple.id == 123 + assert result.named_tuple.name == "test_name" + assert result.named_tuple.active is False + + # Verify Any field + assert result.any_field == {"arbitrary": "data", "number": 999} + + # Test conversion back to JSON + json_obj = typed_converter.to_json_object(result) + + # Verify JSON conversion preserves data + assert json_obj["string_field"] == "test_string" + assert json_obj["int_field"] == 42 + assert json_obj["float_field"] == 3.14159 + assert json_obj["bool_field"] is True + assert json_obj["optional_string"] == "optional_value" + assert json_obj["optional_int"] == 100 + assert json_obj["union_field"] == 123 + assert json_obj["union_optional"] == "union_string" + assert json_obj["string_list"] == ["item1", "item2", "item3"] + assert json_obj["int_list"] == [1, 2, 3, 4, 5] + assert json_obj["nested_list"] == [["a", "b"], ["c", "d"]] + assert json_obj["string_dict"] == {"key1": "value1", "key2": "value2"} + assert json_obj["mixed_dict"] == {"str_key": "string", "int_key": 42, "bool_key": True} + assert json_obj["nested_dict"] == { + "outer1": {"inner1": 1, "inner2": 2}, + "outer2": {"inner3": 3}, + } + assert json_obj["fixed_tuple"] == ["tuple_string", 99, True] + assert json_obj["variable_tuple"] == ["var1", "var2", "var3"] + assert json_obj["set_field"] == [1, 2, 3] + assert json_obj["status"] == "active" + assert json_obj["priority"] == 5 + assert json_obj["nested"]["value"] == 2.71828 + assert json_obj["nested"]["metadata"] == {"nested_key": "nested_value", "count": 42} + assert json_obj["nested"]["tags"] == ["tag1", "tag2"] + assert json_obj["named_tuple"]["id"] == 123 + assert json_obj["named_tuple"]["name"] == "test_name" + assert json_obj["named_tuple"]["active"] is False + assert json_obj["any_field"] == {"arbitrary": "data", "number": 999} + assert json_obj["function_ref"]["__module__"] == my_call.__module__ + assert json_obj["function_ref"]["__function__"] == my_call.__name__ + assert json_obj["type_ref"]["__module__"] == NestedDataClass.__module__ + assert json_obj["type_ref"]["__class__"] == NestedDataClass.__name__ + + +def test_comprehensive_dataclass_with_defaults(): + """Test comprehensive dataclass with minimal data using defaults.""" + parser = typed_converter.JsonParser() + + # Minimal raw data - only required fields + raw_data = {"string_field": "minimal", "int_field": 1, "float_field": 1.0, "bool_field": False} + + result = parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Verify required fields + assert result.string_field == "minimal" + assert result.int_field == 1 + assert result.float_field == 1.0 + assert result.bool_field is False + + # Verify defaults + assert result.optional_string is None + assert result.optional_int is None + assert result.union_field == "default" + assert result.union_optional is None + assert result.string_list == [] + assert result.int_list == [] + assert result.nested_list == [] + assert result.string_dict == {} + assert result.mixed_dict == {} + assert result.nested_dict == {} + assert result.fixed_tuple == ("default", 0, False) + assert result.variable_tuple == ("single",) + assert result.set_field == set() + assert result.status == "pending" + assert result.priority == 3 + assert result.nested is None + assert result.named_tuple is None + assert result.any_field is None + + +def test_comprehensive_dataclass_error_cases(): + """Test error cases for comprehensive dataclass.""" + parser = typed_converter.JsonParser() + + # Test invalid literal value + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "status": "invalid_status", # Should fail + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid union type + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "union_field": 1.5, # Should fail - not str or int + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid list element type + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "string_list": ["valid", 123, "also_valid"], # Should fail - int in string list + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid tuple length + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "fixed_tuple": ["only", "two"], # Should fail - needs 3 elements + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + +def test_comprehensive_dataclass_strict_mode(): + """Test comprehensive dataclass in strict mode with extra keys.""" + parser = typed_converter.JsonParser(strict=True) + + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "extra_field": "should_fail", # Should fail in strict mode + "nested": { + "__module__": my_call_extended.__module__, + "__function__": my_call_extended.__name__, + "a": 42, + "b": "Hello", + }, + "function_ref": { + "__module__": my_call_extended.__module__, + "__function__": my_call_extended.__name__, + }, + "type_ref": { + "__module__": ExtendedNestedDataClass.__module__, + "__class__": ExtendedNestedDataClass.__name__, + }, + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test non-strict mode allows extra keys + parser_non_strict = typed_converter.JsonParser(strict=False) + result = parser_non_strict.raw_to_typed(raw_data, ComprehensiveDataClass) + assert result.string_field == "test" + assert result.int_field == 1 + assert result.nested == ExtendedNestedDataClass( + value=42, metadata={"b": "Hello"}, tags=["Hello"], extra="extra_Hello" + ) + assert result.function_ref == my_call_extended + assert result.type_ref == ExtendedNestedDataClass diff --git a/tests/test_weakref.py b/tests/test_weakref.py index 7d104b86..69e4aadb 100644 --- a/tests/test_weakref.py +++ b/tests/test_weakref.py @@ -1,7 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing -import unittest from dataclasses import dataclass from megatron.energon.fork_hook import ( @@ -14,187 +13,170 @@ ) -class TestWeakref(unittest.TestCase): - def test_weakcallbacks(self): - # Just test the internal WeakCallbacks class. - a_called = 0 - fn_called = 0 +def test_weakcallbacks(): + # Just test the internal WeakCallbacks class. + a_called = 0 + fn_called = 0 - class A: - def method(self): - nonlocal a_called - a_called += 1 + class A: + def method(self): + nonlocal a_called + a_called += 1 - def fn(): - nonlocal fn_called - fn_called += 1 + def fn(): + nonlocal fn_called + fn_called += 1 - a = A() + a = A() - registry = WeakCallbacks() + registry = WeakCallbacks() - registry.add_hook(a.method) - registry.add_hook(fn) - registry.add_hook(a.method) + registry.add_hook(a.method) + registry.add_hook(fn) + registry.add_hook(a.method) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 1, fn_called + assert a_called == 1, a_called + assert fn_called == 1, fn_called - assert len(registry._hooks) == 2, len(registry._hooks) + assert len(registry._hooks) == 2, len(registry._hooks) - del a + del a - assert len(registry._hooks) == 1, len(registry._hooks) + assert len(registry._hooks) == 1, len(registry._hooks) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 2, fn_called + assert a_called == 1, a_called + assert fn_called == 2, fn_called - del fn + del fn - assert len(registry._hooks) == 0, len(registry._hooks) + assert len(registry._hooks) == 0, len(registry._hooks) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 2, fn_called + assert a_called == 1, a_called + assert fn_called == 2, fn_called - assert len(registry._hooks) == 0, len(registry._hooks) + assert len(registry._hooks) == 0, len(registry._hooks) - def test_fork_weakref(self): - # Verify that the fork hooks are called correctly, and that gc works correctly. - _a_before_fork_called = 0 - _a_after_in_child_fork_called = 0 - _a_after_in_parent_fork_called = 0 +def test_fork_weakref(): + # Verify that the fork hooks are called correctly, and that gc works correctly. - class A(ForkMixin): - def __before_fork__(self): - nonlocal _a_before_fork_called - _a_before_fork_called += 1 + _a_before_fork_called = 0 + _a_after_in_child_fork_called = 0 + _a_after_in_parent_fork_called = 0 - def __after_in_child_fork__(self): - nonlocal _a_after_in_child_fork_called - _a_after_in_child_fork_called += 1 + class A(ForkMixin): + def __before_fork__(self): + nonlocal _a_before_fork_called + _a_before_fork_called += 1 - def __after_in_parent_fork__(self): - nonlocal _a_after_in_parent_fork_called - _a_after_in_parent_fork_called += 1 + def __after_in_child_fork__(self): + nonlocal _a_after_in_child_fork_called + _a_after_in_child_fork_called += 1 - _b_before_fork_called = 0 - _b_after_in_child_fork_called = 0 - _b_after_in_parent_fork_called = 0 + def __after_in_parent_fork__(self): + nonlocal _a_after_in_parent_fork_called + _a_after_in_parent_fork_called += 1 - @dataclass - class B(DataclassForkMixin): - def __before_fork__(self): - nonlocal _b_before_fork_called - _b_before_fork_called += 1 + _b_before_fork_called = 0 + _b_after_in_child_fork_called = 0 + _b_after_in_parent_fork_called = 0 - def __after_in_child_fork__(self): - nonlocal _b_after_in_child_fork_called - _b_after_in_child_fork_called += 1 + @dataclass + class B(DataclassForkMixin): + def __before_fork__(self): + nonlocal _b_before_fork_called + _b_before_fork_called += 1 - def __after_in_parent_fork__(self): - nonlocal _b_after_in_parent_fork_called - _b_after_in_parent_fork_called += 1 + def __after_in_child_fork__(self): + nonlocal _b_after_in_child_fork_called + _b_after_in_child_fork_called += 1 - a = A() - b = B() + def __after_in_parent_fork__(self): + nonlocal _b_after_in_parent_fork_called + _b_after_in_parent_fork_called += 1 - _before_fork_called = 0 - _after_in_child_fork_called = 0 - _after_in_parent_fork_called = 0 + a = A() + b = B() - def before_fork(): - nonlocal _before_fork_called - _before_fork_called += 1 + _before_fork_called = 0 + _after_in_child_fork_called = 0 + _after_in_parent_fork_called = 0 - def after_in_child_fork(): - nonlocal _after_in_child_fork_called - _after_in_child_fork_called += 1 + def before_fork(): + nonlocal _before_fork_called + _before_fork_called += 1 - def after_in_parent_fork(): - nonlocal _after_in_parent_fork_called - _after_in_parent_fork_called += 1 + def after_in_child_fork(): + nonlocal _after_in_child_fork_called + _after_in_child_fork_called += 1 - before_fork_hook(before_fork) - after_in_child_fork_hook(after_in_child_fork) - after_in_parent_fork_hook(after_in_parent_fork) + def after_in_parent_fork(): + nonlocal _after_in_parent_fork_called + _after_in_parent_fork_called += 1 - multiprocessing.set_start_method("fork", force=True) + before_fork_hook(before_fork) + after_in_child_fork_hook(after_in_child_fork) + after_in_parent_fork_hook(after_in_parent_fork) - def process_verify_fork_hooks_1(): - # Verify in the process that the fork hooks were called - assert _before_fork_called == 1, _before_fork_called - assert _after_in_child_fork_called == 1, _after_in_child_fork_called - # This was not called in the child process - assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called - - assert _a_before_fork_called == 1, _a_before_fork_called - assert _a_after_in_child_fork_called == 1, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called - - assert _b_before_fork_called == 1, _b_before_fork_called - assert _b_after_in_child_fork_called == 1, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called - - p1 = multiprocessing.Process(target=process_verify_fork_hooks_1) - p1.start() - p1.join() - assert p1.exitcode == 0, p1.exitcode + multiprocessing.set_start_method("fork", force=True) + def process_verify_fork_hooks_1(): + # Verify in the process that the fork hooks were called assert _before_fork_called == 1, _before_fork_called - assert _after_in_child_fork_called == 0, _after_in_child_fork_called - assert _after_in_parent_fork_called == 1, _after_in_parent_fork_called + assert _after_in_child_fork_called == 1, _after_in_child_fork_called + # This was not called in the child process + assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called assert _a_before_fork_called == 1, _a_before_fork_called - assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 1, _a_after_in_parent_fork_called + assert _a_after_in_child_fork_called == 1, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called assert _b_before_fork_called == 1, _b_before_fork_called - assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 1, _b_after_in_parent_fork_called + assert _b_after_in_child_fork_called == 1, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called - _a_before_fork_called = 0 - _a_after_in_child_fork_called = 0 - _a_after_in_parent_fork_called = 0 + p1 = multiprocessing.Process(target=process_verify_fork_hooks_1) + p1.start() + p1.join() + assert p1.exitcode == 0, p1.exitcode - _b_before_fork_called = 0 - _b_after_in_child_fork_called = 0 - _b_after_in_parent_fork_called = 0 + assert _before_fork_called == 1, _before_fork_called + assert _after_in_child_fork_called == 0, _after_in_child_fork_called + assert _after_in_parent_fork_called == 1, _after_in_parent_fork_called - _before_fork_called = 0 - _after_in_child_fork_called = 0 - _after_in_parent_fork_called = 0 + assert _a_before_fork_called == 1, _a_before_fork_called + assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 1, _a_after_in_parent_fork_called - del a - del b - del before_fork - del after_in_child_fork - del after_in_parent_fork + assert _b_before_fork_called == 1, _b_before_fork_called + assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 1, _b_after_in_parent_fork_called - def process_verify_fork_hooks_2(): - assert _before_fork_called == 0, _before_fork_called - assert _after_in_child_fork_called == 0, _after_in_child_fork_called - assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called + _a_before_fork_called = 0 + _a_after_in_child_fork_called = 0 + _a_after_in_parent_fork_called = 0 - assert _a_before_fork_called == 0, _a_before_fork_called - assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called + _b_before_fork_called = 0 + _b_after_in_child_fork_called = 0 + _b_after_in_parent_fork_called = 0 - assert _b_before_fork_called == 0, _b_before_fork_called - assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called + _before_fork_called = 0 + _after_in_child_fork_called = 0 + _after_in_parent_fork_called = 0 - p2 = multiprocessing.Process(target=process_verify_fork_hooks_2) - p2.start() - p2.join() - assert p2.exitcode == 0, p2.exitcode + del a + del b + del before_fork + del after_in_child_fork + del after_in_parent_fork + def process_verify_fork_hooks_2(): assert _before_fork_called == 0, _before_fork_called assert _after_in_child_fork_called == 0, _after_in_child_fork_called assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called @@ -206,3 +188,20 @@ def process_verify_fork_hooks_2(): assert _b_before_fork_called == 0, _b_before_fork_called assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called + + p2 = multiprocessing.Process(target=process_verify_fork_hooks_2) + p2.start() + p2.join() + assert p2.exitcode == 0, p2.exitcode + + assert _before_fork_called == 0, _before_fork_called + assert _after_in_child_fork_called == 0, _after_in_child_fork_called + assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called + + assert _a_before_fork_called == 0, _a_before_fork_called + assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called + + assert _b_before_fork_called == 0, _b_before_fork_called + assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called diff --git a/uv.lock b/uv.lock index e5b6238f..a6a0ca80 100644 --- a/uv.lock +++ b/uv.lock @@ -144,7 +144,7 @@ wheels = [ [[package]] name = "aistore" -version = "1.17.0" +version = "1.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "braceexpand" }, @@ -161,9 +161,9 @@ dependencies = [ { name = "urllib3" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/06/de08fef1a17ed1245603774edf24235a6ba1f6df46d5a53f8fe09b1cd69c/aistore-1.17.0.tar.gz", hash = "sha256:619bafd48ca179aa055f1043594062b63e9797f47fcf150946297790858fc96b", size = 123210, upload-time = "2025-10-16T19:04:52.493Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/fd/53269da4e1d48c07826ff955989eb69be58c74f5e471eb396e9ebf5110a1/aistore-1.18.0.tar.gz", hash = "sha256:281244d4321b4059f0ad19737fa69fae5f3ca40db2392e5da8916131c74b66b6", size = 126859, upload-time = "2025-12-05T23:00:31.27Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/96/ffa0fd5af63cb28c0e7ea0f01709aaf0a19154726a6f608cb52f7d99f665/aistore-1.17.0-py3-none-any.whl", hash = "sha256:50a52428d679267cf967cdc232f27ab5a93b6d6f14fe0d1af18bb095eb695527", size = 179298, upload-time = "2025-10-16T19:04:51.138Z" }, + { url = "https://files.pythonhosted.org/packages/ac/bf/45bebb0df1237b24e8237fa66c0195e7337be7d52f7c351bf3bfb98c591e/aistore-1.18.0-py3-none-any.whl", hash = "sha256:ac9949699fa69ff8b4d90199acd16c0f17fde300a299d7e55369bd96fcf1427a", size = 183122, upload-time = "2025-12-05T23:00:29.715Z" }, ] [[package]] @@ -584,6 +584,91 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "coverage" +version = "7.10.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/14/70/025b179c993f019105b79575ac6edb5e084fb0f0e63f15cdebef4e454fb5/coverage-7.10.6.tar.gz", hash = "sha256:f644a3ae5933a552a29dbb9aa2f90c677a875f80ebea028e5a52a4f429044b90", size = 823736, upload-time = "2025-08-29T15:35:16.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/1d/2e64b43d978b5bd184e0756a41415597dfef30fcbd90b747474bd749d45f/coverage-7.10.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70e7bfbd57126b5554aa482691145f798d7df77489a177a6bef80de78860a356", size = 217025, upload-time = "2025-08-29T15:32:57.169Z" }, + { url = "https://files.pythonhosted.org/packages/23/62/b1e0f513417c02cc10ef735c3ee5186df55f190f70498b3702d516aad06f/coverage-7.10.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e41be6f0f19da64af13403e52f2dec38bbc2937af54df8ecef10850ff8d35301", size = 217419, upload-time = "2025-08-29T15:32:59.908Z" }, + { url = "https://files.pythonhosted.org/packages/e7/16/b800640b7a43e7c538429e4d7223e0a94fd72453a1a048f70bf766f12e96/coverage-7.10.6-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c61fc91ab80b23f5fddbee342d19662f3d3328173229caded831aa0bd7595460", size = 244180, upload-time = "2025-08-29T15:33:01.608Z" }, + { url = "https://files.pythonhosted.org/packages/fb/6f/5e03631c3305cad187eaf76af0b559fff88af9a0b0c180d006fb02413d7a/coverage-7.10.6-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10356fdd33a7cc06e8051413140bbdc6f972137508a3572e3f59f805cd2832fd", size = 245992, upload-time = "2025-08-29T15:33:03.239Z" }, + { url = "https://files.pythonhosted.org/packages/eb/a1/f30ea0fb400b080730125b490771ec62b3375789f90af0bb68bfb8a921d7/coverage-7.10.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80b1695cf7c5ebe7b44bf2521221b9bb8cdf69b1f24231149a7e3eb1ae5fa2fb", size = 247851, upload-time = "2025-08-29T15:33:04.603Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/cfa8fee8e8ef9a6bb76c7bef039f3302f44e615d2194161a21d3d83ac2e9/coverage-7.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2e4c33e6378b9d52d3454bd08847a8651f4ed23ddbb4a0520227bd346382bbc6", size = 245891, upload-time = "2025-08-29T15:33:06.176Z" }, + { url = "https://files.pythonhosted.org/packages/93/a9/51be09b75c55c4f6c16d8d73a6a1d46ad764acca0eab48fa2ffaef5958fe/coverage-7.10.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c8a3ec16e34ef980a46f60dc6ad86ec60f763c3f2fa0db6d261e6e754f72e945", size = 243909, upload-time = "2025-08-29T15:33:07.74Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a6/ba188b376529ce36483b2d585ca7bdac64aacbe5aa10da5978029a9c94db/coverage-7.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7d79dabc0a56f5af990cc6da9ad1e40766e82773c075f09cc571e2076fef882e", size = 244786, upload-time = "2025-08-29T15:33:08.965Z" }, + { url = "https://files.pythonhosted.org/packages/d0/4c/37ed872374a21813e0d3215256180c9a382c3f5ced6f2e5da0102fc2fd3e/coverage-7.10.6-cp310-cp310-win32.whl", hash = "sha256:86b9b59f2b16e981906e9d6383eb6446d5b46c278460ae2c36487667717eccf1", size = 219521, upload-time = "2025-08-29T15:33:10.599Z" }, + { url = "https://files.pythonhosted.org/packages/8e/36/9311352fdc551dec5b973b61f4e453227ce482985a9368305880af4f85dd/coverage-7.10.6-cp310-cp310-win_amd64.whl", hash = "sha256:e132b9152749bd33534e5bd8565c7576f135f157b4029b975e15ee184325f528", size = 220417, upload-time = "2025-08-29T15:33:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/d4/16/2bea27e212c4980753d6d563a0803c150edeaaddb0771a50d2afc410a261/coverage-7.10.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c706db3cabb7ceef779de68270150665e710b46d56372455cd741184f3868d8f", size = 217129, upload-time = "2025-08-29T15:33:13.575Z" }, + { url = "https://files.pythonhosted.org/packages/2a/51/e7159e068831ab37e31aac0969d47b8c5ee25b7d307b51e310ec34869315/coverage-7.10.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e0c38dc289e0508ef68ec95834cb5d2e96fdbe792eaccaa1bccac3966bbadcc", size = 217532, upload-time = "2025-08-29T15:33:14.872Z" }, + { url = "https://files.pythonhosted.org/packages/e7/c0/246ccbea53d6099325d25cd208df94ea435cd55f0db38099dd721efc7a1f/coverage-7.10.6-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:752a3005a1ded28f2f3a6e8787e24f28d6abe176ca64677bcd8d53d6fe2ec08a", size = 247931, upload-time = "2025-08-29T15:33:16.142Z" }, + { url = "https://files.pythonhosted.org/packages/7d/fb/7435ef8ab9b2594a6e3f58505cc30e98ae8b33265d844007737946c59389/coverage-7.10.6-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:689920ecfd60f992cafca4f5477d55720466ad2c7fa29bb56ac8d44a1ac2b47a", size = 249864, upload-time = "2025-08-29T15:33:17.434Z" }, + { url = "https://files.pythonhosted.org/packages/51/f8/d9d64e8da7bcddb094d511154824038833c81e3a039020a9d6539bf303e9/coverage-7.10.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec98435796d2624d6905820a42f82149ee9fc4f2d45c2c5bc5a44481cc50db62", size = 251969, upload-time = "2025-08-29T15:33:18.822Z" }, + { url = "https://files.pythonhosted.org/packages/43/28/c43ba0ef19f446d6463c751315140d8f2a521e04c3e79e5c5fe211bfa430/coverage-7.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b37201ce4a458c7a758ecc4efa92fa8ed783c66e0fa3c42ae19fc454a0792153", size = 249659, upload-time = "2025-08-29T15:33:20.407Z" }, + { url = "https://files.pythonhosted.org/packages/79/3e/53635bd0b72beaacf265784508a0b386defc9ab7fad99ff95f79ce9db555/coverage-7.10.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:2904271c80898663c810a6b067920a61dd8d38341244a3605bd31ab55250dad5", size = 247714, upload-time = "2025-08-29T15:33:21.751Z" }, + { url = "https://files.pythonhosted.org/packages/4c/55/0964aa87126624e8c159e32b0bc4e84edef78c89a1a4b924d28dd8265625/coverage-7.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5aea98383463d6e1fa4e95416d8de66f2d0cb588774ee20ae1b28df826bcb619", size = 248351, upload-time = "2025-08-29T15:33:23.105Z" }, + { url = "https://files.pythonhosted.org/packages/eb/ab/6cfa9dc518c6c8e14a691c54e53a9433ba67336c760607e299bfcf520cb1/coverage-7.10.6-cp311-cp311-win32.whl", hash = "sha256:e3fb1fa01d3598002777dd259c0c2e6d9d5e10e7222976fc8e03992f972a2cba", size = 219562, upload-time = "2025-08-29T15:33:24.717Z" }, + { url = "https://files.pythonhosted.org/packages/5b/18/99b25346690cbc55922e7cfef06d755d4abee803ef335baff0014268eff4/coverage-7.10.6-cp311-cp311-win_amd64.whl", hash = "sha256:f35ed9d945bece26553d5b4c8630453169672bea0050a564456eb88bdffd927e", size = 220453, upload-time = "2025-08-29T15:33:26.482Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ed/81d86648a07ccb124a5cf1f1a7788712b8d7216b593562683cd5c9b0d2c1/coverage-7.10.6-cp311-cp311-win_arm64.whl", hash = "sha256:99e1a305c7765631d74b98bf7dbf54eeea931f975e80f115437d23848ee8c27c", size = 219127, upload-time = "2025-08-29T15:33:27.777Z" }, + { url = "https://files.pythonhosted.org/packages/26/06/263f3305c97ad78aab066d116b52250dd316e74fcc20c197b61e07eb391a/coverage-7.10.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5b2dd6059938063a2c9fee1af729d4f2af28fd1a545e9b7652861f0d752ebcea", size = 217324, upload-time = "2025-08-29T15:33:29.06Z" }, + { url = "https://files.pythonhosted.org/packages/e9/60/1e1ded9a4fe80d843d7d53b3e395c1db3ff32d6c301e501f393b2e6c1c1f/coverage-7.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:388d80e56191bf846c485c14ae2bc8898aa3124d9d35903fef7d907780477634", size = 217560, upload-time = "2025-08-29T15:33:30.748Z" }, + { url = "https://files.pythonhosted.org/packages/b8/25/52136173c14e26dfed8b106ed725811bb53c30b896d04d28d74cb64318b3/coverage-7.10.6-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:90cb5b1a4670662719591aa92d0095bb41714970c0b065b02a2610172dbf0af6", size = 249053, upload-time = "2025-08-29T15:33:32.041Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1d/ae25a7dc58fcce8b172d42ffe5313fc267afe61c97fa872b80ee72d9515a/coverage-7.10.6-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:961834e2f2b863a0e14260a9a273aff07ff7818ab6e66d2addf5628590c628f9", size = 251802, upload-time = "2025-08-29T15:33:33.625Z" }, + { url = "https://files.pythonhosted.org/packages/f5/7a/1f561d47743710fe996957ed7c124b421320f150f1d38523d8d9102d3e2a/coverage-7.10.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf9a19f5012dab774628491659646335b1928cfc931bf8d97b0d5918dd58033c", size = 252935, upload-time = "2025-08-29T15:33:34.909Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ad/8b97cd5d28aecdfde792dcbf646bac141167a5cacae2cd775998b45fabb5/coverage-7.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99c4283e2a0e147b9c9cc6bc9c96124de9419d6044837e9799763a0e29a7321a", size = 250855, upload-time = "2025-08-29T15:33:36.922Z" }, + { url = "https://files.pythonhosted.org/packages/33/6a/95c32b558d9a61858ff9d79580d3877df3eb5bc9eed0941b1f187c89e143/coverage-7.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:282b1b20f45df57cc508c1e033403f02283adfb67d4c9c35a90281d81e5c52c5", size = 248974, upload-time = "2025-08-29T15:33:38.175Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9c/8ce95dee640a38e760d5b747c10913e7a06554704d60b41e73fdea6a1ffd/coverage-7.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cdbe264f11afd69841bd8c0d83ca10b5b32853263ee62e6ac6a0ab63895f972", size = 250409, upload-time = "2025-08-29T15:33:39.447Z" }, + { url = "https://files.pythonhosted.org/packages/04/12/7a55b0bdde78a98e2eb2356771fd2dcddb96579e8342bb52aa5bc52e96f0/coverage-7.10.6-cp312-cp312-win32.whl", hash = "sha256:a517feaf3a0a3eca1ee985d8373135cfdedfbba3882a5eab4362bda7c7cf518d", size = 219724, upload-time = "2025-08-29T15:33:41.172Z" }, + { url = "https://files.pythonhosted.org/packages/36/4a/32b185b8b8e327802c9efce3d3108d2fe2d9d31f153a0f7ecfd59c773705/coverage-7.10.6-cp312-cp312-win_amd64.whl", hash = "sha256:856986eadf41f52b214176d894a7de05331117f6035a28ac0016c0f63d887629", size = 220536, upload-time = "2025-08-29T15:33:42.524Z" }, + { url = "https://files.pythonhosted.org/packages/08/3a/d5d8dc703e4998038c3099eaf77adddb00536a3cec08c8dcd556a36a3eb4/coverage-7.10.6-cp312-cp312-win_arm64.whl", hash = "sha256:acf36b8268785aad739443fa2780c16260ee3fa09d12b3a70f772ef100939d80", size = 219171, upload-time = "2025-08-29T15:33:43.974Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e7/917e5953ea29a28c1057729c1d5af9084ab6d9c66217523fd0e10f14d8f6/coverage-7.10.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ffea0575345e9ee0144dfe5701aa17f3ba546f8c3bb48db62ae101afb740e7d6", size = 217351, upload-time = "2025-08-29T15:33:45.438Z" }, + { url = "https://files.pythonhosted.org/packages/eb/86/2e161b93a4f11d0ea93f9bebb6a53f113d5d6e416d7561ca41bb0a29996b/coverage-7.10.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d91d7317cde40a1c249d6b7382750b7e6d86fad9d8eaf4fa3f8f44cf171e80", size = 217600, upload-time = "2025-08-29T15:33:47.269Z" }, + { url = "https://files.pythonhosted.org/packages/0e/66/d03348fdd8df262b3a7fb4ee5727e6e4936e39e2f3a842e803196946f200/coverage-7.10.6-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3e23dd5408fe71a356b41baa82892772a4cefcf758f2ca3383d2aa39e1b7a003", size = 248600, upload-time = "2025-08-29T15:33:48.953Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/508420fb47d09d904d962f123221bc249f64b5e56aa93d5f5f7603be475f/coverage-7.10.6-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0f3f56e4cb573755e96a16501a98bf211f100463d70275759e73f3cbc00d4f27", size = 251206, upload-time = "2025-08-29T15:33:50.697Z" }, + { url = "https://files.pythonhosted.org/packages/e9/1f/9020135734184f439da85c70ea78194c2730e56c2d18aee6e8ff1719d50d/coverage-7.10.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db4a1d897bbbe7339946ffa2fe60c10cc81c43fab8b062d3fcb84188688174a4", size = 252478, upload-time = "2025-08-29T15:33:52.303Z" }, + { url = "https://files.pythonhosted.org/packages/a4/a4/3d228f3942bb5a2051fde28c136eea23a761177dc4ff4ef54533164ce255/coverage-7.10.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d8fd7879082953c156d5b13c74aa6cca37f6a6f4747b39538504c3f9c63d043d", size = 250637, upload-time = "2025-08-29T15:33:53.67Z" }, + { url = "https://files.pythonhosted.org/packages/36/e3/293dce8cdb9a83de971637afc59b7190faad60603b40e32635cbd15fbf61/coverage-7.10.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:28395ca3f71cd103b8c116333fa9db867f3a3e1ad6a084aa3725ae002b6583bc", size = 248529, upload-time = "2025-08-29T15:33:55.022Z" }, + { url = "https://files.pythonhosted.org/packages/90/26/64eecfa214e80dd1d101e420cab2901827de0e49631d666543d0e53cf597/coverage-7.10.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:61c950fc33d29c91b9e18540e1aed7d9f6787cc870a3e4032493bbbe641d12fc", size = 250143, upload-time = "2025-08-29T15:33:56.386Z" }, + { url = "https://files.pythonhosted.org/packages/3e/70/bd80588338f65ea5b0d97e424b820fb4068b9cfb9597fbd91963086e004b/coverage-7.10.6-cp313-cp313-win32.whl", hash = "sha256:160c00a5e6b6bdf4e5984b0ef21fc860bc94416c41b7df4d63f536d17c38902e", size = 219770, upload-time = "2025-08-29T15:33:58.063Z" }, + { url = "https://files.pythonhosted.org/packages/a7/14/0b831122305abcc1060c008f6c97bbdc0a913ab47d65070a01dc50293c2b/coverage-7.10.6-cp313-cp313-win_amd64.whl", hash = "sha256:628055297f3e2aa181464c3808402887643405573eb3d9de060d81531fa79d32", size = 220566, upload-time = "2025-08-29T15:33:59.766Z" }, + { url = "https://files.pythonhosted.org/packages/83/c6/81a83778c1f83f1a4a168ed6673eeedc205afb562d8500175292ca64b94e/coverage-7.10.6-cp313-cp313-win_arm64.whl", hash = "sha256:df4ec1f8540b0bcbe26ca7dd0f541847cc8a108b35596f9f91f59f0c060bfdd2", size = 219195, upload-time = "2025-08-29T15:34:01.191Z" }, + { url = "https://files.pythonhosted.org/packages/d7/1c/ccccf4bf116f9517275fa85047495515add43e41dfe8e0bef6e333c6b344/coverage-7.10.6-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c9a8b7a34a4de3ed987f636f71881cd3b8339f61118b1aa311fbda12741bff0b", size = 218059, upload-time = "2025-08-29T15:34:02.91Z" }, + { url = "https://files.pythonhosted.org/packages/92/97/8a3ceff833d27c7492af4f39d5da6761e9ff624831db9e9f25b3886ddbca/coverage-7.10.6-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8dd5af36092430c2b075cee966719898f2ae87b636cefb85a653f1d0ba5d5393", size = 218287, upload-time = "2025-08-29T15:34:05.106Z" }, + { url = "https://files.pythonhosted.org/packages/92/d8/50b4a32580cf41ff0423777a2791aaf3269ab60c840b62009aec12d3970d/coverage-7.10.6-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b0353b0f0850d49ada66fdd7d0c7cdb0f86b900bb9e367024fd14a60cecc1e27", size = 259625, upload-time = "2025-08-29T15:34:06.575Z" }, + { url = "https://files.pythonhosted.org/packages/7e/7e/6a7df5a6fb440a0179d94a348eb6616ed4745e7df26bf2a02bc4db72c421/coverage-7.10.6-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6b9ae13d5d3e8aeca9ca94198aa7b3ebbc5acfada557d724f2a1f03d2c0b0df", size = 261801, upload-time = "2025-08-29T15:34:08.006Z" }, + { url = "https://files.pythonhosted.org/packages/3a/4c/a270a414f4ed5d196b9d3d67922968e768cd971d1b251e1b4f75e9362f75/coverage-7.10.6-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:675824a363cc05781b1527b39dc2587b8984965834a748177ee3c37b64ffeafb", size = 264027, upload-time = "2025-08-29T15:34:09.806Z" }, + { url = "https://files.pythonhosted.org/packages/9c/8b/3210d663d594926c12f373c5370bf1e7c5c3a427519a8afa65b561b9a55c/coverage-7.10.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:692d70ea725f471a547c305f0d0fc6a73480c62fb0da726370c088ab21aed282", size = 261576, upload-time = "2025-08-29T15:34:11.585Z" }, + { url = "https://files.pythonhosted.org/packages/72/d0/e1961eff67e9e1dba3fc5eb7a4caf726b35a5b03776892da8d79ec895775/coverage-7.10.6-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:851430a9a361c7a8484a36126d1d0ff8d529d97385eacc8dfdc9bfc8c2d2cbe4", size = 259341, upload-time = "2025-08-29T15:34:13.159Z" }, + { url = "https://files.pythonhosted.org/packages/3a/06/d6478d152cd189b33eac691cba27a40704990ba95de49771285f34a5861e/coverage-7.10.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d9369a23186d189b2fc95cc08b8160ba242057e887d766864f7adf3c46b2df21", size = 260468, upload-time = "2025-08-29T15:34:14.571Z" }, + { url = "https://files.pythonhosted.org/packages/ed/73/737440247c914a332f0b47f7598535b29965bf305e19bbc22d4c39615d2b/coverage-7.10.6-cp313-cp313t-win32.whl", hash = "sha256:92be86fcb125e9bda0da7806afd29a3fd33fdf58fba5d60318399adf40bf37d0", size = 220429, upload-time = "2025-08-29T15:34:16.394Z" }, + { url = "https://files.pythonhosted.org/packages/bd/76/b92d3214740f2357ef4a27c75a526eb6c28f79c402e9f20a922c295c05e2/coverage-7.10.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6b3039e2ca459a70c79523d39347d83b73f2f06af5624905eba7ec34d64d80b5", size = 221493, upload-time = "2025-08-29T15:34:17.835Z" }, + { url = "https://files.pythonhosted.org/packages/fc/8e/6dcb29c599c8a1f654ec6cb68d76644fe635513af16e932d2d4ad1e5ac6e/coverage-7.10.6-cp313-cp313t-win_arm64.whl", hash = "sha256:3fb99d0786fe17b228eab663d16bee2288e8724d26a199c29325aac4b0319b9b", size = 219757, upload-time = "2025-08-29T15:34:19.248Z" }, + { url = "https://files.pythonhosted.org/packages/d3/aa/76cf0b5ec00619ef208da4689281d48b57f2c7fde883d14bf9441b74d59f/coverage-7.10.6-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:6008a021907be8c4c02f37cdc3ffb258493bdebfeaf9a839f9e71dfdc47b018e", size = 217331, upload-time = "2025-08-29T15:34:20.846Z" }, + { url = "https://files.pythonhosted.org/packages/65/91/8e41b8c7c505d398d7730206f3cbb4a875a35ca1041efc518051bfce0f6b/coverage-7.10.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5e75e37f23eb144e78940b40395b42f2321951206a4f50e23cfd6e8a198d3ceb", size = 217607, upload-time = "2025-08-29T15:34:22.433Z" }, + { url = "https://files.pythonhosted.org/packages/87/7f/f718e732a423d442e6616580a951b8d1ec3575ea48bcd0e2228386805e79/coverage-7.10.6-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0f7cb359a448e043c576f0da00aa8bfd796a01b06aa610ca453d4dde09cc1034", size = 248663, upload-time = "2025-08-29T15:34:24.425Z" }, + { url = "https://files.pythonhosted.org/packages/e6/52/c1106120e6d801ac03e12b5285e971e758e925b6f82ee9b86db3aa10045d/coverage-7.10.6-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c68018e4fc4e14b5668f1353b41ccf4bc83ba355f0e1b3836861c6f042d89ac1", size = 251197, upload-time = "2025-08-29T15:34:25.906Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ec/3a8645b1bb40e36acde9c0609f08942852a4af91a937fe2c129a38f2d3f5/coverage-7.10.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cd4b2b0707fc55afa160cd5fc33b27ccbf75ca11d81f4ec9863d5793fc6df56a", size = 252551, upload-time = "2025-08-29T15:34:27.337Z" }, + { url = "https://files.pythonhosted.org/packages/a1/70/09ecb68eeb1155b28a1d16525fd3a9b65fbe75337311a99830df935d62b6/coverage-7.10.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cec13817a651f8804a86e4f79d815b3b28472c910e099e4d5a0e8a3b6a1d4cb", size = 250553, upload-time = "2025-08-29T15:34:29.065Z" }, + { url = "https://files.pythonhosted.org/packages/c6/80/47df374b893fa812e953b5bc93dcb1427a7b3d7a1a7d2db33043d17f74b9/coverage-7.10.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f2a6a8e06bbda06f78739f40bfb56c45d14eb8249d0f0ea6d4b3d48e1f7c695d", size = 248486, upload-time = "2025-08-29T15:34:30.897Z" }, + { url = "https://files.pythonhosted.org/packages/4a/65/9f98640979ecee1b0d1a7164b589de720ddf8100d1747d9bbdb84be0c0fb/coverage-7.10.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:081b98395ced0d9bcf60ada7661a0b75f36b78b9d7e39ea0790bb4ed8da14747", size = 249981, upload-time = "2025-08-29T15:34:32.365Z" }, + { url = "https://files.pythonhosted.org/packages/1f/55/eeb6603371e6629037f47bd25bef300387257ed53a3c5fdb159b7ac8c651/coverage-7.10.6-cp314-cp314-win32.whl", hash = "sha256:6937347c5d7d069ee776b2bf4e1212f912a9f1f141a429c475e6089462fcecc5", size = 220054, upload-time = "2025-08-29T15:34:34.124Z" }, + { url = "https://files.pythonhosted.org/packages/15/d1/a0912b7611bc35412e919a2cd59ae98e7ea3b475e562668040a43fb27897/coverage-7.10.6-cp314-cp314-win_amd64.whl", hash = "sha256:adec1d980fa07e60b6ef865f9e5410ba760e4e1d26f60f7e5772c73b9a5b0713", size = 220851, upload-time = "2025-08-29T15:34:35.651Z" }, + { url = "https://files.pythonhosted.org/packages/ef/2d/11880bb8ef80a45338e0b3e0725e4c2d73ffbb4822c29d987078224fd6a5/coverage-7.10.6-cp314-cp314-win_arm64.whl", hash = "sha256:a80f7aef9535442bdcf562e5a0d5a5538ce8abe6bb209cfbf170c462ac2c2a32", size = 219429, upload-time = "2025-08-29T15:34:37.16Z" }, + { url = "https://files.pythonhosted.org/packages/83/c0/1f00caad775c03a700146f55536ecd097a881ff08d310a58b353a1421be0/coverage-7.10.6-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:0de434f4fbbe5af4fa7989521c655c8c779afb61c53ab561b64dcee6149e4c65", size = 218080, upload-time = "2025-08-29T15:34:38.919Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c4/b1c5d2bd7cc412cbeb035e257fd06ed4e3e139ac871d16a07434e145d18d/coverage-7.10.6-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6e31b8155150c57e5ac43ccd289d079eb3f825187d7c66e755a055d2c85794c6", size = 218293, upload-time = "2025-08-29T15:34:40.425Z" }, + { url = "https://files.pythonhosted.org/packages/3f/07/4468d37c94724bf6ec354e4ec2f205fda194343e3e85fd2e59cec57e6a54/coverage-7.10.6-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:98cede73eb83c31e2118ae8d379c12e3e42736903a8afcca92a7218e1f2903b0", size = 259800, upload-time = "2025-08-29T15:34:41.996Z" }, + { url = "https://files.pythonhosted.org/packages/82/d8/f8fb351be5fee31690cd8da768fd62f1cfab33c31d9f7baba6cd8960f6b8/coverage-7.10.6-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f863c08f4ff6b64fa8045b1e3da480f5374779ef187f07b82e0538c68cb4ff8e", size = 261965, upload-time = "2025-08-29T15:34:43.61Z" }, + { url = "https://files.pythonhosted.org/packages/e8/70/65d4d7cfc75c5c6eb2fed3ee5cdf420fd8ae09c4808723a89a81d5b1b9c3/coverage-7.10.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b38261034fda87be356f2c3f42221fdb4171c3ce7658066ae449241485390d5", size = 264220, upload-time = "2025-08-29T15:34:45.387Z" }, + { url = "https://files.pythonhosted.org/packages/98/3c/069df106d19024324cde10e4ec379fe2fb978017d25e97ebee23002fbadf/coverage-7.10.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e93b1476b79eae849dc3872faeb0bf7948fd9ea34869590bc16a2a00b9c82a7", size = 261660, upload-time = "2025-08-29T15:34:47.288Z" }, + { url = "https://files.pythonhosted.org/packages/fc/8a/2974d53904080c5dc91af798b3a54a4ccb99a45595cc0dcec6eb9616a57d/coverage-7.10.6-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ff8a991f70f4c0cf53088abf1e3886edcc87d53004c7bb94e78650b4d3dac3b5", size = 259417, upload-time = "2025-08-29T15:34:48.779Z" }, + { url = "https://files.pythonhosted.org/packages/30/38/9616a6b49c686394b318974d7f6e08f38b8af2270ce7488e879888d1e5db/coverage-7.10.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ac765b026c9f33044419cbba1da913cfb82cca1b60598ac1c7a5ed6aac4621a0", size = 260567, upload-time = "2025-08-29T15:34:50.718Z" }, + { url = "https://files.pythonhosted.org/packages/76/16/3ed2d6312b371a8cf804abf4e14895b70e4c3491c6e53536d63fd0958a8d/coverage-7.10.6-cp314-cp314t-win32.whl", hash = "sha256:441c357d55f4936875636ef2cfb3bee36e466dcf50df9afbd398ce79dba1ebb7", size = 220831, upload-time = "2025-08-29T15:34:52.653Z" }, + { url = "https://files.pythonhosted.org/packages/d5/e5/d38d0cb830abede2adb8b147770d2a3d0e7fecc7228245b9b1ae6c24930a/coverage-7.10.6-cp314-cp314t-win_amd64.whl", hash = "sha256:073711de3181b2e204e4870ac83a7c4853115b42e9cd4d145f2231e12d670930", size = 221950, upload-time = "2025-08-29T15:34:54.212Z" }, + { url = "https://files.pythonhosted.org/packages/f4/51/e48e550f6279349895b0ffcd6d2a690e3131ba3a7f4eafccc141966d4dea/coverage-7.10.6-cp314-cp314t-win_arm64.whl", hash = "sha256:137921f2bac5559334ba66122b753db6dc5d1cf01eb7b64eb412bb0d064ef35b", size = 219969, upload-time = "2025-08-29T15:34:55.83Z" }, + { url = "https://files.pythonhosted.org/packages/44/0c/50db5379b615854b5cf89146f8f5bd1d5a9693d7f3a987e269693521c404/coverage-7.10.6-py3-none-any.whl", hash = "sha256:92c4ecf6bf11b2e85fd4d8204814dc26e6a19f0c9d938c207c5cb0eadfcabbe3", size = 208986, upload-time = "2025-08-29T15:35:14.506Z" }, +] + [[package]] name = "cryptography" version = "3.4.7" @@ -644,6 +729,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/cf/9dfc5616f103648f483c1595f05d2ac96df2dfec915351f507f7a500a38d/ebmlite-3.3.1-py3-none-any.whl", hash = "sha256:59285c472de1a6b92a4caf758b2b634a72a1468a94f12ebdb003202a07f01edf", size = 92152, upload-time = "2022-12-13T22:29:51.41Z" }, ] +[[package]] +name = "exceptiongroup" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -767,7 +864,7 @@ wheels = [ [[package]] name = "google-api-core" -version = "2.24.2" +version = "2.28.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth" }, @@ -776,9 +873,9 @@ dependencies = [ { name = "protobuf" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/09/5c/085bcb872556934bb119e5e09de54daa07873f6866b8f0303c49e72287f7/google_api_core-2.24.2.tar.gz", hash = "sha256:81718493daf06d96d6bc76a91c23874dbf2fac0adbbf542831b805ee6e974696", size = 163516, upload-time = "2025-03-10T15:55:26.201Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/da/83d7043169ac2c8c7469f0e375610d78ae2160134bf1b80634c482fa079c/google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8", size = 176759, upload-time = "2025-10-28T21:34:51.529Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/95/f472d85adab6e538da2025dfca9e976a0d125cc0af2301f190e77b76e51c/google_api_core-2.24.2-py3-none-any.whl", hash = "sha256:810a63ac95f3c441b7c0e43d344e372887f62ce9071ba972eacf32672e072de9", size = 160061, upload-time = "2025-03-10T15:55:24.386Z" }, + { url = "https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c", size = 173706, upload-time = "2025-10-28T21:34:50.151Z" }, ] [[package]] @@ -810,7 +907,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.4.0" +version = "3.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -820,9 +917,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/a6/6e0a318f70975a3c048c0e1a18aee4f7b6d7dac1e798fdc5353c5248d418/google_cloud_storage-3.4.0.tar.gz", hash = "sha256:4c77ec00c98ccc6428e4c39404926f41e2152f48809b02af29d5116645c3c317", size = 17226847, upload-time = "2025-09-15T10:40:05.045Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/cd/7e112cf025b2b591067b599e4bfe965df0c12b0cc0afdb5556469bff126d/google_cloud_storage-3.6.0.tar.gz", hash = "sha256:29cc6b9a6c0fc9cdad071e375d540a5a50fbc9a7fad8300fa02fb904f6fe2ca2", size = 17251072, upload-time = "2025-11-17T10:18:29.81Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/12/164a90e4692423ed5532274928b0e19c8cae345ae1aa413d78c6b688231b/google_cloud_storage-3.4.0-py3-none-any.whl", hash = "sha256:16eeca305e4747a6871f8f7627eef3b862fdd365b872ca74d4a89e9841d0f8e8", size = 278423, upload-time = "2025-09-15T10:40:03.349Z" }, + { url = "https://files.pythonhosted.org/packages/ae/ef/3b57bf617ee0c79450c1ff211d1eb888db8fc1050ac74b3e52cc6ed86e63/google_cloud_storage-3.6.0-py3-none-any.whl", hash = "sha256:5decbdddd63b7d1fc3e266a393ad6453d2e27d172bd982b1e2f15481668db097", size = 299039, upload-time = "2025-11-17T10:18:27.66Z" }, ] [[package]] @@ -918,22 +1015,36 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.1.10" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/74/31/feeddfce1748c4a233ec1aa5b7396161c07ae1aa9b7bdbc9a72c3c7dd768/hf_xet-1.1.10.tar.gz", hash = "sha256:408aef343800a2102374a883f283ff29068055c111f003ff840733d3b715bb97", size = 487910, upload-time = "2025-09-12T20:10:27.12Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/a2/343e6d05de96908366bdc0081f2d8607d61200be2ac802769c4284cc65bd/hf_xet-1.1.10-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:686083aca1a6669bc85c21c0563551cbcdaa5cf7876a91f3d074a030b577231d", size = 2761466, upload-time = "2025-09-12T20:10:22.836Z" }, - { url = "https://files.pythonhosted.org/packages/31/f9/6215f948ac8f17566ee27af6430ea72045e0418ce757260248b483f4183b/hf_xet-1.1.10-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:71081925383b66b24eedff3013f8e6bbd41215c3338be4b94ba75fd75b21513b", size = 2623807, upload-time = "2025-09-12T20:10:21.118Z" }, - { url = "https://files.pythonhosted.org/packages/15/07/86397573efefff941e100367bbda0b21496ffcdb34db7ab51912994c32a2/hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6bceb6361c80c1cc42b5a7b4e3efd90e64630bcf11224dcac50ef30a47e435", size = 3186960, upload-time = "2025-09-12T20:10:19.336Z" }, - { url = "https://files.pythonhosted.org/packages/01/a7/0b2e242b918cc30e1f91980f3c4b026ff2eedaf1e2ad96933bca164b2869/hf_xet-1.1.10-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eae7c1fc8a664e54753ffc235e11427ca61f4b0477d757cc4eb9ae374b69f09c", size = 3087167, upload-time = "2025-09-12T20:10:17.255Z" }, - { url = "https://files.pythonhosted.org/packages/4a/25/3e32ab61cc7145b11eee9d745988e2f0f4fafda81b25980eebf97d8cff15/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0a0005fd08f002180f7a12d4e13b22be277725bc23ed0529f8add5c7a6309c06", size = 3248612, upload-time = "2025-09-12T20:10:24.093Z" }, - { url = "https://files.pythonhosted.org/packages/2c/3d/ab7109e607ed321afaa690f557a9ada6d6d164ec852fd6bf9979665dc3d6/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f900481cf6e362a6c549c61ff77468bd59d6dd082f3170a36acfef2eb6a6793f", size = 3353360, upload-time = "2025-09-12T20:10:25.563Z" }, - { url = "https://files.pythonhosted.org/packages/ee/0e/471f0a21db36e71a2f1752767ad77e92d8cde24e974e03d662931b1305ec/hf_xet-1.1.10-cp37-abi3-win_amd64.whl", hash = "sha256:5f54b19cc347c13235ae7ee98b330c26dd65ef1df47e5316ffb1e87713ca7045", size = 2804691, upload-time = "2025-09-12T20:10:28.433Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" }, + { url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" }, + { url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" }, + { url = "https://files.pythonhosted.org/packages/e2/51/f7e2caae42f80af886db414d4e9885fac959330509089f97cccb339c6b87/hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e", size = 2861861, upload-time = "2025-10-24T19:04:19.01Z" }, + { url = "https://files.pythonhosted.org/packages/6e/1d/a641a88b69994f9371bd347f1dd35e5d1e2e2460a2e350c8d5165fc62005/hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8", size = 2717699, upload-time = "2025-10-24T19:04:17.306Z" }, + { url = "https://files.pythonhosted.org/packages/df/e0/e5e9bba7d15f0318955f7ec3f4af13f92e773fbb368c0b8008a5acbcb12f/hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0", size = 3314885, upload-time = "2025-10-24T19:04:07.642Z" }, + { url = "https://files.pythonhosted.org/packages/21/90/b7fe5ff6f2b7b8cbdf1bd56145f863c90a5807d9758a549bf3d916aa4dec/hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090", size = 3221550, upload-time = "2025-10-24T19:04:05.55Z" }, + { url = "https://files.pythonhosted.org/packages/6f/cb/73f276f0a7ce46cc6a6ec7d6c7d61cbfe5f2e107123d9bbd0193c355f106/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a", size = 3408010, upload-time = "2025-10-24T19:04:28.598Z" }, + { url = "https://files.pythonhosted.org/packages/b8/1e/d642a12caa78171f4be64f7cd9c40e3ca5279d055d0873188a58c0f5fbb9/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f", size = 3503264, upload-time = "2025-10-24T19:04:30.397Z" }, + { url = "https://files.pythonhosted.org/packages/17/b5/33764714923fa1ff922770f7ed18c2daae034d21ae6e10dbf4347c854154/hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc", size = 2901071, upload-time = "2025-10-24T19:04:37.463Z" }, + { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, + { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, + { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, + { url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" }, + { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, + { url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" }, ] [[package]] name = "huggingface-hub" -version = "0.35.3" +version = "0.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -945,9 +1056,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/7e/a0a97de7c73671863ca6b3f61fa12518caf35db37825e43d63a70956738c/huggingface_hub-0.35.3.tar.gz", hash = "sha256:350932eaa5cc6a4747efae85126ee220e4ef1b54e29d31c3b45c5612ddf0b32a", size = 461798, upload-time = "2025-09-29T14:29:58.625Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/a0/651f93d154cb72323358bf2bbae3e642bdb5d2f1bfc874d096f7cb159fa0/huggingface_hub-0.35.3-py3-none-any.whl", hash = "sha256:0e3a01829c19d86d03793e4577816fe3bdfc1602ac62c7fb220d593d351224ba", size = 564262, upload-time = "2025-09-29T14:29:55.813Z" }, + { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, ] [[package]] @@ -989,6 +1100,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971, upload-time = "2025-01-20T22:21:29.177Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1048,11 +1168,11 @@ wheels = [ [[package]] name = "lark" -version = "1.3.0" +version = "1.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/37/a13baf0135f348af608c667633cbe5d13aa2c5c15a56ae9ad3e6cba45ae3/lark-1.3.0.tar.gz", hash = "sha256:9a3839d0ca5e1faf7cfa3460e420e859b66bcbde05b634e73c369c8244c5fa48", size = 259551, upload-time = "2025-09-22T13:45:05.072Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/3e/1c6b43277de64fc3c0333b0e72ab7b52ddaaea205210d60d9b9f83c3d0c7/lark-1.3.0-py3-none-any.whl", hash = "sha256:80661f261fb2584a9828a097a2432efd575af27d20be0fd35d17f0fe37253831", size = 113002, upload-time = "2025-09-22T13:45:03.747Z" }, + { url = "https://files.pythonhosted.org/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" }, ] [[package]] @@ -1207,7 +1327,9 @@ azure-storage-blob = [ { name = "multi-storage-client", extra = ["azure-storage-blob"] }, ] dev = [ + { name = "coverage" }, { name = "myst-parser" }, + { name = "pytest" }, { name = "ruff" }, { name = "soundfile" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -1244,6 +1366,7 @@ requires-dist = [ { name = "bitstring", marker = "extra == 'av-decode'", specifier = ">=4.2.3" }, { name = "braceexpand" }, { name = "click" }, + { name = "coverage", marker = "extra == 'dev'" }, { name = "ebmlite", marker = "extra == 'av-decode'", specifier = ">=3.3.1" }, { name = "filetype", marker = "extra == 'av-decode'", specifier = ">=1.2.0" }, { name = "filetype", marker = "extra == 'guess-content'", specifier = ">=1.0.0" }, @@ -1259,6 +1382,7 @@ requires-dist = [ { name = "numba", marker = "extra == 'tar-patcher'" }, { name = "numpy" }, { name = "pillow", specifier = ">=10.0.1" }, + { name = "pytest", marker = "extra == 'dev'" }, { name = "pyyaml" }, { name = "rapidyaml", specifier = ">=0.10.0" }, { name = "ruff", marker = "extra == 'dev'" }, @@ -1332,7 +1456,7 @@ wheels = [ [[package]] name = "multi-storage-client" -version = "0.33.0" +version = "0.37.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1345,26 +1469,23 @@ dependencies = [ { name = "python-dateutil" }, { name = "pyyaml" }, { name = "tqdm" }, + { name = "tzdata" }, { name = "wcmatch" }, { name = "xattr" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/c4/6279fb7d4b8b0a7af060047d592f00f8d49c547adfebe50bcd8d0d2dc8a5/multi_storage_client-0.33.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df52b3040ef5698c6388fa589bd63812ae0d2f967d358a792abcad5638686590", size = 5282006, upload-time = "2025-10-23T03:45:37.761Z" }, - { url = "https://files.pythonhosted.org/packages/22/3b/23d8beccd73b887c4552bf884275611255b5028388fa3317365cd56c2a93/multi_storage_client-0.33.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:370da04b1e56a601ba505a29d42fcabc19b583e10d725a37bc0c11ba3573d211", size = 5403083, upload-time = "2025-10-23T03:53:11.998Z" }, - { url = "https://files.pythonhosted.org/packages/b0/ad/dc355d05fd369da0d800e5f7de24da0393f542c5a6f775f6bcee7edcacb1/multi_storage_client-0.33.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c57749a28ec5d49440f465fd73e4e2feaab18ece9b6e57c73395308b41950f66", size = 3178432, upload-time = "2025-10-23T04:07:00.543Z" }, - { url = "https://files.pythonhosted.org/packages/e0/ad/97b54419d8a58f696b85504568391a627641152f80650d7d2697fc2702ed/multi_storage_client-0.33.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7d95f5fe094aab00a240bf6aa11dfe85bec293b76b3688ec3a9c33d86c751d2", size = 3351102, upload-time = "2025-10-23T03:47:47.622Z" }, - { url = "https://files.pythonhosted.org/packages/52/28/1038a68b9df1b179a61967ce9f7d2e80b9954cdb289801afecde5f7660db/multi_storage_client-0.33.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b5a0f5a0b7684835be20ae6782070884982a86665e9bab317375a56a20294d1", size = 5281523, upload-time = "2025-10-23T04:06:36.671Z" }, - { url = "https://files.pythonhosted.org/packages/6c/c5/e18de5e2a2671efdc0a12383b8d63f523044ca453525725b3450d0179c0e/multi_storage_client-0.33.0-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:0db694311f90f44ee8f6f7734a14a0857738a467f2ae201649218a3ecf1f6ab2", size = 5403353, upload-time = "2025-10-23T04:07:25.941Z" }, - { url = "https://files.pythonhosted.org/packages/7e/c9/d9f65eb2370151dbbb06925f4216ee017e6cdbf7657263fd98e60944e52b/multi_storage_client-0.33.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cbe3a0b856f0b968f9fc693670a521b5a995b625351241ca008f866fdfff62a", size = 3180052, upload-time = "2025-10-23T03:57:32.797Z" }, - { url = "https://files.pythonhosted.org/packages/e7/38/08b9d84c93b19ae87caf542ae77f17dfa44a85281ba09de660ffcf3a7718/multi_storage_client-0.33.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018e7e82255feeff973ff02563f11a30f5e507e4cbc87a2167a9568740144ef2", size = 3351389, upload-time = "2025-10-23T04:02:07.348Z" }, - { url = "https://files.pythonhosted.org/packages/6a/31/c95634a27723b5ba9d2d74158444cc5e40b151b51ae59ca196fc9993f039/multi_storage_client-0.33.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:030b3a592c6352605e9ebdb8d9303dd42daf5d171ffa684f3283d4a5c6e2edfe", size = 5273976, upload-time = "2025-10-23T04:04:35.99Z" }, - { url = "https://files.pythonhosted.org/packages/8c/cf/82d1778d73c3baaec331da4ae8d01fa7934bcd73336aa88a08d86d080347/multi_storage_client-0.33.0-cp312-cp312-macosx_11_0_x86_64.whl", hash = "sha256:14dc0ace16d3830917427d6376d14ef62bd053fb2509f893998555ca1e9c4dcb", size = 5400735, upload-time = "2025-10-23T03:58:37.149Z" }, - { url = "https://files.pythonhosted.org/packages/fc/34/a6194ec725ef80c02de58b5ed3520bb1711807df75a27f7214effd22df34/multi_storage_client-0.33.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2821765d5c6de365b5b1dcdc7cf2ebba719ff4061fd02975639629f8aa319f6", size = 3182623, upload-time = "2025-10-23T04:03:29.551Z" }, - { url = "https://files.pythonhosted.org/packages/8f/36/7ec85178fd1dd69c278407a82acaccfb806449deda13f3dbd41f653d73bd/multi_storage_client-0.33.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f92f89480c58067fa53c178785b86e7650e16f277a61a732a8a7019173b16129", size = 3352104, upload-time = "2025-10-23T04:08:51.005Z" }, - { url = "https://files.pythonhosted.org/packages/88/ef/f2eb2efefb0e0588b29ed573b8354ecd72c38e6143da7ed5ecf53e859bf8/multi_storage_client-0.33.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed9af7e77e3cbac1f614816062b36975dcbc610bd3f8c86741d48aa18c718781", size = 5272154, upload-time = "2025-10-23T04:07:49.572Z" }, - { url = "https://files.pythonhosted.org/packages/1e/49/050aa4fccb2579d2ef5bd0d27169ec98fe85c92bba7a2c31154c491a4f75/multi_storage_client-0.33.0-cp313-cp313-macosx_11_0_x86_64.whl", hash = "sha256:c9d75e95a266ee858cf20c88ed255021552de67a40af9c8884d2fc22037dcd2b", size = 5399474, upload-time = "2025-10-23T04:09:14.545Z" }, - { url = "https://files.pythonhosted.org/packages/f6/4b/70c2df3b60c28360f185188d351e9c3958b702614963a09ffb1dc251c1ca/multi_storage_client-0.33.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48195a2ab9e6e9a2763bde17184cad2bdef82684353e210d0d325f20cea18869", size = 3181788, upload-time = "2025-10-23T04:03:10.404Z" }, - { url = "https://files.pythonhosted.org/packages/9b/96/5008852677fdad10eb9d8dd08a6ea58c6f7e820199a3b2c56607186ac6d5/multi_storage_client-0.33.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd64403efdcee2a6efcf7bfdb01422dd174c146014563b09f44590346fd835e6", size = 3351269, upload-time = "2025-10-23T04:00:34.714Z" }, + { url = "https://files.pythonhosted.org/packages/39/8e/f4930451c13c62d3d756bf56a00075f7988361cf0d54287d5d11bcb29451/multi_storage_client-0.37.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:9726e0f53098d4efddd942c23b3efa5a21526ee96c36413b84a50da0266d0084", size = 9770930, upload-time = "2025-12-06T00:03:25.851Z" }, + { url = "https://files.pythonhosted.org/packages/43/37/3841a49bfe67c49e14eff90fe17fa7b1c578c628edbf7e9ca60a01170fc8/multi_storage_client-0.37.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d665cb4504f513dbf358c3a20e8075af4fec4423f4b9e72195012e0603480f63", size = 6029659, upload-time = "2025-12-06T00:00:16.808Z" }, + { url = "https://files.pythonhosted.org/packages/c6/ae/398bc3ffcfe9bffaec040c74d8d6444c0cee74fd0edb83aba6f57161340c/multi_storage_client-0.37.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c105f855be997f8fd1a00c3cc5f9aa2b3019b32c9c44340ce6b4759f9092e263", size = 6436878, upload-time = "2025-12-06T00:01:28.221Z" }, + { url = "https://files.pythonhosted.org/packages/45/04/6cbc509fdb133bbf83b686f85181e9cf2d4ae3dbaa9f807a4bbb0eca50dc/multi_storage_client-0.37.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:2074a351d468061cbe628313caaed89a40d723bb1195367ff8855adaa1ea41bd", size = 9774184, upload-time = "2025-12-06T00:03:53.089Z" }, + { url = "https://files.pythonhosted.org/packages/4c/ce/4dc3b2802e222da2fac91f9ca399948bfda5950738402045de12b9520e58/multi_storage_client-0.37.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcf5c28e0cfbed1589e6f347257a000299d6555ece1f595e9a2b4b02f49e44bd", size = 6029579, upload-time = "2025-12-05T23:58:42.569Z" }, + { url = "https://files.pythonhosted.org/packages/9c/3b/09f3f4c491df6125748a3457ad5e347f96a20bb53a01fbe2eb57337295ea/multi_storage_client-0.37.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61502ea72726cc8389f864010ec2b78c437de7fe364d345351d520851ae4bd1d", size = 6437770, upload-time = "2025-12-05T23:58:00.62Z" }, + { url = "https://files.pythonhosted.org/packages/f5/ae/3e159e5a9cb48032b875869fd39bb7b463df2cb66606fbbc5688918e85ae/multi_storage_client-0.37.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:44e5a36dfb991b3abfac39727912ca919e53dee5d929c8202f273de592286a60", size = 9764143, upload-time = "2025-12-06T00:01:52.312Z" }, + { url = "https://files.pythonhosted.org/packages/3e/a3/e9cbaccbe586c6ae7b173429205e72ea7dd7e05eb96ec3787586f61c264a/multi_storage_client-0.37.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08e076800ea2fd17b90f97b67886e743f502abe729b6087cb77238c9b93cbc6b", size = 6026251, upload-time = "2025-12-05T23:59:49.329Z" }, + { url = "https://files.pythonhosted.org/packages/78/03/86fc443834e591fa3222733a0fd4933f09cd5174fb8e2d5579a45f4da64b/multi_storage_client-0.37.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c181c1ac2af338b623ac158423ae15c2f572a5a4eec0e588ba59e2957f8571bf", size = 6438956, upload-time = "2025-12-05T23:56:25.918Z" }, + { url = "https://files.pythonhosted.org/packages/61/5d/85eafdd0ec69d14894689cc33721eb0b8ead8af7d726565c781380c98306/multi_storage_client-0.37.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:2d08b1caf2e3c1ad693c91ee870096fbfe9f4a6373606dce3901b5f2a66e6836", size = 9762517, upload-time = "2025-12-06T00:06:56.352Z" }, + { url = "https://files.pythonhosted.org/packages/7f/87/518d0645b40e05d58d089e3b1ea551d0bece09e7884aaf241435d620331f/multi_storage_client-0.37.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f2bcd46aca73495b6f79ed82ad744f2566e5aacbd80b4fa1fb964bf2c4a1db9", size = 6026238, upload-time = "2025-12-06T00:02:23.012Z" }, + { url = "https://files.pythonhosted.org/packages/ba/90/0957386155cb3e71b827bb13ad197238fd899a4a6f2f5f9694798fab95cd/multi_storage_client-0.37.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89b83b97e9e0f4d8580456643efc426ae640c70330bbba46d3d3d0e4d19816e7", size = 6438506, upload-time = "2025-12-06T00:05:06.91Z" }, ] [package.optional-dependencies] @@ -1741,7 +1862,7 @@ wheels = [ [[package]] name = "oci" -version = "2.161.0" +version = "2.164.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1751,9 +1872,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/a2/0295ef211f8687b85505fb79ab3833ba8d56bb7aaaf2c0568ab289d2edec/oci-2.161.0.tar.gz", hash = "sha256:1322069822babf472feba130da131bce114e9070f95f7c5bf96d034520470c7e", size = 15836650, upload-time = "2025-10-07T06:01:02.165Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/d6/c8a0857d882a72e335d3eda7d3955f302a5b862f687c54a2554449cb595d/oci-2.164.0.tar.gz", hash = "sha256:fac58e1d29b36418cf1761826b31e2d152450bfec3c322e7a1d197327faf8bbf", size = 16144855, upload-time = "2025-11-18T06:34:08.556Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/7d/1a19fb91620d8dc82529860ec5f40730277749c4967c67cd3c91cb23e247/oci-2.161.0-py3-none-any.whl", hash = "sha256:e189272f165d2ae32d2839ce300f50ad8376a861500cf93e8295a10b51172b94", size = 32331958, upload-time = "2025-10-07T06:00:54.045Z" }, + { url = "https://files.pythonhosted.org/packages/6b/af/70c1475f3f15f2a4ef961a29238aabf5d953a3a5e7552c886d14d0987076/oci-2.164.0-py3-none-any.whl", hash = "sha256:3eb055f4b472655067fbcee99c01aa1b2c73d89f415fa1034a4479333a7581ae", size = 32970538, upload-time = "2025-11-18T06:33:59.657Z" }, ] [[package]] @@ -1864,6 +1985,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/21/2c/5e05f58658cf49b6667762cca03d6e7d85cededde2caf2ab37b81f80e574/pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044", size = 2674751, upload-time = "2025-04-12T17:49:59.628Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pockets" version = "0.9.1" @@ -1878,14 +2008,14 @@ wheels = [ [[package]] name = "prettytable" -version = "3.16.0" +version = "3.17.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "wcwidth" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/b1/85e18ac92afd08c533603e3393977b6bc1443043115a47bb094f3b98f94f/prettytable-3.16.0.tar.gz", hash = "sha256:3c64b31719d961bf69c9a7e03d0c1e477320906a98da63952bc6698d6164ff57", size = 66276, upload-time = "2025-03-24T19:39:04.008Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/45/b0847d88d6cfeb4413566738c8bbf1e1995fad3d42515327ff32cc1eb578/prettytable-3.17.0.tar.gz", hash = "sha256:59f2590776527f3c9e8cf9fe7b66dd215837cca96a9c39567414cbc632e8ddb0", size = 67892, upload-time = "2025-11-14T17:33:20.212Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/c7/5613524e606ea1688b3bdbf48aa64bafb6d0a4ac3750274c43b6158a390f/prettytable-3.16.0-py3-none-any.whl", hash = "sha256:b5eccfabb82222f5aa46b798ff02a8452cf530a352c31bddfa29be41242863aa", size = 33863, upload-time = "2025-03-24T19:39:02.359Z" }, + { url = "https://files.pythonhosted.org/packages/ee/8c/83087ebc47ab0396ce092363001fa37c17153119ee282700c0713a195853/prettytable-3.17.0-py3-none-any.whl", hash = "sha256:aad69b294ddbe3e1f95ef8886a060ed1666a0b83018bbf56295f6f226c43d287", size = 34433, upload-time = "2025-11-14T17:33:19.093Z" }, ] [[package]] @@ -2005,18 +2135,28 @@ wheels = [ [[package]] name = "psutil" -version = "7.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b3/31/4723d756b59344b643542936e37a31d1d3204bcdc42a7daa8ee9eb06fb50/psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2", size = 497660, upload-time = "2025-09-17T20:14:52.902Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/62/ce4051019ee20ce0ed74432dd73a5bb087a6704284a470bb8adff69a0932/psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13", size = 245242, upload-time = "2025-09-17T20:14:56.126Z" }, - { url = "https://files.pythonhosted.org/packages/38/61/f76959fba841bf5b61123fbf4b650886dc4094c6858008b5bf73d9057216/psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5", size = 246682, upload-time = "2025-09-17T20:14:58.25Z" }, - { url = "https://files.pythonhosted.org/packages/88/7a/37c99d2e77ec30d63398ffa6a660450b8a62517cabe44b3e9bae97696e8d/psutil-7.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e4454970b32472ce7deaa45d045b34d3648ce478e26a04c7e858a0a6e75ff3", size = 287994, upload-time = "2025-09-17T20:14:59.901Z" }, - { url = "https://files.pythonhosted.org/packages/9d/de/04c8c61232f7244aa0a4b9a9fbd63a89d5aeaf94b2fc9d1d16e2faa5cbb0/psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70e113920d51e89f212dd7be06219a9b88014e63a4cec69b684c327bc474e3", size = 291163, upload-time = "2025-09-17T20:15:01.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/58/c4f976234bf6d4737bc8c02a81192f045c307b72cf39c9e5c5a2d78927f6/psutil-7.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d4a113425c037300de3ac8b331637293da9be9713855c4fc9d2d97436d7259d", size = 293625, upload-time = "2025-09-17T20:15:04.492Z" }, - { url = "https://files.pythonhosted.org/packages/79/87/157c8e7959ec39ced1b11cc93c730c4fb7f9d408569a6c59dbd92ceb35db/psutil-7.1.0-cp37-abi3-win32.whl", hash = "sha256:09ad740870c8d219ed8daae0ad3b726d3bf9a028a198e7f3080f6a1888b99bca", size = 244812, upload-time = "2025-09-17T20:15:07.462Z" }, - { url = "https://files.pythonhosted.org/packages/bf/e9/b44c4f697276a7a95b8e94d0e320a7bf7f3318521b23de69035540b39838/psutil-7.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:57f5e987c36d3146c0dd2528cd42151cf96cd359b9d67cfff836995cc5df9a3d", size = 247965, upload-time = "2025-09-17T20:15:09.673Z" }, - { url = "https://files.pythonhosted.org/packages/26/65/1070a6e3c036f39142c2820c4b52e9243246fcfc3f96239ac84472ba361e/psutil-7.1.0-cp37-abi3-win_arm64.whl", hash = "sha256:6937cb68133e7c97b6cc9649a570c9a18ba0efebed46d8c5dae4c07fa1b67a07", size = 244971, upload-time = "2025-09-17T20:15:12.262Z" }, +version = "7.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/88/bdd0a41e5857d5d703287598cbf08dad90aed56774ea52ae071bae9071b6/psutil-7.1.3.tar.gz", hash = "sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74", size = 489059, upload-time = "2025-11-02T12:25:54.619Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/93/0c49e776b8734fef56ec9c5c57f923922f2cf0497d62e0f419465f28f3d0/psutil-7.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc", size = 239751, upload-time = "2025-11-02T12:25:58.161Z" }, + { url = "https://files.pythonhosted.org/packages/6f/8d/b31e39c769e70780f007969815195a55c81a63efebdd4dbe9e7a113adb2f/psutil-7.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0", size = 240368, upload-time = "2025-11-02T12:26:00.491Z" }, + { url = "https://files.pythonhosted.org/packages/62/61/23fd4acc3c9eebbf6b6c78bcd89e5d020cfde4acf0a9233e9d4e3fa698b4/psutil-7.1.3-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7", size = 287134, upload-time = "2025-11-02T12:26:02.613Z" }, + { url = "https://files.pythonhosted.org/packages/30/1c/f921a009ea9ceb51aa355cb0cc118f68d354db36eae18174bab63affb3e6/psutil-7.1.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251", size = 289904, upload-time = "2025-11-02T12:26:05.207Z" }, + { url = "https://files.pythonhosted.org/packages/a6/82/62d68066e13e46a5116df187d319d1724b3f437ddd0f958756fc052677f4/psutil-7.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa", size = 249642, upload-time = "2025-11-02T12:26:07.447Z" }, + { url = "https://files.pythonhosted.org/packages/df/ad/c1cd5fe965c14a0392112f68362cfceb5230819dbb5b1888950d18a11d9f/psutil-7.1.3-cp313-cp313t-win_arm64.whl", hash = "sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee", size = 245518, upload-time = "2025-11-02T12:26:09.719Z" }, + { url = "https://files.pythonhosted.org/packages/2e/bb/6670bded3e3236eb4287c7bcdc167e9fae6e1e9286e437f7111caed2f909/psutil-7.1.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353", size = 239843, upload-time = "2025-11-02T12:26:11.968Z" }, + { url = "https://files.pythonhosted.org/packages/b8/66/853d50e75a38c9a7370ddbeefabdd3d3116b9c31ef94dc92c6729bc36bec/psutil-7.1.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b", size = 240369, upload-time = "2025-11-02T12:26:14.358Z" }, + { url = "https://files.pythonhosted.org/packages/41/bd/313aba97cb5bfb26916dc29cf0646cbe4dd6a89ca69e8c6edce654876d39/psutil-7.1.3-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9", size = 288210, upload-time = "2025-11-02T12:26:16.699Z" }, + { url = "https://files.pythonhosted.org/packages/c2/fa/76e3c06e760927a0cfb5705eb38164254de34e9bd86db656d4dbaa228b04/psutil-7.1.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f", size = 291182, upload-time = "2025-11-02T12:26:18.848Z" }, + { url = "https://files.pythonhosted.org/packages/0f/1d/5774a91607035ee5078b8fd747686ebec28a962f178712de100d00b78a32/psutil-7.1.3-cp314-cp314t-win_amd64.whl", hash = "sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7", size = 250466, upload-time = "2025-11-02T12:26:21.183Z" }, + { url = "https://files.pythonhosted.org/packages/00/ca/e426584bacb43a5cb1ac91fae1937f478cd8fbe5e4ff96574e698a2c77cd/psutil-7.1.3-cp314-cp314t-win_arm64.whl", hash = "sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264", size = 245756, upload-time = "2025-11-02T12:26:23.148Z" }, + { url = "https://files.pythonhosted.org/packages/ef/94/46b9154a800253e7ecff5aaacdf8ebf43db99de4a2dfa18575b02548654e/psutil-7.1.3-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab", size = 238359, upload-time = "2025-11-02T12:26:25.284Z" }, + { url = "https://files.pythonhosted.org/packages/68/3a/9f93cff5c025029a36d9a92fef47220ab4692ee7f2be0fba9f92813d0cb8/psutil-7.1.3-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880", size = 239171, upload-time = "2025-11-02T12:26:27.23Z" }, + { url = "https://files.pythonhosted.org/packages/ce/b1/5f49af514f76431ba4eea935b8ad3725cdeb397e9245ab919dbc1d1dc20f/psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3", size = 263261, upload-time = "2025-11-02T12:26:29.48Z" }, + { url = "https://files.pythonhosted.org/packages/e0/95/992c8816a74016eb095e73585d747e0a8ea21a061ed3689474fabb29a395/psutil-7.1.3-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b", size = 264635, upload-time = "2025-11-02T12:26:31.74Z" }, + { url = "https://files.pythonhosted.org/packages/55/4c/c3ed1a622b6ae2fd3c945a366e64eb35247a31e4db16cf5095e269e8eb3c/psutil-7.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd", size = 247633, upload-time = "2025-11-02T12:26:33.887Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ad/33b2ccec09bf96c2b2ef3f9a6f66baac8253d7565d8839e024a6b905d45d/psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1", size = 244608, upload-time = "2025-11-02T12:26:36.136Z" }, ] [[package]] @@ -2173,6 +2313,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/de/f8342b68fa9e981d348039954657bdf681b2ab93de27443be51865ffa310/pyOpenSSL-19.1.0-py2.py3-none-any.whl", hash = "sha256:621880965a720b8ece2f1b2f54ea2071966ab00e2970ad2ce11d596102063504", size = 53749, upload-time = "2019-11-18T04:59:37.93Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2900,6 +3058,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, ] +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + [[package]] name = "urllib3" version = "2.4.0" @@ -3010,54 +3177,55 @@ wheels = [ [[package]] name = "xattr" -version = "1.2.0" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/65/14438ae55acf7f8fc396ee8340d740a3e1d6ef382bf25bf24156cfb83563/xattr-1.2.0.tar.gz", hash = "sha256:a64c8e21eff1be143accf80fd3b8fde3e28a478c37da298742af647ac3e5e0a7", size = 17293, upload-time = "2025-07-14T03:15:44.884Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/08/cd/a7db5dc24e03074f02457c76ddcd35f721db2fe9945bafa058b8796056dc/xattr-1.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3df4d8d91e2996c3c72a390ec82e8544acdcb6c7df67b954f1736ff37ea4293e", size = 24248, upload-time = "2025-07-14T03:14:23.279Z" }, - { url = "https://files.pythonhosted.org/packages/5a/6c/236b7be6afe3f2fae6a0834f3ddca3d1cd7695d76247312069a7247f8a5a/xattr-1.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f5eec248976bbfa6c23df25d4995413df57dccf4161f6cbae36f643e99dbc397", size = 19213, upload-time = "2025-07-14T03:14:24.472Z" }, - { url = "https://files.pythonhosted.org/packages/4a/db/776dc933799addf692a8e1a2094f87f5615a5b7de3a4ec83a264a1a23783/xattr-1.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fafecfdedf7e8d455443bec2c3edab8a93d64672619cd1a4ee043a806152e19c", size = 19547, upload-time = "2025-07-14T03:14:25.619Z" }, - { url = "https://files.pythonhosted.org/packages/df/51/6e40331e5effd8f592cab3a6001eb91c9f023ab0c2c1f54cc076e90eee36/xattr-1.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c229e245c6c9a85d2fd7d07531498f837dd34670e556b552f73350f11edf000c", size = 39433, upload-time = "2025-07-14T03:14:27.143Z" }, - { url = "https://files.pythonhosted.org/packages/5e/0d/7e072a6d30434e93c0046ef1267229162445f15485a1a133dcc9efde3b60/xattr-1.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:376631e2383918fbc3dc9bcaeb9a533e319322d2cff1c119635849edf74e1126", size = 37315, upload-time = "2025-07-14T03:14:28.274Z" }, - { url = "https://files.pythonhosted.org/packages/51/5b/be272ba051442fb308494675a8e49b69c04cb97123d257eac810cfabe0ba/xattr-1.2.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbae24ab22afe078d549645501ecacaa17229e0b7769c8418fad69b51ad37c9", size = 39222, upload-time = "2025-07-14T03:14:29.676Z" }, - { url = "https://files.pythonhosted.org/packages/48/50/5e0e900461ada1628d7909da5a21189087fd2ae80d313983d4cd55631d70/xattr-1.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a161160211081d765ac41fa056f4f9b1051f027f08188730fbc9782d0dce623e", size = 38679, upload-time = "2025-07-14T03:14:31.061Z" }, - { url = "https://files.pythonhosted.org/packages/1e/6c/e76b0fb90934fbc991efd5f4c0d1f1e41e8ed9d53f2a141f1626eae0f101/xattr-1.2.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a542acf6c4e8221664b51b35e0160c44bd0ed1f2fd80019476f7698f4911e560", size = 37069, upload-time = "2025-07-14T03:14:32.456Z" }, - { url = "https://files.pythonhosted.org/packages/8f/1a/ea62d888abf8850baba65ebea887f70de486c10a7b854e87091a15c0939f/xattr-1.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:034f075fc5a9391a1597a6c9a21cb57b688680f0f18ecf73b2efc22b8d330cff", size = 38276, upload-time = "2025-07-14T03:14:33.852Z" }, - { url = "https://files.pythonhosted.org/packages/5d/e2/bf74df197a415f25e07378bfa301788e3bf2ac029c3a6c7bd56a900934ff/xattr-1.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:00c26c14c90058338993bb2d3e1cebf562e94ec516cafba64a8f34f74b9d18b4", size = 24246, upload-time = "2025-07-14T03:14:34.873Z" }, - { url = "https://files.pythonhosted.org/packages/a5/51/922df424556ff35b20ca043da5e4dcf0f99cbcb674f59046d08ceff3ebc7/xattr-1.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b4f43dc644db87d5eb9484a9518c34a864cb2e588db34cffc42139bf55302a1c", size = 19212, upload-time = "2025-07-14T03:14:35.905Z" }, - { url = "https://files.pythonhosted.org/packages/7c/72/1ed37812e8285c8002b8834395c53cc89a2d83aa088db642b217be439017/xattr-1.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7602583fc643ca76576498e2319c7cef0b72aef1936701678589da6371b731b", size = 19546, upload-time = "2025-07-14T03:14:37.242Z" }, - { url = "https://files.pythonhosted.org/packages/d4/b8/ec75db23d81beec68e3be20ea176c11f125697d3bbb5e118b9de9ea7a9ab/xattr-1.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90c3ad4a9205cceb64ec54616aa90aa42d140c8ae3b9710a0aaa2843a6f1aca7", size = 39426, upload-time = "2025-07-14T03:14:38.264Z" }, - { url = "https://files.pythonhosted.org/packages/d4/9f/c24950641b138072eda7f34d86966dd15cfe3af9a111b5e77b85ee55f99c/xattr-1.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83d87cfe19cd606fc0709d45a4d6efc276900797deced99e239566926a5afedf", size = 37311, upload-time = "2025-07-14T03:14:39.347Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d5/3b7e0dab706d09c6cdb2f05384610e6c5693c72e3794d54a4cad8c838373/xattr-1.2.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c67dabd9ddc04ead63fbc85aed459c9afcc24abfc5bb3217fff7ec9a466faacb", size = 39222, upload-time = "2025-07-14T03:14:40.768Z" }, - { url = "https://files.pythonhosted.org/packages/0e/16/80cf8ec7d92d20b2860c96a1eca18d25e27fa4770f32c9e8250ff32e7386/xattr-1.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9a18ee82d8ba2c17f1e8414bfeb421fa763e0fb4acbc1e124988ca1584ad32d5", size = 38694, upload-time = "2025-07-14T03:14:41.93Z" }, - { url = "https://files.pythonhosted.org/packages/38/c0/b154b254e6e4596aed3210dd48b2e82d958b16d9a7f65346b9154968d2d0/xattr-1.2.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:38de598c47b85185e745986a061094d2e706e9c2d9022210d2c738066990fe91", size = 37055, upload-time = "2025-07-14T03:14:43.435Z" }, - { url = "https://files.pythonhosted.org/packages/dc/1d/3a615660849ef9bdf46d04f9c6d40ee082f7427678013ff85452ed9497db/xattr-1.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:15e754e854bdaac366ad3f1c8fbf77f6668e8858266b4246e8c5f487eeaf1179", size = 38275, upload-time = "2025-07-14T03:14:45.18Z" }, - { url = "https://files.pythonhosted.org/packages/37/e5/b048a5f6c5a489915026b70b9133242a2a368383ddab24e4e3a5bdba7ebd/xattr-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:daff0c1f5c5e4eaf758c56259c4f72631fa9619875e7a25554b6077dc73da964", size = 24240, upload-time = "2025-07-14T03:14:46.173Z" }, - { url = "https://files.pythonhosted.org/packages/cc/f5/d795774f719a0be6137041d4833ca00b178f816e538948548dff79530f34/xattr-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:109b11fb3f73a0d4e199962f11230ab5f462e85a8021874f96c1732aa61148d5", size = 19218, upload-time = "2025-07-14T03:14:47.412Z" }, - { url = "https://files.pythonhosted.org/packages/cb/8b/65f3bed09ca9ced27bbba8d4a3326f14a58b98ac102163d85b545f81d9c2/xattr-1.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c7c12968ce0bf798d8ba90194cef65de768bee9f51a684e022c74cab4218305", size = 19539, upload-time = "2025-07-14T03:14:48.413Z" }, - { url = "https://files.pythonhosted.org/packages/96/2d/01ecfdf41ce70f7e29c8a21e730de3c157fb1cb84391923581af81a44c45/xattr-1.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d37989dabf25ff18773e4aaeebcb65604b9528f8645f43e02bebaa363e3ae958", size = 39631, upload-time = "2025-07-14T03:14:49.665Z" }, - { url = "https://files.pythonhosted.org/packages/c9/e9/15cbf9c59cf1117e3c45dd429c52f9dab25d95e65ac245c5ad9532986bec/xattr-1.2.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:165de92b0f2adafb336f936931d044619b9840e35ba01079f4dd288747b73714", size = 37552, upload-time = "2025-07-14T03:14:50.718Z" }, - { url = "https://files.pythonhosted.org/packages/9d/f5/cb4dad87843fe79d605cf5d10caad80e2c338a06f0363f1443449185f489/xattr-1.2.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82191c006ae4c609b22b9aea5f38f68fff022dc6884c4c0e1dba329effd4b288", size = 39472, upload-time = "2025-07-14T03:14:51.74Z" }, - { url = "https://files.pythonhosted.org/packages/5a/d9/012df7b814cc4a0ae41afb59ac31d0469227397b29f58c1377e8db0f34ba/xattr-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b2e9c87dc643b09d86befad218e921f6e65b59a4668d6262b85308de5dbd1dd", size = 38802, upload-time = "2025-07-14T03:14:52.801Z" }, - { url = "https://files.pythonhosted.org/packages/d8/08/e107a5d294a816586f274c33aea480fe740fd446276efc84c067e6c82de2/xattr-1.2.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:14edd5d47d0bb92b23222c0bb6379abbddab01fb776b2170758e666035ecf3aa", size = 37125, upload-time = "2025-07-14T03:14:54.313Z" }, - { url = "https://files.pythonhosted.org/packages/3e/6c/a6f9152e10543af67ea277caae7c5a6400a581e407c42156ffce71dd8242/xattr-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:12183d5eb104d4da787638c7dadf63b718472d92fec6dbe12994ea5d094d7863", size = 38456, upload-time = "2025-07-14T03:14:55.383Z" }, - { url = "https://files.pythonhosted.org/packages/b6/f9/6c98102949691f7e9caf9a31118be6e46720a23049f417dcf77cc689d06e/xattr-1.2.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c385ea93a18aeb6443a719eb6a6b1d7f7b143a4d1f2b08bc4fadfc429209e629", size = 24242, upload-time = "2025-07-14T03:14:56.392Z" }, - { url = "https://files.pythonhosted.org/packages/22/6a/130f6cd5cbb0ea0e470c9b366a21b9474eb607288fd17256d60e50f05d0b/xattr-1.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2d39d7b36842c67ab3040bead7eb6d601e35fa0d6214ed20a43df4ec30b6f9f9", size = 19219, upload-time = "2025-07-14T03:14:57.367Z" }, - { url = "https://files.pythonhosted.org/packages/3d/40/93f2dd033544028e7b9512b8b9fb6872ec74a804fbb686e62b83fdf72e21/xattr-1.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:320ef856bb817f4c40213b6de956dc440d0f23cdc62da3ea02239eb5147093f8", size = 19538, upload-time = "2025-07-14T03:14:58.434Z" }, - { url = "https://files.pythonhosted.org/packages/13/d5/7e301840afb7e3d3ad07b95af1815c7b674373d1f7d95cb6f2ecc794fdb1/xattr-1.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26d306bfb3b5641726f2ee0da6f63a2656aa7fdcfd15de61c476e3ca6bc3277e", size = 39544, upload-time = "2025-07-14T03:14:59.66Z" }, - { url = "https://files.pythonhosted.org/packages/50/19/64a1b02d237126c3198257ebd7c643374d928915a86d36db7ad4da0a4f28/xattr-1.2.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c67e70d5d8136d328ad13f85b887ffa97690422f1a11fb29ab2f702cf66e825a", size = 37468, upload-time = "2025-07-14T03:15:01.096Z" }, - { url = "https://files.pythonhosted.org/packages/59/53/f794e3630cf16840e199f086520aca6a59a30f9428b1423a8581bc9cee9d/xattr-1.2.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8904d3539afe1a84fc0b7f02fa91da60d2505adf2d5951dc855bf9e75fe322b2", size = 39378, upload-time = "2025-07-14T03:15:02.149Z" }, - { url = "https://files.pythonhosted.org/packages/f0/a2/ee2d1cdba5e5273886b9f157cb7ef5ba6d83b177d0c17a203d3ac11ee7f7/xattr-1.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2520516c1d058895eae00b2b2f10833514caea6dc6802eef1e431c474b5317ad", size = 38797, upload-time = "2025-07-14T03:15:03.206Z" }, - { url = "https://files.pythonhosted.org/packages/73/28/9216ba5a4485561cf628ea8f7a0753f246e7f0df31656a1cf363c1b7bed4/xattr-1.2.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:29d06abbef4024b7469fcd0d4ade6d2290582350a4df95fcc48fa48b2e83246b", size = 37142, upload-time = "2025-07-14T03:15:04.249Z" }, - { url = "https://files.pythonhosted.org/packages/fd/20/dee2ec6153323592e33f2b82c8c0f0946b9d1989e3c521a9f3d6daac47e5/xattr-1.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:093c75f7d9190be355b8e86da3f460b9bfe3d6a176f92852d44dcc3289aa10dc", size = 38462, upload-time = "2025-07-14T03:15:05.387Z" }, - { url = "https://files.pythonhosted.org/packages/d2/aa/5ea6dd94d0ea7affdd57a6eeb88a9e62a6b600e76aff03d32e89474b7c2c/xattr-1.2.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:29ae44247d46e63671311bf7e700826a97921278e2c0c04c2d11741888db41b8", size = 15938, upload-time = "2025-07-14T03:15:27.426Z" }, - { url = "https://files.pythonhosted.org/packages/24/a4/5bab900c0b715b96bfdd16f0b9d160ae8f7e2065d3ff74e9497087d21828/xattr-1.2.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:629c42c1dd813442d90f281f69b88ef0c9625f604989bef8411428671f70f43e", size = 16428, upload-time = "2025-07-14T03:15:28.439Z" }, - { url = "https://files.pythonhosted.org/packages/dd/14/70d531b536d6aea9032b1ed4fd241be6a59301a86082564c6bbd7bbdc80c/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:549f8fbda5da48cafc81ba6ab7bb8e8e14c4b0748c37963dc504bcae505474b7", size = 18286, upload-time = "2025-07-14T03:15:29.653Z" }, - { url = "https://files.pythonhosted.org/packages/26/a4/1b2e04ea684fc081183eca6faff485da5ab87b25b4dcfcc4164ae87865a1/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa83e677b5f92a3c5c86eaf875e9d3abbc43887ff1767178def865fa9f12a3a0", size = 17997, upload-time = "2025-07-14T03:15:30.997Z" }, - { url = "https://files.pythonhosted.org/packages/1f/03/75a399549e82b6a20ff84d71ee9e777caf6bc687e8004d8b3699565a6aad/xattr-1.2.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb669f01627962ce2bc556f19d421162247bc2cad0d4625d6ea5eb32af4cf29b", size = 17908, upload-time = "2025-07-14T03:15:32.335Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/08/d5/25f7b19af3a2cb4000cac4f9e5525a40bec79f4f5d0ac9b517c0544586a0/xattr-1.3.0.tar.gz", hash = "sha256:30439fabd7de0787b27e9a6e1d569c5959854cb322f64ce7380fedbfa5035036", size = 17148, upload-time = "2025-10-13T22:16:47.353Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/11/bbb25ab921e02efb789efcab5b7d03581b5d28f71d829f21e4ea6aba09fb/xattr-1.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a80c4617e08670cdc3ba71f1dbb275c1627744c5c3641280879cb3bc95a07237", size = 23453, upload-time = "2025-10-13T22:15:50.753Z" }, + { url = "https://files.pythonhosted.org/packages/be/88/66021fdfbb2037a94fc5b16c1dce1894b8e9da7a1829e4be0b491b3f24ff/xattr-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51cdaa359f5cd2861178ae01ea3647b56dbdfd98e724a8aa3c04f77123b78217", size = 18551, upload-time = "2025-10-13T22:15:51.961Z" }, + { url = "https://files.pythonhosted.org/packages/be/f7/5dd21fcfc48487a59fcec33ffe02eb671f256424869e9aef87e33c65d95b/xattr-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2fea070768d7d2d25797817bea93bf0a6fda6449e88cfee8bb3d75de9ed11c7b", size = 18852, upload-time = "2025-10-13T22:15:53.104Z" }, + { url = "https://files.pythonhosted.org/packages/af/2a/e29753ac17a92aadf27b9e16b1d600584d9f10acd0b399d2c06f47af2dff/xattr-1.3.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:69bca34be2d7a928389aff4e32f27857e1c62d04c91ec7c1519b1636870bd58f", size = 38547, upload-time = "2025-10-13T22:15:54.385Z" }, + { url = "https://files.pythonhosted.org/packages/f4/46/b2c9185d24b93542e4307ce30cd3d4eb6af8efdc843d98ff9f07fcb048d9/xattr-1.3.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:05f8e068409742d246babba60cff8310b2c577745491f498b08bf068e0c867a3", size = 38755, upload-time = "2025-10-13T22:15:55.738Z" }, + { url = "https://files.pythonhosted.org/packages/c0/0a/93cf1f03536bf38e8fd3fe57eb04124e4dfe2e16c0c5ced589d3360a1858/xattr-1.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bbd06987102bc11f5cbd08b15d1029832b862cf5bc61780573fc0828812f01ca", size = 38052, upload-time = "2025-10-13T22:15:57.031Z" }, + { url = "https://files.pythonhosted.org/packages/55/ad/60e43f7e1037cee671e14c2a283e3e7168b756c9938eba62f0616e6599aa/xattr-1.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b8589744116d2c37928b771c50383cb281675cd6dcfd740abfab6883e3d4af85", size = 37560, upload-time = "2025-10-13T22:15:58.295Z" }, + { url = "https://files.pythonhosted.org/packages/8a/64/292426ad5653e72c6e1325bbff22868a20077290d967cebb9c0624ad08b6/xattr-1.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:331a51bf8f20c27822f44054b0d760588462d3ed472d5e52ba135cf0bea510e8", size = 23448, upload-time = "2025-10-13T22:15:59.229Z" }, + { url = "https://files.pythonhosted.org/packages/63/84/6539fbe620da8e5927406e76b9c8abad8953025d5f578d792747c38a8c0e/xattr-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:196360f068b74fa0132a8c6001ce1333f095364b8f43b6fd8cdaf2f18741ef89", size = 18553, upload-time = "2025-10-13T22:16:00.151Z" }, + { url = "https://files.pythonhosted.org/packages/cc/bb/c1c2e24a49f8d13ff878fb85aabc42ea1b2f98ce08d8205b9661d517a9cc/xattr-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:405d2e4911d37f2b9400fa501acd920fe0c97fe2b2ec252cb23df4b59c000811", size = 18848, upload-time = "2025-10-13T22:16:01.046Z" }, + { url = "https://files.pythonhosted.org/packages/02/c2/a60aad150322b217dfe33695d8d9f32bc01e8f300641b6ba4b73f4b3c03f/xattr-1.3.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4ae3a66ae1effd40994f64defeeaa97da369406485e60bfb421f2d781be3b75d", size = 38547, upload-time = "2025-10-13T22:16:01.973Z" }, + { url = "https://files.pythonhosted.org/packages/c6/58/2eca142bad4ea0a2be6b58d3122d0acce310c4e53fa7defd168202772178/xattr-1.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:69cd3bfe779f7ba87abe6473fdfa428460cf9e78aeb7e390cfd737b784edf1b5", size = 38753, upload-time = "2025-10-13T22:16:03.244Z" }, + { url = "https://files.pythonhosted.org/packages/2b/50/d032e5254c2c27d36bdb02abdf2735db6768a441f0e3d0f139e0f9f56638/xattr-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c5742ca61761a99ae0c522f90a39d5fb8139280f27b254e3128482296d1df2db", size = 38054, upload-time = "2025-10-13T22:16:04.656Z" }, + { url = "https://files.pythonhosted.org/packages/04/24/458a306439aabe0083ca0a7b14c3e6a800ab9782b5ec0bdcec4ec9f3dc6c/xattr-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4a04ada131e9bdfd32db3ab1efa9f852646f4f7c9d6fde0596c3825c67161be3", size = 37562, upload-time = "2025-10-13T22:16:05.97Z" }, + { url = "https://files.pythonhosted.org/packages/bf/78/00bdc9290066173e53e1e734d8d8e1a84a6faa9c66aee9df81e4d9aeec1c/xattr-1.3.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:dd4e63614722d183e81842cb237fd1cc978d43384166f9fe22368bfcb187ebe5", size = 23476, upload-time = "2025-10-13T22:16:06.942Z" }, + { url = "https://files.pythonhosted.org/packages/53/16/5243722294eb982514fa7b6b87a29dfb7b29b8e5e1486500c5babaf6e4b3/xattr-1.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:995843ef374af73e3370b0c107319611f3cdcdb6d151d629449efecad36be4c4", size = 18556, upload-time = "2025-10-13T22:16:08.209Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5c/d7ab0e547bea885b55f097206459bd612cefb652c5fc1f747130cbc0d42c/xattr-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fa23a25220e29d956cedf75746e3df6cc824cc1553326d6516479967c540e386", size = 18869, upload-time = "2025-10-13T22:16:10.319Z" }, + { url = "https://files.pythonhosted.org/packages/98/25/25cc7d64f07de644b7e9057842227adf61017e5bcfe59a79df79f768874c/xattr-1.3.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b4345387087fffcd28f709eb45aae113d911e1a1f4f0f70d46b43ba81e69ccdd", size = 38797, upload-time = "2025-10-13T22:16:11.624Z" }, + { url = "https://files.pythonhosted.org/packages/a9/24/cc350bcdbed006dfcc6ade0ac817693b8b3d4b2787f20e427fd0697042e4/xattr-1.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fe92bb05eb849ab468fe13e942be0f8d7123f15d074f3aba5223fad0c4b484de", size = 38956, upload-time = "2025-10-13T22:16:13.121Z" }, + { url = "https://files.pythonhosted.org/packages/9b/b2/9416317ac89e2ed759a861857cda0d5e284c3691e6f460d36cc2bd5ce4d1/xattr-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6c42ef5bdac3febbe28d3db14d3a8a159d84ba5daca2b13deae6f9f1fc0d4092", size = 38214, upload-time = "2025-10-13T22:16:14.389Z" }, + { url = "https://files.pythonhosted.org/packages/38/63/188f7cb41ab35d795558325d5cc8ab552171d5498cfb178fd14409651e18/xattr-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2aaa5d66af6523332189108f34e966ca120ff816dfa077ca34b31e6263f8a236", size = 37754, upload-time = "2025-10-13T22:16:15.306Z" }, + { url = "https://files.pythonhosted.org/packages/27/d3/6a1731a339842afcbb2643bc93628d4ab9c52d1bf26a7b085ca8f35bba6e/xattr-1.3.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:937d8c91f6f372788aff8cc0984c4be3f0928584839aaa15ff1c95d64562071c", size = 23474, upload-time = "2025-10-13T22:16:16.33Z" }, + { url = "https://files.pythonhosted.org/packages/1b/25/6741ed3d4371eaa2fae70b259d17a580d858ebff8af0042a59e11bb6385f/xattr-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e470b3f15e9c3e263662506ff26e73b3027e1c9beac2cbe9ab89cad9c70c0495", size = 18558, upload-time = "2025-10-13T22:16:17.251Z" }, + { url = "https://files.pythonhosted.org/packages/ba/84/cc450688abeb8647aa93a62c1435bb532db11313abfeb9d43b28b4751503/xattr-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f2238b2a973fcbf5fefa1137db97c296d27f4721f7b7243a1fac51514565e9ec", size = 18869, upload-time = "2025-10-13T22:16:18.607Z" }, + { url = "https://files.pythonhosted.org/packages/b9/49/0e2315225ba7557e9801f9f0168a0195a7e13a3223088081eb32d2760533/xattr-1.3.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f32bb00395371f4a3bed87080ae315b19171ba114e8a5aa403a2c8508998ce78", size = 38702, upload-time = "2025-10-13T22:16:19.539Z" }, + { url = "https://files.pythonhosted.org/packages/7e/8c/de4f4441c318ac38a5d3d7d4b8b940305a667e9320c34a45e57f6eb6b0e8/xattr-1.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:78df56bfe3dd4912548561ed880225437d6d49ef082fe6ccd45670810fa53cfe", size = 38869, upload-time = "2025-10-13T22:16:20.554Z" }, + { url = "https://files.pythonhosted.org/packages/ef/2a/38e0498c22aa733a9b5265f4929af4613e5b967659cf3e5f2f933b3ba118/xattr-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:864c34c14728f21c3ef89a9f276d75ae5e31dd34f48064e0d37e4bf0f671fc6e", size = 38210, upload-time = "2025-10-13T22:16:22.212Z" }, + { url = "https://files.pythonhosted.org/packages/62/21/49b386eb8dcf42ac8e3ff55b6e8ea0a1e8b6b799571599c795265d2dc1b5/xattr-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1fd185b3f01121bd172c98b943f9341ca3b9ea6c6d3eb7fe7074723614d959ff", size = 37753, upload-time = "2025-10-13T22:16:23.959Z" }, + { url = "https://files.pythonhosted.org/packages/24/49/b8bc589427696d67bc2b0992c188e576f70242c586a379f97698772c0c3d/xattr-1.3.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:630c85020282bd0bcb72c3d031491c4e91d7f29bb4c094ebdfb9db51375c5b07", size = 23543, upload-time = "2025-10-13T22:16:25.242Z" }, + { url = "https://files.pythonhosted.org/packages/9d/0a/03192e78071cfb86e6d8ceae0e5dcec4bacf0fd734755263aabd01532e50/xattr-1.3.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:95f1e14a4d9ca160b4b78c527bf2bac6addbeb0fd9882c405fc0b5e3073a8752", size = 18673, upload-time = "2025-10-13T22:16:26.224Z" }, + { url = "https://files.pythonhosted.org/packages/3d/36/9ab4f0b5c3d10df3aceaecf7e395cabe7fb7c7c004b2dc3f3cff0ef70fc3/xattr-1.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:88557c0769f64b1d014aada916c9630cfefa38b0be6c247eae20740d2d8f7b47", size = 18877, upload-time = "2025-10-13T22:16:27.164Z" }, + { url = "https://files.pythonhosted.org/packages/1c/1c/ab905d19a1349e847e37e02933316d17adfd1dd70b64d366885ab0bd959d/xattr-1.3.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c6992eb5da32c0a1375a9eeacfab15c66eebc8bd34be63ebd1eae80cc2f8bf03", size = 38782, upload-time = "2025-10-13T22:16:28.157Z" }, + { url = "https://files.pythonhosted.org/packages/83/a7/f615a6e5d48d47e9febbe5a62b94ffa0d8bfc6d325b899873281abac10c4/xattr-1.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:da5954424099ca9d402933eaf6112c29ddde26e6da59b32f0bf5a4e35eec0b28", size = 38936, upload-time = "2025-10-13T22:16:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/9f/6c/a8221567a7cbc00ac305a4842318562f90bb1fdd16636e1379361133f1f4/xattr-1.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:726b4d0b66724759132cacdcd84a5b19e00b0cdf704f4c2cf96d0c08dc5eaeb5", size = 38268, upload-time = "2025-10-13T22:16:30.238Z" }, + { url = "https://files.pythonhosted.org/packages/3e/4d/38a98df630e19360d98df8d98ec4a2560612840823f0bf55f81e0e84c866/xattr-1.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:928c49ceb0c70fc04732e46fa236d7c8281bfc3db1b40875e5f548bb14d2668c", size = 37825, upload-time = "2025-10-13T22:16:31.557Z" }, + { url = "https://files.pythonhosted.org/packages/97/3f/6d50237645edd83e9dc6bf6521e4e28335845b674cabefd69f12bc4db59a/xattr-1.3.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:f3bef26fd2d5d7b17488f4cc4424a69894c5a8ed71dd5f657fbbf69f77f68a51", size = 23788, upload-time = "2025-10-13T22:16:32.465Z" }, + { url = "https://files.pythonhosted.org/packages/f4/8b/3efd48c85e08d1bfcbd46f87368b155d3d3de78bb660b408fbaff7623572/xattr-1.3.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:64f1fb511f8463851e0d97294eb0e0fde54b059150da90582327fb43baa1bb92", size = 18825, upload-time = "2025-10-13T22:16:33.442Z" }, + { url = "https://files.pythonhosted.org/packages/fd/19/4b4e3e2ea5fa213ff4220e84450628fecde042b0961e7b4e6d845e555ade/xattr-1.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1e6c216927b16fd4b72df655d5124b69b2a406cb3132b5231179021182f0f0d1", size = 19023, upload-time = "2025-10-13T22:16:34.395Z" }, + { url = "https://files.pythonhosted.org/packages/6f/4a/6460befb22ce8d43abdb22d2bf5aa63b8311507c75dc50ad402681b4b095/xattr-1.3.0-cp314-cp314t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c0d9ab346cdd20539afddf2f9e123efee0fe8d54254d9fc580b4e2b4e6d77351", size = 43732, upload-time = "2025-10-13T22:16:35.41Z" }, + { url = "https://files.pythonhosted.org/packages/15/a8/3fa83e9f91dc868d764b2ca3758bf449945c4b1511e137e33a6210609b58/xattr-1.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2c5e7ba0e893042deef4e8638db7a497680f587ac7bd6d68925f29af633dfa6b", size = 43851, upload-time = "2025-10-13T22:16:36.416Z" }, + { url = "https://files.pythonhosted.org/packages/28/b3/06bf7f691c3f35e94a37e097ae1868fbaa916cc174b1b916fb7aeca441e4/xattr-1.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1e0dabb39596d8d7b83d6f9f7fa30be68cf15bfb135cb633e2aad9887d308a32", size = 43274, upload-time = "2025-10-13T22:16:37.805Z" }, + { url = "https://files.pythonhosted.org/packages/df/41/d6298c95513eabe091a6851bff5e7928fab49ffd9143808feaaf7721cf33/xattr-1.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5eeaa944516b7507ec51456751334b4880e421de169bbd067c4f32242670d606", size = 42864, upload-time = "2025-10-13T22:16:38.811Z" }, ] [[package]]