From 398a1dec51a60a71a9ded506b5f0642ea22f8285 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Mon, 1 Dec 2025 14:26:12 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 838933101 --- grain/_src/python/BUILD | 43 - grain/_src/python/data_loader.py | 5 +- grain/_src/python/dataset/BUILD | 1 - grain/_src/python/dataset/dataset.py | 9 +- .../_src/python/dataset/transformations/BUILD | 2 +- .../dataset/transformations/prefetch.py | 439 +------ .../dataset/transformations/prefetch_test.py | 1154 ++++++----------- .../transformations/process_prefetch.py | 22 - .../transformations/process_prefetch_test.py | 18 - grain/_src/python/grain_pool.py | 834 ------------ grain/_src/python/grain_pool_test.py | 471 ------- grain/python/experimental.py | 1 - 12 files changed, 437 insertions(+), 2562 deletions(-) delete mode 100644 grain/_src/python/grain_pool.py delete mode 100644 grain/_src/python/grain_pool_test.py diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index cb57e0116..ac1980520 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -208,49 +208,6 @@ py_test( ], ) -py_library( - name = "grain_pool", - srcs = ["grain_pool.py"], - srcs_version = "PY3", - target_compatible_with = select({ - "@platforms//os:windows": ["@platforms//:incompatible"], - "//conditions:default": [], - }), - deps = [ - ":grain_logging", - ":multiprocessing_common", - ":options", - ":record", - ":shared_memory_array", - "//grain/_src/core:config", - "//grain/_src/core:monitoring", - "//grain/_src/core:parallel", - "//grain/_src/core:tree_lib", - "@abseil-py//absl/flags", - "@abseil-py//absl/logging", - "@pypi//cloudpickle:pkg", - ], -) - -py_test( - name = "grain_pool_test", - srcs = ["grain_pool_test.py"], - shard_count = 20, - srcs_version = "PY3", - tags = ["not_run:arm"], - deps = [ - ":data_sources", - ":grain_pool", - ":options", - ":record", - "//grain/_src/core:config", - "//grain/_src/core:monitoring", - "@abseil-py//absl/flags", - "@abseil-py//absl/testing:absltest", - "@abseil-py//absl/testing:parameterized", - ], -) - py_library( name = "checkpoint_handlers", srcs = ["checkpoint_handlers.py"], diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 2aa48a3d6..c22aaf117 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -39,7 +39,6 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import batch as batch_ds from grain._src.python.dataset.transformations import flatmap -from grain._src.python.dataset.transformations import prefetch from grain._src.python.operations import BatchOperation from grain._src.python.operations import Operation from grain._src.python.samplers import Sampler @@ -462,10 +461,8 @@ def _create_dataset(self) -> dataset.IterDataset: ds = _apply_transform_to_dataset(operation, ds) ds = ds.map(lambda r: r.data) if self.multiprocessing_options.num_workers > 0: - ds = prefetch.MultiprocessPrefetchIterDataset( - ds, + ds = ds.mp_prefetch( self.multiprocessing_options, - always_report_worker_state=True, ) if not self._use_native_dataset_checkpointing: ds = _DataLoaderStateIterDataset( diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index f93d41dd2..2cc131e84 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -53,7 +53,6 @@ py_library( "//grain/_src/core:tree_lib", "//grain/_src/python:checkpointing", "//grain/_src/python:grain_logging", - "//grain/_src/python:grain_pool", "//grain/_src/python:options", "//grain/_src/python:shared_memory_array", "//grain/proto:execution_summary_py_pb2", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index bee9d1c78..a2f9b3c28 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -1324,13 +1324,14 @@ def mp_prefetch( A dataset prefetching input elements in separate processes. """ options = options or grain_options.MultiprocessingOptions(num_workers=10) - # Loaded lazily due to a circular dependency (dataset <-> prefetch). + # Loaded lazily due to a circular dependency (dataset <-> process_prefetch). # pylint: disable=g-import-not-at-top - from grain._src.python.dataset.transformations import prefetch + from grain._src.python.dataset.transformations import process_prefetch # pylint: enable=g-import-not-at-top - return prefetch.MultiprocessPrefetchIterDataset( + return process_prefetch.multiprocess_prefetch( self, - multiprocessing_options=options, + num_workers=options.num_workers, + buffer_size=options.per_worker_buffer_size, worker_init_fn=worker_init_fn, sequential_slice=sequential_slice, ) diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index 29734f519..68a6a4c54 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -53,7 +53,7 @@ py_test( py_test( name = "prefetch_test", - timeout = "long", + timeout = "eternal", srcs = ["prefetch_test.py"], shard_count = 50, srcs_version = "PY3", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 03001eb47..e205a3430 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -16,34 +16,23 @@ from __future__ import annotations import collections -from collections.abc import Callable, Iterator, Sequence -import contextlib +from collections.abc import Iterator, Sequence import copy import functools -import math from multiprocessing import queues -from multiprocessing import synchronize import queue -import sys import threading -import time import typing -from typing import Any, Generic, Optional, Protocol, TypeVar +from typing import Any, Optional, Protocol, TypeVar -import cloudpickle from concurrent import futures -from grain._src.core import tree_lib -import multiprocessing as mp -from grain._src.python import grain_pool from grain._src.python import options as grain_options -from grain._src.python import shared_memory_array from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import filter as filter_dataset from grain._src.python.dataset.transformations import interleave from grain._src.python.dataset.transformations import source -import numpy as np T = TypeVar("T") @@ -324,127 +313,6 @@ def close(self) -> None: future.cancel() -def _iterator_with_context( - iterator: contextlib.AbstractContextManager[Iterator[T]], -) -> Iterator[T]: - with iterator as it: - yield from it - - -def _validate_no_double_prefetch( - parent: dataset.MapDataset | dataset.IterDataset, -) -> None: - """Checks that there are no multiple levels of parallelization.""" - to_check: list[dataset.MapDataset | dataset.IterDataset] = [parent] - while to_check: - ds = to_check.pop(0) - if isinstance(ds, MultiprocessPrefetchIterDataset): - raise ValueError( - "Nesting multiprocessing or multithreading is not allowed." - ) - to_check.extend(ds.parents) - - -class MultiprocessPrefetchIterDataset(dataset.IterDataset[T]): - """Uses a pool of processes to prefetch elements ahead of time. - - It usually makes sense to add this transformation in the end of the pipeline - since it will execute the parent IterDataset in multiple processes. - """ - - def __init__( - self, - parent: dataset.IterDataset[T], - multiprocessing_options: grain_options.MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - if multiprocessing_options.num_workers < 0: - raise ValueError( - "`num_workers` must be greater than or equal to 0, got " - f"{multiprocessing_options.num_workers}." - ) - super().__init__(parent) - self._multiprocessing_options = multiprocessing_options - self._worker_init_fn = worker_init_fn - self._sequential_slice = sequential_slice - _validate_no_double_prefetch(self._parent) - self._always_report_worker_state = always_report_worker_state - - def __str__(self) -> str: - return ( - "MultiprocessPrefetchIterDataset(" - f"multiprocessing_options={self._multiprocessing_options})" - ) - - def __iter__(self) -> dataset.DatasetIterator[T]: - if self._multiprocessing_options.num_workers == 0: - return self._parent.__iter__() - return _MultiprocessPrefetchDatasetIterator( - self._parent, - self._multiprocessing_options, - self._worker_init_fn, - self._sequential_slice, - self._always_report_worker_state, - ) - - -# Keys in `MultiprocessPrefetchDatasetIterator` checkpoints. -_WORKERS_STATE = "workers_state" -_ITERATIONS_TO_SKIP = "iterations_to_skip" -_LAST_WORKER_INDEX = "last_worker_index" - -# Minimal interval (in seconds) between consecutive state recordings in worker -# processes of `MultiprocessPrefetchDatasetIterator`. We record the state -# periodically to reduce the overhead of sending the state from workers. -# Note that this is also an approximate upper bound on how long it is going to -# take to recover from a checkpointed state. Larger values will decrease the -# overhead of sending the updated state but will also make recovery from a -# checkpoint longer on average. -_RECORD_STATE_INTERVAL_S = 3 - - -def _copy_leaf_to_shm(leaf: Any, min_size: int = 0) -> Any: - """Copies `leaf` to shared memory if it's a big enough numpy array.""" - if isinstance(leaf, shared_memory_array.SharedMemoryArray): - return leaf.metadata - if ( - not isinstance(leaf, np.ndarray) - or leaf.dtype.hasobject - or not leaf.flags.c_contiguous - or math.prod(leaf.shape) == 0 - or leaf.nbytes < min_size - ): - return leaf - - shared_memory_arr = shared_memory_array.SharedMemoryArray( - leaf.shape, leaf.dtype - ) - np.copyto(shared_memory_arr, leaf, casting="no") - return shared_memory_arr.metadata - - -def _copy_struct_to_shm(struct: Any, min_size: int = 0) -> Any: - """Copies leaf ndarrays of the structure to shared memory.""" - return tree_lib.map_structure( - functools.partial(_copy_leaf_to_shm, min_size=min_size), struct - ) - - -def _open_leaf_from_shm(leaf: Any) -> Any: - """Recovers `leaf` from shared memory if it's a numpy array metadata.""" - if isinstance(leaf, shared_memory_array.SharedMemoryArrayMetadata): - leaf = shared_memory_array.SharedMemoryArray.from_metadata(leaf) - leaf.unlink_on_del() - return leaf - - -def _open_struct_from_shm(struct: Any) -> Any: - """Recovers leaf ndarrays of the structure from shared memory.""" - return tree_lib.map_structure(_open_leaf_from_shm, struct) - - def _set_slice_iter_dataset( ds: dataset.IterDataset, sl: slice, @@ -501,120 +369,6 @@ def _set_slice_map_dataset( _set_slice_iter_dataset(parent, sl, sequential_slice) -def _check_picklable( - ds: dataset.IterDataset | dataset.MapDataset, -): - """Detects the first unpickle-able dataset in post-order. - - Args: - ds: IterDataset or MapDataset to check whether it is picklable. - - NOTE: This function's time complexity is O(n^2) where n is the number of - Grain dataset operations because `cloudpickle.dumps(ds)` will trigger - pickling into all the datasets. If this naive O(n^2) algorithm takes too - much time, we could consider doing copying `ds`, delete its parents and then - do `cloudpickle.dumps(new_ds)` to reduce the time complexity to O(n). - """ - - # Traverses the graph in post-order to find the first unpickle-able subtree - for parent in ds.parents: - _check_picklable(parent) - - try: - cloudpickle.dumps(ds) - except Exception as e: # pylint: disable=broad-exception-caught - if sys.version_info >= (3, 11): - e.add_note( - f"Dataset: {ds} cannot be pickled!" - ) - raise e - - -class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]): - """Implements `GetElementProducerFn` for `grain_pool.MultiProcessIterator`. - - This class implements `GetElementProducerFn` with `serialize` being overridden - to generate better error messages if user-provided dataset is not pickle-able. - """ - - def __init__( - self, - state: dict[str, dict[str, Any] | int], - ds: dataset.IterDataset[T], - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - self._state = state - self._ds = ds - self._sequential_slice = sequential_slice - self._always_report_worker_state = always_report_worker_state - - def __call__( - self, - *, - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_out_queue: queues.Queue | None = None, - ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: - if worker_count > 1: - _set_slice_iter_dataset( - self._ds, - slice(worker_index, None, worker_count), - self._sequential_slice, - ) - it = self._ds.__iter__() - it._ctx.mp_context = base.MultiprocessingContext( - process_index=worker_index, process_count=worker_count - ) - min_shm_size = it._ctx.dataset_options.min_shm_size - # Recover from the last recorded state for the given worker. - worker_state = self._state[_WORKERS_STATE][str(worker_index)] - if worker_state is not None: - it.set_state(worker_state) - # Set the stats queue in worker process to send stats to the main process. - it._stats._config.stats_out_queue = stats_out_queue # pytype: disable=attribute-error - # Skip the required number of iterations after the last recorded state. - for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]): - _ = next(it) - last_recorded_state_time = time.time() - for element in it: - now = time.time() - element = _copy_struct_to_shm(element, min_size=min_shm_size) - # If the node is prefetch, we already record the bytes produced in it's - # __next__ method. - if not it._stats._config.is_prefetch: - it._stats.record_bytes_produced(element) - if ( - self._always_report_worker_state - or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S - ): - last_recorded_state_time = now - yield (element, it.get_state()) # pytype: disable=attribute-error - else: - yield (element, None) - - def serialize(self) -> bytes: - """Overrides the default implementation to generate better error messages.""" - - try: - return cloudpickle.dumps(self) - except Exception as e: # pylint: disable=broad-except - # Calls `_check_picklable` to generate useful pickle errors - # - # Note: No need to check `self._state` because it should not generate - # unpicklable errors and it is controlled by us, not from user's code - # in most cases. Except for the case when users try to implement their own - # `MapDataset` and `IterDataset` with custom pickle-ing logic that - # contains unpickle-able objects. - _check_picklable(self._ds) - - # If somehow we cannot find the dataset that is causing the pickle - # issues, just raise the original error - raise e - - def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: result = base.DatasetOptions() to_visit = [ds] @@ -626,172 +380,6 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result -class _MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): - """Iterator that performs prefetching using a multiprocessing pool.""" - - def __init__( - self, - parent: dataset.IterDataset[T], - multiprocessing_options: grain_options.MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - sequential_slice: bool = False, - always_report_worker_state: bool = False, - ): - super().__init__() - self._iter_parent = parent - # Since the parent iterator is going to be created in each subprocess, and - # the options are propagated during iterator creation, we need to manually - # propagate them. - self._ctx.dataset_options = _get_dataset_options(parent) - self._multiprocessing_options = multiprocessing_options - self._worker_init_fn = worker_init_fn - self._sequential_slice = sequential_slice - # The underlying iterator producing elements and workers state. - self._iterator = None - # Raw reference to the underlying iterator that can be used to determine the - # last worker index. - self._raw_iterator = None - # Create initial state. We record state of each worker periodically together - # with the number of iterations without the recorded state and index of the - # last worker. - iterations_to_skip: dict[str, int] = { - str(i): 0 for i in range(multiprocessing_options.num_workers) - } - workers_state: dict[str, Any] = { - str(i): None for i in range(multiprocessing_options.num_workers) - } - self._stats_in_queues = tuple( - mp.get_context("spawn").Queue(maxsize=5) - for _ in range(multiprocessing_options.num_workers) - ) - self._start_profiling_event = mp.get_context("spawn").Event() - self._stop_profiling_event = mp.get_context("spawn").Event() - - self._state: dict[str, dict[str, Any] | int] = { - _WORKERS_STATE: workers_state, - _ITERATIONS_TO_SKIP: iterations_to_skip, - _LAST_WORKER_INDEX: -1, - } - - self._always_report_worker_state = always_report_worker_state - - def _initialize_stats( - self, execution_tracking_mode: base.ExecutionTrackingMode - ): - self._stats = _initialize_prefetch_stats( - self, - execution_tracking_mode, - parent_stats=[], - stats_in_queues=self._stats_in_queues, - ) - return self._stats - - @functools.cached_property - def _stats(self): - return self._initialize_stats( - self._ctx.dataset_options.execution_tracking_mode - ) - - def __iter__(self) -> dataset.DatasetIterator[T]: - return self - - @dataset_stats.record_next_duration_if_output - def __next__(self) -> T: - self._assert_not_closed() - self._ensure_iterator_initialized() - # The time recorded here is the time spent in prefetch node to return an - # element, including the time spent in parent node. - timer = dataset_stats.Timer() - result, state = next(self._iterator) - with self._stats.record_self_time(offset_ns=timer.value()): - worker_index = self._raw_iterator.get_last_worker_index() # pytype: disable=attribute-error - - # pytype: disable=annotation-type-mismatch - iterations_to_skip: dict[str, Any] = self._state[_ITERATIONS_TO_SKIP] - worker_state: dict[str, Any] = self._state[_WORKERS_STATE] - # pytype: enable=annotation-type-mismatch - - self._state[_LAST_WORKER_INDEX] = worker_index - worker_index_str = str(worker_index) - if state is None: - iterations_to_skip[worker_index_str] += 1 - else: - iterations_to_skip[worker_index_str] = 0 - worker_state[worker_index_str] = state - result = self._stats.record_bytes_produced(result) - return _open_struct_from_shm(result) - - def start_prefetch(self) -> None: - """Prefetches elements from the iterator. - - This will run background processes for prefetching. To make sure to clean up - the resources, it should be followed by at least one `next` call. - """ - self._ensure_iterator_initialized() - - def set_state(self, state: dict[str, dict[str, Any] | int]) -> None: - self._state = state - self._raw_iterator = None - self._iterator = None - - def get_state(self) -> dict[str, Any]: - result = copy.deepcopy(self._state) - workers_state: dict[str, Any] = result[_WORKERS_STATE] # pytype: disable=annotation-type-mismatch - parent_state = None - for worker_index, worker_state in workers_state.items(): - # Create initial state from the parent iterator. This is to make sure the - # spec of the produced iterator does not change. - if worker_state is None: - parent_state = parent_state or self._iter_parent.__iter__().get_state() - workers_state[worker_index] = copy.deepcopy(parent_state) - return result - - def _ensure_iterator_initialized(self) -> None: - if self._iterator is None: - self._raw_iterator = self._create_iterator_context() - self._raw_iterator.start_prefetch() - self._iterator = _iterator_with_context(self._raw_iterator) - - def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: - """Creates a `MultiProcessIterator`.""" - # Apply the latest options to the subprocess dataset. We delay this until - # starting subprocesses because child iterators may update them. - ds = dataset.WithOptionsIterDataset( - self._iter_parent, self._ctx.dataset_options - ) - get_element_producer_fn = GetElementProducerFn( - self._state, - ds, - self._sequential_slice, - self._always_report_worker_state, - ) - - return grain_pool.MultiProcessIterator( - get_element_producer_fn, - self._multiprocessing_options, - (self._state[_LAST_WORKER_INDEX] + 1) - % self._multiprocessing_options.num_workers, - self._worker_init_fn, - self._start_profiling_event, - self._stop_profiling_event, - self._stats_in_queues, - ) - - def __str__(self) -> str: - return ( - "MultiprocessPrefetchDatasetIterator(" - f"multiprocessing_options={self._multiprocessing_options})" - ) - - def close(self) -> None: - """Shuts down the prefetching threads and multiprocessing pool.""" - if self._closed: - return - self._closed = True - if self._raw_iterator is not None: - self._raw_iterator.stop_prefetch() - - class ThreadPrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a synchronized queue for prefetching. @@ -1017,8 +605,8 @@ def __str__(self) -> str: def multithread_prefetch( ds: dataset.IterDataset[T], - num_threads: int, - buffer_size: int, + num_threads: int = 0, + buffer_size: int = 1, sequential_slice: bool = False, ) -> dataset.IterDataset[T]: """Uses a pool of threads to prefetch elements ahead of time. @@ -1043,14 +631,17 @@ def multithread_prefetch( if num_threads == 0: return ds - _validate_no_double_prefetch(ds) + dataset_options = _get_dataset_options(ds) shards = [] for i in range(num_threads): - worker_ds = copy.deepcopy(ds) - _set_slice_iter_dataset( - worker_ds, slice(i, None, num_threads), sequential_slice - ) + if num_threads == 1: + worker_ds = ds + else: + worker_ds = copy.deepcopy(ds) + _set_slice_iter_dataset( + worker_ds, slice(i, None, num_threads), sequential_slice + ) shards.append( _MpContextIterDataset( worker_ds, @@ -1061,6 +652,10 @@ def multithread_prefetch( ) ) - return interleave.InterleaveIterDataset( + ds = interleave.InterleaveIterDataset( shards, cycle_length=num_threads, iter_buffer_size=buffer_size ) + # Apply options from parent dataset because interleave dataset does not + # propagate options. + ds = dataset.WithOptionsIterDataset(ds, dataset_options) + return ds diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 597d8565b..e8333002a 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -13,13 +13,11 @@ # limitations under the License. from concurrent import futures import dataclasses -import logging as std_logging import platform import sys import threading import time from typing import TypeVar, cast -from unittest import mock from absl import logging from absl.testing import absltest @@ -433,324 +431,444 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) -class MultiprocessPrefetchIterDatasetTest(parameterized.TestCase): +class ThreadPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): super().setUp() - ds = dataset.MapDataset.range(20) - ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) - self.iter_ds = ds.filter(FilterKeepingOddElementsOnly()) + self.ds = ( + dataset.MapDataset.range(20) + .to_iter_dataset() + .filter(FilterKeepingOddElementsOnly()) + ) @parameterized.named_parameters( dict( - testcase_name='0_workers', - num_workers=0, - per_worker_buffer_size=1, + testcase_name='no_prefetch', + prefetch_buffer_size=0, + warm_start=False, ), dict( - testcase_name='1_worker', - num_workers=1, - per_worker_buffer_size=1, + testcase_name='no_prefetch_with_warm_start', + prefetch_buffer_size=0, + warm_start=True, ), dict( - testcase_name='1_worker_large_buffer', - num_workers=1, - per_worker_buffer_size=20, + testcase_name='thread', + prefetch_buffer_size=1, + warm_start=True, ), dict( - testcase_name='10_workers', - num_workers=10, - per_worker_buffer_size=1, + testcase_name='thread_large_buffer', + prefetch_buffer_size=20, + warm_start=False, ), dict( - testcase_name='10_workers_large_buffer', - num_workers=10, - per_worker_buffer_size=20, + testcase_name='thread_huge_buffer', + prefetch_buffer_size=200, + warm_start=True, ), ) - def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): - prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers, per_worker_buffer_size), + def test_prefetch_data(self, prefetch_buffer_size: int, warm_start: bool): + prefetch_lazy_iter_ds = prefetch.ThreadPrefetchIterDataset( + self.ds, prefetch_buffer_size=prefetch_buffer_size ) - actual = list(prefetch_lazy_iter_ds) + ds = prefetch_lazy_iter_ds.__iter__() + if warm_start: + ds.start_prefetch() + actual = list(ds) expected = list(range(1, 20, 2)) self.assertSequenceEqual(actual, expected) - def test_prefetch_size_zero_data(self): - ds = dataset.MapDataset.source( - [np.zeros(shape=(0,), dtype=np.int64)] - ).repeat(3) - iter_ds = ds.to_iter_dataset() - prefetch_lazy_iter_ds = prefetch.MultiprocessPrefetchIterDataset( - iter_ds, - options.MultiprocessingOptions(num_workers=1), + @parameterized.parameters([False, True]) + def test_checkpoint(self, warm_start: bool): + ds = prefetch.ThreadPrefetchIterDataset( + self.ds, + prefetch_buffer_size=500, ) - actual = list(prefetch_lazy_iter_ds) - expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 - self.assertLen(actual, 3) - self.assertLen(expected, 3) - for i in range(3): - np.testing.assert_array_equal(actual[i], expected[i]) - - @parameterized.product( - ( - dict( - num_workers=0, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=1, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=10, - record_state_interval=prefetch._RECORD_STATE_INTERVAL_S, - ), - dict( - num_workers=10, - record_state_interval=0, - ), - ), - step_index=[0, 3, 8], - ) - def test_checkpoint( - self, num_workers: int, record_state_interval: int, step_index: int - ): - with mock.patch.object( - prefetch, '_RECORD_STATE_INTERVAL_S', record_state_interval - ): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers), - ) - ds_iter = ds.__iter__() + ds_iter = ds.__iter__() + if warm_start: + ds_iter.start_prefetch() - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) - ds_iter.set_state(checkpoints[step_index]) - for i in range(step_index, max_steps): + for starting_step in range(9): + ds_iter.set_state(checkpoints[starting_step]) + for i in range(starting_step, max_steps): value = next(ds_iter) self.assertEqual(value, values_without_interruption[i]) - def test_set_state_twice(self): - with mock.patch.object(prefetch, '_RECORD_STATE_INTERVAL_S', 0): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(2), - ) - ds_iter = ds.__iter__() - - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in [0, 3, 8]: - ds_iter.set_state(checkpoints[starting_step]) - for i in range(starting_step, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_fails_with_negative_num_workers(self): - with self.assertRaisesRegex( - ValueError, '`num_workers` must be greater than or equal to 0' - ): - prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers=-1), - ) - - def test_fails_with_multiple_prefetches(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds, - options.MultiprocessingOptions(num_workers=10), + def test_set_state_on_fresh_iterator(self): + ds = prefetch.ThreadPrefetchIterDataset( + self.ds, + prefetch_buffer_size=2, ) - with self.assertRaisesRegex( - ValueError, - 'Nesting multiprocessing or multithreading is not allowed.', - ): - _ = prefetch.MultiprocessPrefetchIterDataset( - ds, - options.MultiprocessingOptions(num_workers=1), - ) + ds_iter = ds.__iter__() - def test_works_with_iter_source_single_worker(self): - # Even though a pure IterDataset cannot be sliced, we should still be able - # to multiprocess-prefetch it with a single worker, since that doesn't - # require any slicing. - ds = prefetch.MultiprocessPrefetchIterDataset( - RepeatedIntSourceIterDataset().map(lambda x: x + 1), - options.MultiprocessingOptions(num_workers=1), - ) - ds_iter = iter(ds) - self.assertEqual(next(ds_iter), 2) + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) - def test_fails_with_iter_source_multiple_workers(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - RepeatedIntSourceIterDataset().map(lambda x: x + 1), - options.MultiprocessingOptions(num_workers=2), - ) - ds_iter = iter(ds) + for starting_step in range(9): + ds_iter = ds.__iter__() + ds_iter.set_state(checkpoints[starting_step]) + for i in range(starting_step, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) - with self.assertRaisesRegex( - Exception, - 'Cannot slice `IterDataset` source.', - ): - next(ds_iter) + def test_get_state_doesnt_start_prefetch(self): + event = threading.Event() - def test_propagates_transform_error(self): - error_msg = 'I shall fail!' + def f(x): + event.set() + return x - def failing_transform(element): - del element - raise ValueError(error_msg) + ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + it = ds.__iter__() + it.get_state() + time.sleep(1) + self.assertFalse(event.is_set()) - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(failing_transform), - options.MultiprocessingOptions(num_workers=1), - ) - with self.assertRaisesRegex(Exception, error_msg): - list(ds) + def test_parent_dataset_modifies_state(self): + class TestIterator(dataset.DatasetIterator): - def test_reports_worker_crash(self): - def failing_transform(element): - del element - sys.exit(123) + def __next__(self): + return 1 - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(failing_transform), - options.MultiprocessingOptions(num_workers=1), - ) - with self.assertRaisesRegex( - RuntimeError, 'was terminated unexpectedly with exit code 123' - ): - list(ds) + def get_state(self): + return {'test': 1} - def test_reports_unpicklable_transform(self): - class UnpicklableObject: + def set_state(self, state): + pass - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') + class TestDataset(dataset.IterDataset): - local_state = UnpicklableObject() + def __iter__(self): + return TestIterator() - ds = prefetch.MultiprocessPrefetchIterDataset( - self.iter_ds.map(lambda _: 1 if local_state is None else 2), - options.MultiprocessingOptions(num_workers=1), - ) + parent = TestDataset() + ds = prefetch.ThreadPrefetchIterDataset(parent, prefetch_buffer_size=1) + ds_iter = ds.__iter__() + ds_iter.set_state({'test': 2}) + self.assertEqual(ds_iter.get_state(), {'test': 1}) + + def test_fails_with_negative_prefetch_buffer_size(self): with self.assertRaisesRegex( - ValueError, 'UnpicklableObject is not picklable' - ) as context_manager: - list(ds) + ValueError, '`prefetch_buffer_size` must be greater than or equal to 0' + ): + prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1) - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapIterDataset.* cannot be pickled!', - ) + def test_start_prefetch_with_mp_prefetch_but_no_read(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat().to_iter_dataset() + ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=2)) + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + it = ds.__iter__() + it.start_prefetch() + del it - def test_reports_first_unpicklable_dataset_when_with_multiple_parents(self): - class UnpicklableObject: + def test_does_not_create_reference_to_itself(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + it = ds.__iter__() + refcount_before_iteration = sys.getrefcount(it) + _ = next(it) + refcount_after_iteration = sys.getrefcount(it) + self.assertEqual(refcount_before_iteration, refcount_after_iteration) - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') + def test_does_not_hang_after_stop_iteration(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + it = ds.__iter__() + _ = list(it) + self.assertEmpty(list(it)) - local_unpicklable_obj = UnpicklableObject() + def test_nonnative_iterator(self): - class LeftTransform(transforms.MapTransform): + class TestIterator: - def map(self, x): - return x if local_unpicklable_obj else x + def __init__(self): + self._counter = 0 - class RightTransform(transforms.MapTransform): + def __iter__(self): + return self - def map(self, x): - return x if local_unpicklable_obj else x + def __next__(self) -> int: + self._counter += 1 + if self._counter > 10: + raise StopIteration + return self._counter - ds_left = dataset.MapDataset.range(0, 10) - ds_left = ds_left.map(LeftTransform()) - ds_right = dataset.MapDataset.range(10, 20) - ds_right = ds_right.map(RightTransform()) + def get_state(self): + return {'counter': self._counter} - ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) + def set_state(self, state): + self._counter = state['counter'] - iter_ds = ds.to_iter_dataset( - read_options=options.ReadOptions(prefetch_buffer_size=0) + test_iterator = TestIterator() + it = prefetch.ThreadPrefetchDatasetIterator( + test_iterator, prefetch_buffer_size=10 ) - iter_ds = iter_ds.mp_prefetch() - - with self.assertRaisesRegex( - ValueError, - r'UnpicklableObject is not picklable', - ) as context_manager: - list(iter_ds) - - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapMapDataset\(transform=LeftTransform\) cannot be' - r' pickled!', - ) - - def test_reports_unpicklable_issue_when_only_one_parent_unpicklable(self): - class UnpicklableObject: - - def __getstate__(self): - raise ValueError('UnpicklableObject is not picklable') - - class PickleableTransform(transforms.MapTransform): - - def map(self, x): - return x - - local_unpicklable_obj = UnpicklableObject() - - class RightTransform(transforms.MapTransform): - - def map(self, x): - return x if local_unpicklable_obj else x - - ds_left = dataset.MapDataset.range(0, 10) - ds_left = ds_left.map(PickleableTransform()) - ds_right = dataset.MapDataset.range(10, 20) - ds_right = ds_right.map(RightTransform()) - - ds = dataset.MapDataset.mix([ds_left, ds_right], [1.0, 1.0]) + elements = [] + checkpoint_step = 5 + for _ in range(checkpoint_step): + elements.append(next(it)) + checkpoint = it.get_state() + elements.extend(it) + self.assertEqual(elements, list(range(1, 11))) + it.set_state(checkpoint) + self.assertEqual(list(it), elements[checkpoint_step:]) - iter_ds = ds.to_iter_dataset( - read_options=options.ReadOptions(prefetch_buffer_size=0) + def test_no_mem_leak(self): + ds = ( + dataset.MapDataset.range(1000) + .repeat() + .map(lambda x: x * np.ones((1000, 1000), dtype=np.int64)) + .to_iter_dataset(options.ReadOptions(prefetch_buffer_size=0)) ) - iter_ds = iter_ds.mp_prefetch() + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + # If buffered elements are not cleaned up when the iterator is gc'ed, this + # test will OOM. + for _ in range(1000): + it = ds.__iter__() + for _ in range(5): + _ = next(it) - with self.assertRaisesRegex( - ValueError, 'UnpicklableObject is not picklable' - ) as context_manager: - list(iter_ds) - - if sys.version_info >= (3, 11): - self.assertRegex( - ''.join(context_manager.exception.__notes__), - r'Dataset: MapMapDataset\(transform=RightTransform\) cannot be' - r' pickled!', - ) + @parameterized.parameters([True, False]) + def test_no_mem_leak_with_double_prefetch(self, close: bool): + ds = ( + dataset.MapDataset.range(1000) + .repeat() + .map(lambda x: x * np.ones((1000, 1000), dtype=np.int64)) + .to_iter_dataset(options.ReadOptions(prefetch_buffer_size=0)) + ) + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + ds = ds.map(lambda x: x + 1) + ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) + # If buffered elements are not cleaned up when the iterator is gc'ed, this + # test will OOM. + for _ in range(1000): + it = ds.__iter__() + for _ in range(5): + _ = next(it) + if close: + it.close() # pytype: disable=attribute-error + + @absltest.skipIf(platform.system() == 'Darwin', 'Fails on macos-14 runner.') + @parameterized.parameters([True, False]) + def test_early_break_continues_prefetching(self, close: bool): + count = 0 + count_lock = threading.Lock() + + class SlowCountingSource: + + def __len__(self): + return 16 + + def __getitem__(self, index): + nonlocal count + time.sleep(0.1) + with count_lock: + count += 1 + return index + + read_options = options.ReadOptions(num_threads=2) + ds = dataset.MapDataset.source(SlowCountingSource()).to_iter_dataset( + read_options + ) + iterator = ds.__iter__() + + assert count == 0 + if close: + next(iterator) + self.assertGreater(count, 0) + iterator.close() + time.sleep(1) + self.assertLess(count, 8) + else: + next(iterator) + self.assertGreater(count, 0) + time.sleep(1) + self.assertGreater(count, 8) + + +class _MpContextCheckIterDataset(dataset.IterDataset[_T]): + + def __iter__(self) -> dataset.DatasetIterator[_T]: + return _MpContextCheckIterator(self._parent.__iter__()) + + +class _MpContextCheckIterator(dataset.DatasetIterator[_T]): + + def __next__(self) -> tuple[_T, base.MultiprocessingContext]: + element = next(self._parent) + return (element, self._ctx.mp_context) + + def get_state(self): + return self._parent.get_state() + + def set_state(self, state): + self._parent.set_state(state) + + +class MultithreadPrefetchTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + ds = dataset.MapDataset.range(20) + self.iter_ds = ds.to_iter_dataset().filter(FilterKeepingOddElementsOnly()) + + @parameterized.named_parameters( + dict( + testcase_name='0_workers', + num_threads=0, + per_worker_buffer_size=1, + ), + dict( + testcase_name='1_worker', + num_threads=1, + per_worker_buffer_size=1, + ), + dict( + testcase_name='1_worker_large_buffer', + num_threads=1, + per_worker_buffer_size=20, + ), + dict( + testcase_name='10_workers', + num_threads=10, + per_worker_buffer_size=1, + ), + dict( + testcase_name='10_workers_large_buffer', + num_threads=10, + per_worker_buffer_size=20, + ), + ) + def test_prefetch_data(self, num_threads: int, per_worker_buffer_size: int): + prefetch_lazy_iter_ds = prefetch.multithread_prefetch( + self.iter_ds, + num_threads=num_threads, + buffer_size=per_worker_buffer_size, + ) + actual = list(prefetch_lazy_iter_ds) + expected = list(range(1, 20, 2)) + self.assertSequenceEqual(actual, expected) + + def test_prefetch_size_zero_data(self): + ds = dataset.MapDataset.source( + [np.zeros(shape=(0,), dtype=np.int64)] + ).repeat(3) + iter_ds = ds.to_iter_dataset() + prefetch_lazy_iter_ds = prefetch.multithread_prefetch( + iter_ds, + num_threads=1, + ) + actual = list(prefetch_lazy_iter_ds) + expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 + self.assertLen(actual, 3) + self.assertLen(expected, 3) + for i in range(3): + np.testing.assert_array_equal(actual[i], expected[i]) + + @parameterized.product( + ( + dict(num_threads=0), + dict(num_threads=1), + dict(num_threads=10), + ), + step_index=[0, 3, 8], + ) + def test_checkpoint(self, num_threads: int, step_index: int): + ds = prefetch.multithread_prefetch( + self.iter_ds, + num_threads=num_threads, + ) + ds_iter = ds.__iter__() + + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) + + ds_iter.set_state(checkpoints[step_index]) + for i in range(step_index, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) + + def test_set_state_twice(self): + ds = prefetch.multithread_prefetch( + self.iter_ds, + num_threads=2, + ) + ds_iter = ds.__iter__() + + max_steps = 10 + values_without_interruption = [] + checkpoints = [] + for _ in range(max_steps): + checkpoints.append(ds_iter.get_state()) + values_without_interruption.append(next(ds_iter)) + + for starting_step in [0, 3, 8]: + ds_iter.set_state(checkpoints[starting_step]) + for i in range(starting_step, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) + + def test_works_with_iter_source_single_worker(self): + # Even though a pure IterDataset cannot be sliced, we should still be able + # to multiprocess-prefetch it with a single worker, since that doesn't + # require any slicing. + ds = prefetch.multithread_prefetch( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + num_threads=1, + ) + ds_iter = iter(ds) + self.assertEqual(next(ds_iter), 2) + + def test_fails_with_iter_source_multiple_workers(self): + with self.assertRaisesRegex( + ValueError, + 'Cannot slice `IterDataset` source.', + ): + prefetch.multithread_prefetch( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + num_threads=2, + ) + + def test_propagates_transform_error(self): + error_msg = 'I shall fail!' + + def failing_transform(element): + del element + raise ValueError(error_msg) + + ds = prefetch.multithread_prefetch( + self.iter_ds.map(failing_transform), + num_threads=1, + ) + with self.assertRaisesRegex(Exception, error_msg): + list(ds) @parameterized.product( start_prefetch_calls=[0, 1, 10], - num_workers=[6], + num_threads=[6], per_worker_buffer_size=[1, 20], ) def test_start_prefetch( self, start_prefetch_calls: int, - num_workers: int, + num_threads: int, per_worker_buffer_size: int, ): class _SleepTransform(transforms.MapTransform): @@ -761,10 +879,11 @@ def map(self, features): ds = dataset.MapDataset.range(10) ds = ds.map(_SleepTransform()) - ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) - ds = prefetch.MultiprocessPrefetchIterDataset( + ds = ds.to_iter_dataset() + ds = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers, per_worker_buffer_size), + num_threads=num_threads, + buffer_size=per_worker_buffer_size, ) it = ds.__iter__() @@ -792,7 +911,7 @@ def test_prefetch_but_no_read(self, sleep_s): ds = dataset.MapDataset.source([1, 2, 3]).repeat() ds = ds.filter(lambda x: x > 3) ds = ds.to_iter_dataset() - ds = ds.mp_prefetch() + ds = prefetch.multithread_prefetch(ds, num_threads=1) it = ds.__iter__() it.start_prefetch() time.sleep(sleep_s) @@ -801,9 +920,9 @@ def test_prefetch_but_no_read(self, sleep_s): def test_prefetch_with_random_map(self): ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) - ds = prefetch.MultiprocessPrefetchIterDataset( + ds = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers=5), + num_threads=5, ) # Make sure that sliced datasets on workers are seeded differently and thus # produce different random elements. @@ -817,7 +936,7 @@ def test_concurrent_start_prefetch(self): def make_iter(i): ds = dataset.MapDataset.source([i]) ds = ds.to_iter_dataset() - ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) + ds = prefetch.multithread_prefetch(ds, num_threads=1) return ds.__iter__() iters = [make_iter(i) for i in range(num_iters)] @@ -832,30 +951,33 @@ def test_options_before_prefetch(self): ds = ds.to_iter_dataset() ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) ds = dataset.WithOptionsIterDataset(ds, ds_options) - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds = prefetch.multithread_prefetch(ds, num_threads=1) ds = ds.filter(lambda x: x > 2) with self.assertRaises(Exception): list(ds) def test_multiprocess_prefetch_with_sequential_slice(self): ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds = prefetch.MultiprocessPrefetchIterDataset( + ds = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) self.assertEqual(list(ds), [0, 4, 7, 1, 5, 8, 2, 6, 9, 3]) def test_multiprocess_prefetch_with_default_slice_non_sequential(self): ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds_sequential_off = prefetch.MultiprocessPrefetchIterDataset( + ds_sequential_off = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=False, ) - ds_sequential_default = prefetch.MultiprocessPrefetchIterDataset( + ds_sequential_default = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, ) elements_sequential_off = list(ds_sequential_off) elements_sequential_default = list(ds_sequential_default) @@ -870,9 +992,10 @@ def test_multiprocess_prefetch_with_default_slice_non_sequential(self): def test_multiprocess_prefetch_sequential_slice_order_from_source(self): ds = dataset.MapDataset.source(range(10)).to_iter_dataset() - ds_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds_sequential_on = prefetch.multithread_prefetch( ds, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) elements_sequential_on = list(ds_sequential_on) @@ -880,9 +1003,10 @@ def test_multiprocess_prefetch_sequential_slice_order_from_source(self): def test_multiprocess_prefetch_sequential_slice_order_from_range(self): ds_range = dataset.MapDataset.range(10).to_iter_dataset() - ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds_range_sequential_on = prefetch.multithread_prefetch( ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) elements_range_sequential_on = list(ds_range_sequential_on) @@ -895,9 +1019,10 @@ def test_multiprocess_prefetch_sequential_slice_order_from_range_slice(self): ds_range = dataset.MapDataset.range( start=2, stop=21, step=3 ).to_iter_dataset() - ds_range_sequential_on = prefetch.MultiprocessPrefetchIterDataset( + ds_range_sequential_on = prefetch.multithread_prefetch( ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) elements_range_sequential_on = list(ds_range_sequential_on) @@ -909,14 +1034,16 @@ def test_multiprocess_prefetch_sequential_slice_order_from_range_slice(self): def test_multiprocess_prefetch_sequential_slice_order_same(self): ds_source = dataset.MapDataset.source(range(10)).to_iter_dataset() ds_range = dataset.MapDataset.range(10).to_iter_dataset() - ds_source_mp = prefetch.MultiprocessPrefetchIterDataset( + ds_source_mp = prefetch.multithread_prefetch( ds_source, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) - ds_range_mp = prefetch.MultiprocessPrefetchIterDataset( + ds_range_mp = prefetch.multithread_prefetch( ds_range, - options.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=1), + num_threads=3, + buffer_size=1, sequential_slice=True, ) elements_source = list(ds_source_mp) @@ -927,475 +1054,20 @@ def test_options_after_prefetch(self): ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) ds = ds.filter(lambda x: x > 2) ds = ds.to_iter_dataset() - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=1)) + ds = prefetch.multithread_prefetch(ds, num_threads=1) ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) ds = dataset.WithOptionsIterDataset(ds, ds_options) with self.assertRaises(Exception): list(ds) - def test_worker_init_fn(self): - def set_worker_index_and_count(worker_index: int, worker_count: int): - log_formatter = std_logging.Formatter( - f'[Worker {worker_index} out of {worker_count}] %(message)s' - ) - logging.get_absl_handler().setFormatter(log_formatter) - - def map_fn(x): - # absl logging from workers is not propagated to the main process in unit - # tests. Therefore, we manually pass the formatted log message. - record = logging.get_absl_logger().makeRecord( - 'grain', - logging.INFO, - 'grain_pool_test', - 123, - f'processing element {x}', - (), - None, - ) - return logging.get_absl_handler().format(record) - - ds = dataset.MapDataset.range(2).map(map_fn) - ds = ds.to_iter_dataset() - ds = ds.mp_prefetch( - options.MultiprocessingOptions(num_workers=2), - worker_init_fn=set_worker_index_and_count, - ) - self.assertEqual( - list(ds), - [ - '[Worker 0 out of 2] processing element 0', - '[Worker 1 out of 2] processing element 1', - ], - ) - - -class ThreadPrefetchIterDatasetTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ds = ( - dataset.MapDataset.range(20) - .to_iter_dataset() - .filter(FilterKeepingOddElementsOnly()) - ) - - @parameterized.named_parameters( - dict( - testcase_name='no_prefetch', - prefetch_buffer_size=0, - warm_start=False, - ), - dict( - testcase_name='no_prefetch_with_warm_start', - prefetch_buffer_size=0, - warm_start=True, - ), - dict( - testcase_name='thread', - prefetch_buffer_size=1, - warm_start=True, - ), - dict( - testcase_name='thread_large_buffer', - prefetch_buffer_size=20, - warm_start=False, - ), - dict( - testcase_name='thread_huge_buffer', - prefetch_buffer_size=200, - warm_start=True, - ), - ) - def test_prefetch_data(self, prefetch_buffer_size: int, warm_start: bool): - prefetch_lazy_iter_ds = prefetch.ThreadPrefetchIterDataset( - self.ds, prefetch_buffer_size=prefetch_buffer_size - ) - ds = prefetch_lazy_iter_ds.__iter__() - if warm_start: - ds.start_prefetch() - actual = list(ds) - expected = list(range(1, 20, 2)) - self.assertSequenceEqual(actual, expected) - - @parameterized.parameters([False, True]) - def test_checkpoint(self, warm_start: bool): - ds = prefetch.ThreadPrefetchIterDataset( - self.ds, - prefetch_buffer_size=500, - ) - ds_iter = ds.__iter__() - if warm_start: - ds_iter.start_prefetch() - - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in range(9): - ds_iter.set_state(checkpoints[starting_step]) - for i in range(starting_step, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_set_state_on_fresh_iterator(self): - ds = prefetch.ThreadPrefetchIterDataset( - self.ds, - prefetch_buffer_size=2, - ) - ds_iter = ds.__iter__() - - max_steps = 10 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in range(9): - ds_iter = ds.__iter__() - ds_iter.set_state(checkpoints[starting_step]) - for i in range(starting_step, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_get_state_doesnt_start_prefetch(self): - event = threading.Event() - - def f(x): - event.set() - return x - - ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - it = ds.__iter__() - it.get_state() - time.sleep(1) - self.assertFalse(event.is_set()) - - def test_parent_dataset_modifies_state(self): - class TestIterator(dataset.DatasetIterator): - - def __next__(self): - return 1 - - def get_state(self): - return {'test': 1} - - def set_state(self, state): - pass - - class TestDataset(dataset.IterDataset): - - def __iter__(self): - return TestIterator() - - parent = TestDataset() - ds = prefetch.ThreadPrefetchIterDataset(parent, prefetch_buffer_size=1) - ds_iter = ds.__iter__() - ds_iter.set_state({'test': 2}) - self.assertEqual(ds_iter.get_state(), {'test': 1}) - - def test_fails_with_negative_prefetch_buffer_size(self): - with self.assertRaisesRegex( - ValueError, '`prefetch_buffer_size` must be greater than or equal to 0' - ): - prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1) - - def test_start_prefetch_with_mp_prefetch_but_no_read(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat().to_iter_dataset() - ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=2)) - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - it = ds.__iter__() - it.start_prefetch() - del it - - def test_does_not_create_reference_to_itself(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - it = ds.__iter__() - refcount_before_iteration = sys.getrefcount(it) - _ = next(it) - refcount_after_iteration = sys.getrefcount(it) - self.assertEqual(refcount_before_iteration, refcount_after_iteration) - - def test_does_not_hang_after_stop_iteration(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - it = ds.__iter__() - _ = list(it) - self.assertEmpty(list(it)) - - def test_nonnative_iterator(self): - - class TestIterator: - - def __init__(self): - self._counter = 0 - - def __iter__(self): - return self - - def __next__(self) -> int: - self._counter += 1 - if self._counter > 10: - raise StopIteration - return self._counter - - def get_state(self): - return {'counter': self._counter} - - def set_state(self, state): - self._counter = state['counter'] - - test_iterator = TestIterator() - it = prefetch.ThreadPrefetchDatasetIterator( - test_iterator, prefetch_buffer_size=10 - ) - elements = [] - checkpoint_step = 5 - for _ in range(checkpoint_step): - elements.append(next(it)) - checkpoint = it.get_state() - elements.extend(it) - self.assertEqual(elements, list(range(1, 11))) - it.set_state(checkpoint) - self.assertEqual(list(it), elements[checkpoint_step:]) - - def test_no_mem_leak(self): - ds = ( - dataset.MapDataset.range(1000) - .repeat() - .map(lambda x: x * np.ones((1000, 1000), dtype=np.int64)) - .to_iter_dataset(options.ReadOptions(prefetch_buffer_size=0)) - ) - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - # If buffered elements are not cleaned up when the iterator is gc'ed, this - # test will OOM. - for _ in range(1000): - it = ds.__iter__() - for _ in range(5): - _ = next(it) - - @parameterized.parameters([True, False]) - def test_no_mem_leak_with_double_prefetch(self, close: bool): - ds = ( - dataset.MapDataset.range(1000) - .repeat() - .map(lambda x: x * np.ones((1000, 1000), dtype=np.int64)) - .to_iter_dataset(options.ReadOptions(prefetch_buffer_size=0)) - ) - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - ds = ds.map(lambda x: x + 1) - ds = prefetch.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=10) - # If buffered elements are not cleaned up when the iterator is gc'ed, this - # test will OOM. - for _ in range(1000): - it = ds.__iter__() - for _ in range(5): - _ = next(it) - if close: - it.close() # pytype: disable=attribute-error - - @absltest.skipIf(platform.system() == 'Darwin', 'Fails on macos-14 runner.') - @parameterized.parameters([True, False]) - def test_early_break_continues_prefetching(self, close: bool): - count = 0 - count_lock = threading.Lock() - - class SlowCountingSource: - - def __len__(self): - return 16 - - def __getitem__(self, index): - nonlocal count - time.sleep(0.1) - with count_lock: - count += 1 - return index - - read_options = options.ReadOptions(num_threads=2) - ds = dataset.MapDataset.source(SlowCountingSource()).to_iter_dataset( - read_options - ) - iterator = ds.__iter__() - - assert count == 0 - if close: - next(iterator) - self.assertGreater(count, 0) - iterator.close() - time.sleep(1) - self.assertLess(count, 8) - else: - next(iterator) - self.assertGreater(count, 0) - time.sleep(1) - self.assertGreater(count, 8) - - -class _MpContextCheckIterDataset(dataset.IterDataset[_T]): - - def __iter__(self) -> dataset.DatasetIterator[_T]: - return _MpContextCheckIterator(self._parent.__iter__()) - - -class _MpContextCheckIterator(dataset.DatasetIterator[_T]): - - def __next__(self) -> tuple[_T, base.MultiprocessingContext]: - element = next(self._parent) - return (element, self._ctx.mp_context) - - def get_state(self): - return self._parent.get_state() - - def set_state(self, state): - self._parent.set_state(state) - - -class MultithreadPrefetchIterDatasetTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ds = dataset.MapDataset.range(20).to_iter_dataset() - - @parameterized.named_parameters( - dict( - testcase_name='no_prefetch', - num_workers=0, - per_worker_buffer_size=0, - ), - dict( - testcase_name='thread', - num_workers=1, - per_worker_buffer_size=1, - ), - dict( - testcase_name='2_threads_large_buffer', - num_workers=2, - per_worker_buffer_size=20, - ), - dict( - testcase_name='4_threads_huge_buffer', - num_workers=4, - per_worker_buffer_size=200, - ), - ) - def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): - prefetch_lazy_iter_ds = prefetch.multithread_prefetch( - self.ds, - num_threads=num_workers, - buffer_size=per_worker_buffer_size, - ) - ds_iter = prefetch_lazy_iter_ds.__iter__() - if num_workers > 0: - ds_iter.start_prefetch() - actual = list(ds_iter) - expected = list(range(20)) - self.assertSequenceEqual(actual, expected) - - def test_checkpoint(self): - ds = prefetch.multithread_prefetch( - self.ds, - num_threads=2, - buffer_size=5, - ) - ds_iter = ds.__iter__() - ds_iter.start_prefetch() - - max_steps = 20 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in [0, 5, 13, 19]: - ds_iter.set_state(checkpoints[starting_step]) - ds_iter.start_prefetch() - for i in range(starting_step, max_steps): - value = next(ds_iter) - print(value) - self.assertEqual(value, values_without_interruption[i]) - - def test_set_state_on_fresh_iterator(self): - ds = prefetch.multithread_prefetch( - self.ds, - num_threads=2, - buffer_size=2, - ) - ds_iter = ds.__iter__() - ds_iter.start_prefetch() - - max_steps = 20 - values_without_interruption = [] - checkpoints = [] - for _ in range(max_steps): - checkpoints.append(ds_iter.get_state()) - values_without_interruption.append(next(ds_iter)) - - for starting_step in [0, 5, 13, 19]: - ds_iter = ds.__iter__() - ds_iter.set_state(checkpoints[starting_step]) - ds_iter.start_prefetch() - for i in range(starting_step, max_steps): - value = next(ds_iter) - self.assertEqual(value, values_without_interruption[i]) - - def test_get_state_doesnt_start_prefetch(self): - event = threading.Event() - - def f(x): - event.set() - return x - - ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() - ds = prefetch.multithread_prefetch( - ds, - num_threads=2, - buffer_size=10, - ) - it = ds.__iter__() - it.get_state() - time.sleep(1) - self.assertFalse(event.is_set()) - - def test_does_not_hang_after_stop_iteration(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() - ds = prefetch.multithread_prefetch( - ds, - num_threads=2, - buffer_size=10, - ) - it = ds.__iter__() - it.start_prefetch() - - def test_fails_with_multiprocess_prefetch_parent(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.ds, - options.MultiprocessingOptions(num_workers=2), - ) - with self.assertRaisesRegex( - ValueError, - 'Nesting multiprocessing or multithreading is not allowed.', - ): - _ = prefetch.multithread_prefetch( - ds, - num_threads=1, - buffer_size=1, - ) - def test_mp_context_is_set_correctly(self): - num_workers = 4 + num_threads = 4 ds = dataset.MapDataset.range(20).to_iter_dataset() ds = _MpContextCheckIterDataset(ds) ds = ds.map(lambda x: x) ds = prefetch.multithread_prefetch( ds, - num_threads=num_workers, + num_threads=num_threads, buffer_size=1, ) @@ -1408,8 +1080,8 @@ def test_mp_context_is_set_correctly(self): # Check mp_context. for i, (_, context) in enumerate(results): - self.assertEqual(context.process_index, i % num_workers) - self.assertEqual(context.process_count, num_workers) + self.assertEqual(context.process_index, i % num_threads) + self.assertEqual(context.process_count, num_threads) if __name__ == '__main__': diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 39d3f46ee..32187f3d2 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -90,27 +90,6 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions: return result -def _validate_no_nested_process_prefetch( - ds: dataset.MapDataset | dataset.IterDataset, -): - """Checks that there are no nested process prefetch nodes.""" - to_check: list[dataset.MapDataset | dataset.IterDataset] = [ds] - while to_check: - d = to_check.pop(0) - if isinstance( - d, - ( - ProcessPrefetchIterDataset, - prefetch.MultiprocessPrefetchIterDataset, - ), - ): - raise ValueError( - "Nesting prefetching with processes is not allowed, but found " - f"{type(d).__name__} under a ProcessPrefetchIterDataset." - ) - to_check.extend(d.parents) - - def _check_picklable( ds: dataset.IterDataset | dataset.MapDataset, ): @@ -165,7 +144,6 @@ def __init__( super().__init__(parent) self._buffer_size = buffer_size self._worker_init_fn = worker_init_fn - _validate_no_nested_process_prefetch(self._parent) def __str__(self) -> str: return f"ProcessPrefetchIterDataset(buffer_size={self._buffer_size})" diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index 9e46a9839..56b0747d6 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -28,7 +28,6 @@ from grain._src.python import options from grain._src.python.dataset import base from grain._src.python.dataset import dataset -from grain._src.python.dataset.transformations import prefetch from grain._src.python.dataset.transformations import process_prefetch import numpy as np @@ -265,23 +264,6 @@ def __getstate__(self): ): list(ds) - def test_fails_with_nested_prefetch(self): - ds1 = process_prefetch.ProcessPrefetchIterDataset(self.ds, buffer_size=1) - with self.assertRaisesRegex( - ValueError, - 'Nesting prefetching with processes is not allowed', - ): - process_prefetch.ProcessPrefetchIterDataset(ds1, buffer_size=1) - - ds2 = prefetch.MultiprocessPrefetchIterDataset( - self.ds, options.MultiprocessingOptions(num_workers=1) - ) - with self.assertRaisesRegex( - ValueError, - 'Nesting prefetching with processes is not allowed', - ): - process_prefetch.ProcessPrefetchIterDataset(ds2, buffer_size=1) - def test_reports_worker_crash(self): def failing_transform(element): del element diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py deleted file mode 100644 index 0bcf381e3..000000000 --- a/grain/_src/python/grain_pool.py +++ /dev/null @@ -1,834 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This module provides a way to distribute processing across multiple workers. - -In the context of Grain we use the term "process" similar to JAX, where usually -each machine runs one Python process (identified by `jax.process_index()`). -In Grain each "process" can create additional Python child processes that we -call "workers". - -GrainPool manages a set of Python processes. It's similar to -`multiprocessing.Pool` but optimises communication between the processes to -enable high throughput data pipelines. -The GrainPool works as follows: -* Parent process launches a set of "num_workers" child processes. -* Each child process produces elements by reading data and transforming it. The - resulting elements are added to a queue (each child process has its queue). -* Parent process reads data from the children queues in a strict round-robin - fashion. - -Shutdown logic considerations: -* Child processes are launched as Daemon processes. In case of (unexpected) - parent termination, child processes will be terminated by OS. -* System uses a multiprocessing event ("termination_event") for termination. - Parent and child processes continuously check if the "termination_event" and - if set, they break from what they are doing. -* We never block indefinitely when calling get() or put() on a queue. This - ensures parent and child processes continue to check the termination_event. - -MultiProcessIterator wraps GrainPool adding lifecycle management, checkpointing -support and multithreaded elements read. -""" - -from __future__ import annotations - -from collections.abc import Iterator -import cProfile -import dataclasses -from multiprocessing import context -from multiprocessing import pool -from multiprocessing import queues -from multiprocessing import synchronize -import pstats -import queue -import sys -import threading -import traceback -from typing import Any, Callable, Protocol, Type, TypeVar, Union, runtime_checkable - -from absl import flags -from absl import logging -import cloudpickle -from grain._src.core import monitoring as grain_monitoring -from grain._src.core import parallel -from grain._src.core import tree_lib -from grain._src.core.config import config -import multiprocessing as mp -from grain._src.python import grain_logging -from grain._src.python import multiprocessing_common -from grain._src.python import record -from grain._src.python import shared_memory_array -from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member - - -T = TypeVar("T") - -# Maximum number of threads for starting and stopping processes. -_PROCESS_MANAGEMENT_MAX_THREADS = 64 -_PROCESS_JOIN_TIMEOUT = 25 -_QUEUE_WAIT_TIMEOUT = 1 -# Input queues contain small structures (record metadata), thus they are safe -# to have a big size. -_INPUT_QUEUE_MAX_SIZE = 10000 - - -@dataclasses.dataclass -class _ProcessingComplete: - """Indicates child process finished processing.""" - - -_PROCESSING_COMPLETE = _ProcessingComplete() - - -@dataclasses.dataclass(slots=True, frozen=True) -class GrainPoolElement: - """Wrapper for output records emited by Grain Pool.""" - - record: Any - worker_index: Any - - -@dataclasses.dataclass(slots=True, frozen=True) -class RemoteWorkerError: - """Grain worker exception that can be pickled and sent over a queue.""" - error_cls: Type[Exception] - error: str - worker_index: int - - @property - def original_error(self) -> Exception: - msg = ( - f"Grain worker {self.worker_index} failed with the following" - f" error:\n\n{self.error}" - ) - # Custom exception classes can have different c'tor arguments. - try: - return self.error_cls(msg) - except Exception: # pylint: disable=broad-except - return RuntimeError(msg) - - -def _print_profile(preamble: str, profile: cProfile.Profile): - """Prints output of cProfile, sorted by cumulative time.""" - print(preamble) - stats = pstats.Stats(profile).sort_stats(pstats.SortKey.CUMULATIVE) - stats.print_stats() - - -@runtime_checkable -class GetElementProducerFn(Protocol[T]): - """A callable class able to generate elements with serialization support.""" - - def __call__( - self, - *, - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_out_queue: queues.Queue | None = None, - ) -> Iterator[T]: - """Returns a generator of elements.""" - - def serialize(self) -> bytes: - """Serializes itself and the result will be used by `deserialize`. - - If a class inherits from this class, it should make sure `deserialize` - is compatible with this `serialize` function. - i.e. `GetElementProducerFn.deserialize(obj.serialize())` should return the - same object as `obj: GetElementProducerFn`. - - Returns: - a serialized string of myself. - """ - return cloudpickle.dumps(self) - - @classmethod - def deserialize(cls, serialized: bytes) -> GetElementProducerFn[T]: - """Deserializes the result from `serialize`.""" - del cls - - obj = cloudpickle.loads(serialized) - if not isinstance(obj, GetElementProducerFn): - raise ValueError( - "`serialized` should be deserialized into `GetElementProducerFn`." - ) - - return obj - - -def parse_debug_flags(debug_flags: dict[str, Any]): - """Parses debug flags.""" - - flags.FLAGS["grain_py_debug_mode"].present = True - flags.FLAGS["grain_py_dataset_visualization_output_dir"].present = True - config.update("py_debug_mode", debug_flags["grain_py_debug_mode"]) - config.update( - "py_dataset_visualization_output_dir", - debug_flags["grain_py_dataset_visualization_output_dir"], - ) - - -def _initialize_and_get_element_producer( - args_queue: queues.Queue, - *, - debug_flags: dict[str, Any], - worker_index: int, - worker_count: int, - start_profiling_event: synchronize.Event, - stop_profiling_event: synchronize.Event, - stats_out_queue: queues.Queue, -) -> Iterator[Any]: - """Unpickles the element producer from the args queue and closes the queue.""" - ( - serialized_flag_parse_fn, - serialized_init_fns, - serialized_element_producer_fn, - ) = args_queue.get() - flag_parse_fn: Callable[[Any], None] = cloudpickle.loads( - serialized_flag_parse_fn - ) - flag_parse_fn(debug_flags) - init_fns: list[Callable[[int, int], None]] = cloudpickle.loads( - serialized_init_fns - ) - for init_fn in init_fns: - init_fn(worker_index, worker_count) - element_producer_fn: GetElementProducerFn[Any] = ( - GetElementProducerFn.deserialize(serialized_element_producer_fn) - ) - - element_producer = element_producer_fn( - worker_index=worker_index, - worker_count=worker_count, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - stats_out_queue=stats_out_queue, - ) - # args_queue has only a single argument and thus can be safely closed. - args_queue.close() - return element_producer - - -def _worker_loop( - *, - args_queue: queues.Queue, - errors_queue: queues.Queue, - output_queue: queues.Queue, - termination_event: synchronize.Event, - start_profiling_event: synchronize.Event, - stop_profiling_event: synchronize.Event, - worker_index: int, - worker_count: int, - enable_profiling: bool, - debug_flags: dict[str, Any], - stats_out_queue: queues.Queue, -): - """Code to be run on each child process.""" - out_of_elements = False - try: - worker_index_suffix = "" if worker_count == 1 else f" {worker_index}" - grain_logging.set_process_identifier_prefix( - f"PyGrain Worker{worker_index_suffix}" - ) - logging.info("Starting work.") - element_producer = _initialize_and_get_element_producer( - args_queue, - debug_flags=debug_flags, - worker_index=worker_index, - worker_count=worker_count, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - stats_out_queue=stats_out_queue, - ) - profiling_enabled = enable_profiling and worker_index == 0 - if profiling_enabled: - profile = cProfile.Profile() - profile.enable() - # If termination event is set, we terminate and discard remaining elements. - while not termination_event.is_set(): - try: - next_element = next(element_producer) - if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types - next_element, output_queue, termination_event.is_set - ): - # We failed to put the element into the output queue because the - # termination event was set. The element may contain a shared memory - # block reference that has to be cleaned up. - _unlink_shm_in_structure(next_element) - except StopIteration: - out_of_elements = True - multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types - _ProcessingComplete(), output_queue, termination_event.is_set - ) - break - if profiling_enabled: - profile.disable() - _print_profile(f"PROFILE OF PROCESS WITH IDX {worker_index}.", profile) - - except Exception as e: # pylint: disable=broad-except - logging.exception( - "Error occurred in child process with worker_index: %i", worker_index - ) - remote_error = RemoteWorkerError( - error_cls=e.__class__, - error="".join( - traceback.format_exception(e.__class__, e, e.__traceback__) - ), - worker_index=worker_index, - ) - try: - errors_queue.put(remote_error, timeout=_QUEUE_WAIT_TIMEOUT) - except queue.Full: - logging.error("Couldn't send exception from child process. Queue full!") - - logging.info( - "Setting termination event in process with worker_index: %i", - worker_index, - ) - termination_event.set() - - if termination_event.is_set(): - if not out_of_elements: - # Since the termination event is set the consumer will not get any more - # elements from the output queue. The elements may contain reference to - # shared memory blocks that have to be cleaned up. - while not output_queue.empty(): - _unlink_shm_in_structure(output_queue.get_nowait()) - # When adding elements to the queue, element is put in a buffer and a - # background thread flushes the elements through the pipe. The process that - # writes to the queue joins that thread automatically on exit. We call - # cancel_join_thread when system terminates to prevent deadlocks. - output_queue.cancel_join_thread() - output_queue.close() - logging.info("Process %i exiting.", worker_index) - - -def _unlink_shm_if_metadata(obj: Any): - if isinstance(obj, shared_memory_array.SharedMemoryArrayMetadata): - obj.close_and_unlink_shm() - - -def _unlink_shm_in_structure(structure: Any): - if isinstance(structure, record.Record): - _unlink_shm_in_structure(structure.data) - else: - tree_lib.map_structure(_unlink_shm_if_metadata, structure) - - -class GrainPool(Iterator[T]): - """Pool to parallelize processing of Grain pipelines among a set of processes.""" - - def __init__( - self, - ctx: context.BaseContext, - *, - get_element_producer_fn: GetElementProducerFn[T], - worker_index_to_start_reading: int = 0, - termination_event: threading.Event | None = None, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - options: MultiprocessingOptions, - worker_init_fn: Callable[[int, int], None] | None = None, - stats_in_queues: tuple[queues.Queue, ...] | None = None, - ): - """Initialise a Grain Pool. - - Args: - ctx: Context to make multiprocessing primitives work. - get_element_producer_fn: Callable that returns an iterator over the - elements given the process index and process count. - worker_index_to_start_reading: index of worker to start reading output - batches from (needed for checkpointing support). - termination_event: Setting this event will terminate the pool. Otherwise, - the pool will terminate when either one of the workers failed or when - all workers are done processing data. GrainPool will not set this event. - start_profiling_event: Event to start prism profiling. - stop_profiling_event: Event to stop prism profiling. - options: Options for multiprocessing. See MultiprocessingOptions. - worker_init_fn: Function to run in each worker process before the element - producer. The function takes two arguments: the current worker index and - the total worker count. - stats_in_queues: Queue to propagate execution summary from child processes - to the parent. - """ - self.num_processes = options.num_workers - logging.info("Grain pool will use %i processes.", self.num_processes) - self.worker_args_queues = [] - self.worker_output_queues = [] - self.processes = [] - # Reader termination should always result in worker termination. However, - # worker termination should not shut down the reader: workers are terminated - # when they finished processing data, but the reader may still need to read - # the remaining output from the shared queues. That is why we use two - # separate events. - self._reader_termination_event = termination_event or threading.Event() - self._workers_termination_event = ctx.Event() - self._worker_init_fn = worker_init_fn - self.completed_processes = set() - # Queue to propagate errors from child processes to the parent. Note that - # this queue is shared by all child processes. - self.worker_error_queue = ctx.Queue(self.num_processes) - self.stats_in_queues = stats_in_queues - - try: - get_element_producer_fn = get_element_producer_fn.serialize() - except Exception as e: - if sys.version_info >= (3, 11): - e.add_note( - "\nCould not serialize transformation function passed to Grain " - "workers. This likely means that your data source, sampler or one " - "of your transformations cannot be serialized. Please make sure " - "that the objects work with cloudpickle.dumps()." - ) - raise e - - for worker_index in range(self.num_processes): - worker_args_queue = ctx.Queue(1) - worker_output_queue = ctx.Queue(options.per_worker_buffer_size) - process_kwargs = dict( - args_queue=worker_args_queue, - errors_queue=self.worker_error_queue, - output_queue=worker_output_queue, - stats_out_queue=( - self.stats_in_queues[worker_index] - if self.stats_in_queues - else None - ), - termination_event=self._workers_termination_event, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - worker_index=worker_index, - worker_count=options.num_workers, - enable_profiling=options.enable_profiling, - debug_flags=dict( - grain_py_debug_mode=config.get_or_default("py_debug_mode"), - grain_py_dataset_visualization_output_dir=( - config.get_or_default("py_dataset_visualization_output_dir") - ), - ), - ) - # The process kwargs must all be pickable and will be unpickle before - # absl.app.run() is called. We send arguments via a queue to ensure that - # they are unpickled after absl.app.run() was called in the child - # processes. - worker_init_fns = [self._worker_init_fn] if self._worker_init_fn else [] - parse_debug_flags_fn = parse_debug_flags - worker_init_fns = cloudpickle.dumps(worker_init_fns) - parse_debug_flags_fn = cloudpickle.dumps(parse_debug_flags_fn) - worker_args_queue.put( - (parse_debug_flags_fn, worker_init_fns, get_element_producer_fn) - ) - process = ctx.Process( # pytype: disable=attribute-error # re-none - target=_worker_loop, kwargs=process_kwargs, daemon=True - ) - self.worker_args_queues.append(worker_args_queue) - self.worker_output_queues.append(worker_output_queue) - self.processes.append(process) - - logging.info("Grain pool will start child processes.") - parallel.run_in_parallel( - function=lambda child_process: child_process.start(), - list_of_kwargs_to_function=[ - {"child_process": p} for p in self.processes - ], - num_workers=min(_PROCESS_MANAGEMENT_MAX_THREADS, self.num_processes), - ) - logging.info("Grain pool started all child processes.") - self._next_worker_index = worker_index_to_start_reading - - def __iter__(self) -> GrainPool: - return self - - def _process_failed(self, worker_index: int) -> bool: - exit_code = self.processes[worker_index].exitcode - return exit_code is not None and exit_code != 0 - - def _processing_completed(self) -> bool: - return all(p.exitcode == 0 for p in self.processes) - - def _update_next_worker_index(self) -> None: - self._next_worker_index = (self._next_worker_index + 1) % self.num_processes - - def __next__(self) -> GrainPoolElement: - processing_failed = False - while ( - not self._workers_termination_event.is_set() - and len(self.completed_processes) < self.num_processes - ): - # If the reader was shut down, e.g. due to iterator deletion, we should - # shut down the workers. - if self._reader_termination_event.is_set(): - self._shutdown() - # Since the reader is shut down it doesn't matter what we return here. - # We should not raise an exception because it is common to iterate over - # infinite datasets and delete the iterator before processing is - # complete. - return GrainPoolElement( - "Grain worker pool reader was terminated, shutting down workers.", - -1, - ) - if self._next_worker_index in self.completed_processes: - self._update_next_worker_index() - continue - try: - element_worker_index = self._next_worker_index - element = self.worker_output_queues[self._next_worker_index].get( - timeout=_QUEUE_WAIT_TIMEOUT - ) - logging.debug("Read element from process: %s", self._next_worker_index) - if element == _PROCESSING_COMPLETE: - logging.info( - "Processing complete for process with worker_index %i", - self._next_worker_index, - ) - self.completed_processes.add(self._next_worker_index) - self._update_next_worker_index() - else: - self._update_next_worker_index() - return GrainPoolElement(element, element_worker_index) - except queue.Empty: - logging.debug("Got no element from process %s", self._next_worker_index) - if self._process_failed(self._next_worker_index): - processing_failed = True - logging.info( - "Process with idx %i Failed (Exitcode: %s).", - self._next_worker_index, - self.processes[self._next_worker_index].exitcode, - ) - break - - if processing_failed or self._workers_termination_event.is_set(): - logging.error("Processing Failed. Shutting down.") - self._shutdown() - - try: - remote_error = self.worker_error_queue.get(timeout=_QUEUE_WAIT_TIMEOUT) - raise remote_error.original_error - except queue.Empty: - # Worker did not report any error. This means that either an exception - # was raised outside of the worker loop (e.g. during flag parsing) or - # the worker process was forcefully terminated. Unfortunately, there is - # no debugging info available in the main process at this point apart - # from the exit code. The crash logs, however, should've been produced. - raise RuntimeError( - f"Grain worker process {self._next_worker_index} was terminated" - " unexpectedly with exit code " - f"{self.processes[self._next_worker_index].exitcode}. Search the " - "logs above for the source of the crash." - ) from None - - # Processing successfully completed. - raise StopIteration - - def __del__(self): - self._shutdown() - - def __enter__(self) -> GrainPool: - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - logging.info("Grain pool is exiting.") - self._shutdown() - - def _shutdown(self) -> None: - """Gracefully shutdown the multiprocessing system.""" - logging.info("Shutting down multiprocessing system.") - try: - self._workers_termination_event.set() - # There is a chance that shutdown was triggered before the worker - # processes fully initialized and read from the arg queues. The arg - # queues will block the main process until their elements are flushed - # through the pipes, which will never happen since the workers were shut - # down. Here we avoid blocking the main process, see - # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread - for q in self.worker_args_queues: - q.cancel_join_thread() - q.close() - # Not joining here will cause the children to be zombie after they finish. - # Need to join or call active_children. - for process in self.processes: - process.join(timeout=_PROCESS_JOIN_TIMEOUT) - finally: - for process in self.processes: - # In case all our attempts to terminate the system fails, we forcefully - # kill the child processes. - if process.is_alive(): - logging.info("Killing worker process with pid %i", process.pid) - process.kill() - - -@dataclasses.dataclass(slots=True, frozen=True) -class _ReaderQueueElement: - """Element to be added to the reader queue.""" - - async_result: pool.AsyncResult[Any] - # index of worker producing the element in [0, worker_count] - worker_index: int - - -@dataclasses.dataclass(frozen=True) -class _GrainPoolProcessingComplete: - """Indicates processing of grain pool is complete.""" - - -_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete() -_QueueElement = Union[ - _ReaderQueueElement, _GrainPoolProcessingComplete, Exception -] - - -def _open_shared_memory_for_leaf(element: Any) -> Any: - if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata): - element = shared_memory_array.SharedMemoryArray.from_metadata(element) - element.unlink_on_del() - return element - - -def _open_shared_memory_for_structure(structure: Any) -> Any: - if isinstance(structure, record.Record): - structure.data = tree_lib.map_structure( - _open_shared_memory_for_leaf, structure.data - ) - return structure - return tree_lib.map_structure(_open_shared_memory_for_leaf, structure) - - -def _process_elements_in_grain_pool( - *, - get_element_producer_fn: GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - reader_queue: queue.Queue[_QueueElement], - thread_pool: pool.ThreadPool, - termination_event: threading.Event, - start_profiling_event: synchronize.Event | None, - stop_profiling_event: synchronize.Event | None, - worker_index_to_start_reading: int, - worker_init_fn: Callable[[int, int], None] | None, - stats_in_queues: tuple[queues.Queue, ...] | None, -) -> None: - """Processes elements in grain worker pool asynchronously.""" - - def read_thread_should_stop(): - return termination_event.is_set() or not threading.main_thread().is_alive() - - ctx = mp.get_context("spawn") - - try: - with GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - worker_index_to_start_reading=worker_index_to_start_reading, - termination_event=termination_event, - start_profiling_event=start_profiling_event, - stop_profiling_event=stop_profiling_event, - options=multiprocessing_options, - worker_init_fn=worker_init_fn, - stats_in_queues=stats_in_queues, - ) as g_pool: - for element in g_pool: - if read_thread_should_stop(): - break - # Note: We use a thread pool for opening the shared memory because - # in some cases the calls to `shm_open` can actually become the - # bottleneck for a single thread. - async_result = thread_pool.apply_async( - _open_shared_memory_for_structure, - args=(element.record,), - ) - multiprocessing_common.add_element_to_queue( - _ReaderQueueElement( - async_result, - element.worker_index, - ), - reader_queue, - read_thread_should_stop, - ) - # This exception could arise from user-provide code. Propagating it to - # the main thread to re-raise it as is. - except Exception as e: # pylint: disable=broad-except - multiprocessing_common.add_element_to_queue( - e, reader_queue, read_thread_should_stop - ) - return - multiprocessing_common.add_element_to_queue( - _GrainPoolProcessingComplete(), - reader_queue, - read_thread_should_stop, - ) - - -class MultiProcessIteratorInvalidStateError(Exception): - """Raised when iterator is an invalid state and can't be iterated on.""" - - -class MultiProcessIterator(Iterator[T]): - """Runs iterators returned by `get_element_producer_fn` in child processes. - - Note: MultiProcessIterator implements the Context Manager protocol to clean - resources. As such, it must be used within a "with" statement. - - Wraps `GrainPool` adding lifecycle management, multithreaded elements read and - recording the last worker index useful for checkpointing. - """ - - def __init__( - self, - get_element_producer_fn: GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - worker_init_fn: Callable[[int, int], None] | None = None, - start_profiling_event: synchronize.Event | None = None, - stop_profiling_event: synchronize.Event | None = None, - stats_in_queues: tuple[queues.Queue, ...] | None = None, - ): - """Initializes MultiProcessIterator. - - Args: - get_element_producer_fn: factory making record iterators for each child - process. - multiprocessing_options: options for distributing the record iterators. - worker_index_to_start_reading: Index of the next worker to read from. This - is useful for recovering from a checkpoint. - worker_init_fn: Function to run in each worker process before the element - producer. The function takes two arguments: the current worker index and - the total worker count. - start_profiling_event: Event to start prism profiling. - stop_profiling_event: Event to stop prism profiling. - stats_in_queues: Queues to send execution summaries from worker processes - to the main process. - """ - self._get_element_producer_fn = get_element_producer_fn - self._multiprocessing_options = multiprocessing_options - self._last_worker_index = worker_index_to_start_reading - 1 - self._worker_init_fn = worker_init_fn - self._reader_queue = None - self._reader_thread_pool = None - self._termination_event = None - self._reader_thread = None - self._stats_in_queues = stats_in_queues - self._start_profiling_event = start_profiling_event - self._stop_profiling_event = stop_profiling_event - - def __del__(self): - if self._reader_thread: - logging.info("Destroying multiprocess iterator.") - self.stop_prefetch() - - def start_prefetch(self) -> None: - """Starts the prefetching threads.""" - - if self._reader_thread: - return - - max_buffered_elements = ( - self._multiprocessing_options.num_workers - * self._multiprocessing_options.per_worker_buffer_size - ) - self._reader_queue = queue.Queue(maxsize=max_buffered_elements) - self._reader_thread_pool = pool.ThreadPool(max_buffered_elements) - self._termination_event = threading.Event() - self._reader_thread = threading.Thread( - target=_process_elements_in_grain_pool, - kwargs=dict( - get_element_producer_fn=self._get_element_producer_fn, - multiprocessing_options=self._multiprocessing_options, - reader_queue=self._reader_queue, - thread_pool=self._reader_thread_pool, - termination_event=self._termination_event, - start_profiling_event=self._start_profiling_event, - stop_profiling_event=self._stop_profiling_event, - worker_index_to_start_reading=self._last_worker_index + 1, - worker_init_fn=self._worker_init_fn, - stats_in_queues=self._stats_in_queues, - ), - ) - self._reader_thread.start() - shared_memory_array.SharedMemoryArray.enable_async_del( - self._multiprocessing_options.num_workers - ) - - def stop_prefetch(self) -> None: - """Cleans up prefetching threads.""" - - if not self._reader_thread: - return - - # pytype: disable=attribute-error - self._termination_event.set() - self._reader_thread_pool.close() - self._reader_thread.join() - self._reader_thread_pool.join() - # pytype: enable=attribute-error - self._termination_event = None - self._reader_thread_pool = None - self._reader_thread = None - self._reader_queue = None - - def __enter__(self): - self.start_prefetch() - return self - - def __exit__(self, exc_type, exc_value, tb): - self.stop_prefetch() - - def _can_iterate(self): - """Checks whether the object is in a state where it can be iterated on.""" - return ( - self._reader_queue is not None - and self._termination_event is not None - and self._reader_thread_pool is not None - and self._reader_thread is not None - ) - - def __iter__(self): - if not self._can_iterate(): - raise MultiProcessIteratorInvalidStateError( - "MultiProcessIterator is in an invalid state. Note that" - " MultiProcessIterator should be used with a 'with' statement." - ) - return self - - def get_last_worker_index(self): - return self._last_worker_index - - def __next__(self): - if not self._can_iterate(): - raise MultiProcessIteratorInvalidStateError( - "MultiProcessIterator is in an invalid state. Note that" - " MultiProcessIterator should be used with a 'with' statement." - ) - element = multiprocessing_common.get_element_from_queue( - self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error - ) - if isinstance(element, Exception): - raise element - if ( - element == _GRAIN_POOL_PROCESSING_COMPLETE - or element == multiprocessing_common.SYSTEM_TERMINATED - ): - raise StopIteration - - if not isinstance(element, _ReaderQueueElement): - raise ValueError( - f"Got invalid element type from GrainPool: {type(element)}" - ) - - result = multiprocessing_common.get_async_result( - element.async_result, self._termination_event.is_set - ) - if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access - raise StopIteration - self._last_worker_index = element.worker_index - return result diff --git a/grain/_src/python/grain_pool_test.py b/grain/_src/python/grain_pool_test.py deleted file mode 100644 index 5aa87795b..000000000 --- a/grain/_src/python/grain_pool_test.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for GrainPool.""" - -from collections.abc import Iterator -import multiprocessing -import os -import platform -import signal -import sys -from typing import Any -from absl import flags -from absl.testing import absltest -from absl.testing import parameterized -from grain._src.core import config -from grain._src.core import monitoring as grain_monitoring -import multiprocessing as mp -from grain._src.python import data_sources -from grain._src.python import grain_pool as gp -from grain._src.python import record -from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member - - -class GrainPoolTest(absltest.TestCase): - - def _join_and_assert_process_exitcode(self, process: multiprocessing.Process): - # The process can be potentially terminated forcibly and needs a moment to - # finalize and update the exitcode. - process.join(timeout=gp._PROCESS_JOIN_TIMEOUT) - self.assertIn(process.exitcode, {0, -signal.SIGTERM}) - - def test_pool_with_flags_not_parsed(self): - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - # unparse the flags explicitly - flags.FLAGS.unparse_flags() - - _ = gp.GrainPool( - ctx=mp.get_context("spawn"), - get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), - ) - - def test_pool_equal_split_in_memory_data_source(self): - in_memory_ds = data_sources.SharedMemoryDataSource(range(12)) - - # 12 elements in the `in_memory_ds` are divided - # equally among 4 processes. - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 12, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - output_elements = [] - with gp.GrainPool( - ctx=mp.get_context("spawn"), - get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - # turn each element in `in_memory_ds` to their negatives. - in_memory_ds[element.record] = -in_memory_ds[element.record] - - self.assertEqual( - output_elements, [gp.GrainPoolElement(x, x % 4) for x in range(12)] - ) - - self.assertEqual(list(iter(in_memory_ds)), [-x for x in range(12)]) - - def test_pool_equal_split(self): - ctx = mp.get_context("spawn") - - # 16 elements divide equally among 4 processes - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 16, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - output_elements = [] - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - expected_elements = list( - map( - lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(16) - ) - ) - self.assertEqual(expected_elements, output_elements) - # Make sure num_processes processes were launched. - self.assertLen(grain_pool.processes, options.num_workers) - # Make sure all child processes exited successfully. - for child_process in grain_pool.processes: - self._join_and_assert_process_exitcode(child_process) - - def test_pool_non_equal_split(self): - ctx = mp.get_context("spawn") - - # 14 elements do not divide equally among 4 processes - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - output_elements = [] - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - for element in grain_pool: - output_elements.append(element) - expected_elements = list( - map( - lambda x: gp.GrainPoolElement(x, x % options.num_workers), range(14) - ) - ) - self.assertEqual(expected_elements, output_elements) - # Make sure all child processes exited successfully. - for child_process in grain_pool.processes: - self._join_and_assert_process_exitcode(child_process) - - @absltest.skipIf( - platform.system() == "Windows", "SIGKILL signal not available on Windows." - ) - def test_pool_kill_child(self): - ctx = mp.get_context("spawn") - - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - with gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) as grain_pool: - child_pid = grain_pool.processes[0].pid - os.kill(child_pid, signal.SIGKILL) - - self.assertEqual( - grain_pool.processes[0].exitcode, -1 * signal.SIGKILL.value - ) - for child_process in grain_pool.processes[1:]: - self._join_and_assert_process_exitcode(child_process) - - def test_pool_object_deletion(self): - ctx = mp.get_context("spawn") - - class GetElementProducerFn(gp.GetElementProducerFn): - - def __call__(self, *, worker_index: int, worker_count: int, **kwargs): - del self - return iter(range(worker_index, 14, worker_count)) - - get_element_producer_fn = GetElementProducerFn() - - options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) - - # Users should generally use the with statement, here we test if GrainPool - # was created without the "with statement", that object deletion would - # have child processes gracefully exited. - grain_pool = gp.GrainPool( - ctx=ctx, - get_element_producer_fn=get_element_producer_fn, - options=options, - ) - - child_processes = grain_pool.processes - grain_pool.__del__() - - for child_process in child_processes: - self._join_and_assert_process_exitcode(child_process) - - -def _make_uniform_element_producer_fn( - last_seen_index: int = -1, -) -> gp.GetElementProducerFn: - - class _RoundrobinElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self - yield from range(10)[last_seen_index + 1 + worker_index :: worker_count] - - return _RoundrobinElementProducerFn() - - -class RoundrobinRecordElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[record.Record[int]]: - del self - for i in range(5)[worker_index::worker_count]: - yield record.Record(record.RecordMetadata(i), i) - - -class NonUniformElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_count - for _ in range(worker_index * 3): - yield worker_index - - -class MultiProcessIteratorTest(parameterized.TestCase): - - @parameterized.named_parameters( - dict( - testcase_name="two_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=0, - expected=list(range(10)), - ), - dict( - testcase_name="five_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - expected=list(range(10)), - ), - dict( - testcase_name="from_checkpoint", - get_element_producer_fn=_make_uniform_element_producer_fn(5), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=1, - expected=[7, 6, 9, 8], - ), - dict( - testcase_name="non_uniform", - get_element_producer_fn=NonUniformElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - expected=[1, 2, 1, 2, 1, 2, 2, 2, 2], - ), - dict( - testcase_name="record_producer_fn", - get_element_producer_fn=RoundrobinRecordElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - expected=[ - record.Record(record.RecordMetadata(i), i) for i in range(5) - ], - ), - ) - def test_produces_correct_data( - self, - get_element_producer_fn: gp.GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - expected: Any, - ): - with gp.MultiProcessIterator( - get_element_producer_fn, - multiprocessing_options, - worker_index_to_start_reading, - ) as iterator: - actual = list(iterator) - self.assertEqual(actual, expected) - - @parameterized.named_parameters( - dict( - testcase_name="two_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=1, - num_iters=5, - expected_last_worker_index=1, - ), - dict( - testcase_name="five_workers", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - num_iters=7, - expected_last_worker_index=1, - ), - dict( - testcase_name="five_workers_incomplete_round", - get_element_producer_fn=_make_uniform_element_producer_fn(), - multiprocessing_options=MultiprocessingOptions(num_workers=5), - worker_index_to_start_reading=0, - num_iters=3, - expected_last_worker_index=2, - ), - dict( - testcase_name="from_checkpoint", - get_element_producer_fn=_make_uniform_element_producer_fn(5), - multiprocessing_options=MultiprocessingOptions(num_workers=2), - worker_index_to_start_reading=0, - num_iters=3, - expected_last_worker_index=0, - ), - dict( - testcase_name="non_uniform_record_producer_fn", - get_element_producer_fn=NonUniformElementProducerFn(), - multiprocessing_options=MultiprocessingOptions(num_workers=3), - worker_index_to_start_reading=0, - num_iters=6, - expected_last_worker_index=2, - ), - ) - def test_get_state( - self, - get_element_producer_fn: gp.GetElementProducerFn, - multiprocessing_options: MultiprocessingOptions, - worker_index_to_start_reading: int, - num_iters: int, - expected_last_worker_index: int, - ): - with gp.MultiProcessIterator( - get_element_producer_fn, - multiprocessing_options, - worker_index_to_start_reading, - ) as iterator: - for _ in range(num_iters): - _ = next(iterator) - actual_last_worker_index = iterator.get_last_worker_index() - self.assertEqual(actual_last_worker_index, expected_last_worker_index) - - def test_fails_with_zero_workers(self): - with self.assertRaisesRegex( - ValueError, "Number of processes must be at least 1" - ): - with gp.MultiProcessIterator( - _make_uniform_element_producer_fn(), - MultiprocessingOptions(num_workers=0), - 0, - ) as iterator: - list(iterator) - - def test_propagates_error(self): - error_msg = "very unique error" - - class FailingGetElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - raise ValueError(error_msg) - - failing_get_element_producer_fn = FailingGetElementProducerFn() - - with gp.MultiProcessIterator( - failing_get_element_producer_fn, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex(ValueError, error_msg): - list(iterator) - - def test_reports_worker_crash(self): - - class FailingGetElementProducerFn(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - sys.exit(12) - - failing_get_element_producer_fn = FailingGetElementProducerFn() - - with gp.MultiProcessIterator( - failing_get_element_producer_fn, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex( - RuntimeError, "was terminated unexpectedly with exit code 12" - ): - list(iterator) - - def test_reports_unpicklable_element_producer_fn(self): - error_msg = "UnpicklableObject is not picklable" - - class UnpicklableObject: - - def __getstate__(self): - raise ValueError(error_msg) - - local_state = UnpicklableObject() - - class GetElementProducerFnWithUnpicklableClosure(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[int]: - del self, worker_index, worker_count - yield 1 if local_state is None else 2 - - get_element_producer_fn_with_unpicklable_closure = ( - GetElementProducerFnWithUnpicklableClosure() - ) - - with gp.MultiProcessIterator( - get_element_producer_fn_with_unpicklable_closure, - MultiprocessingOptions(num_workers=2), - 0, - ) as iterator: - with self.assertRaisesRegex(ValueError, error_msg): - list(iterator) - - def test_worker_init_fn(self): - - def _set_worker_index_and_count(worker_index: int, worker_count: int): - gp.monkey_patched_index_and_count = (worker_index, worker_count) - - class GetElementProducerFnReturningGlobal(gp.GetElementProducerFn): - - def __call__( - self, *, worker_index: int, worker_count: int, **kwargs - ) -> Iterator[tuple[int, int]]: - del self, worker_index, worker_count - yield gp.monkey_patched_index_and_count # pytype: disable=module-attr - - with gp.MultiProcessIterator( - GetElementProducerFnReturningGlobal(), - MultiprocessingOptions(num_workers=2), - 0, - worker_init_fn=_set_worker_index_and_count, - ) as iterator: - result = list(iterator) - self.assertEqual(result, [(0, 2), (1, 2)]) - - -if __name__ == "__main__": - absltest.main() diff --git a/grain/python/experimental.py b/grain/python/experimental.py index c213ae4d3..1bdbe4fbf 100644 --- a/grain/python/experimental.py +++ b/grain/python/experimental.py @@ -48,7 +48,6 @@ ConcatThenSplitIterDataset, ) from grain._src.python.dataset.transformations.prefetch import ( - MultiprocessPrefetchIterDataset, ThreadPrefetchIterDataset, ThreadPrefetchDatasetIterator, )