Skip to content

Commit 25f80c8

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 823568742
1 parent 104ba7f commit 25f80c8

File tree

6 files changed

+647
-9
lines changed

6 files changed

+647
-9
lines changed

grain/_src/python/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ py_library(
224224
":options",
225225
":record",
226226
":shared_memory_array",
227+
":variable_size_queue",
227228
"//grain/_src/core:config",
228229
"//grain/_src/core:monitoring",
229230
"//grain/_src/core:parallel",
@@ -370,3 +371,19 @@ py_library(
370371
"@pypi//etils:pkg",
371372
],
372373
)
374+
375+
py_library(
376+
name = "variable_size_queue",
377+
srcs = ["variable_size_queue.py"],
378+
srcs_version = "PY3",
379+
)
380+
381+
py_test(
382+
name = "variable_size_queue_test",
383+
srcs = ["variable_size_queue_test.py"],
384+
srcs_version = "PY3",
385+
deps = [
386+
":variable_size_queue",
387+
"@abseil-py//absl/testing:absltest",
388+
],
389+
)

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,14 @@ def __str__(self) -> str:
760760
f"multiprocessing_options={self._multiprocessing_options})"
761761
)
762762

763+
def set_per_worker_buffer_size(self, per_worker_buffer_size: int):
764+
if self._raw_iterator is None:
765+
raise ValueError(
766+
"Cannot change per worker buffer size before the iterator has been"
767+
" initialized."
768+
)
769+
self._raw_iterator.set_per_worker_buffer_size(per_worker_buffer_size)
770+
763771

764772
class ThreadPrefetchIterDataset(dataset.IterDataset[T]):
765773
"""Iterable dataset that uses a synchronized queue for prefetching.

grain/_src/python/dataset/transformations/prefetch_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,77 @@ def map_fn(x):
917917
],
918918
)
919919

920+
def test_set_per_worker_buffer_size_increase(self):
921+
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
922+
mp_options = options.MultiprocessingOptions(
923+
num_workers=1, per_worker_buffer_size=1
924+
)
925+
ds = prefetch.MultiprocessPrefetchIterDataset(
926+
ds,
927+
mp_options,
928+
)
929+
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
930+
self.assertEqual(next(it), 1)
931+
time.sleep(1)
932+
self.assertEqual(
933+
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
934+
)
935+
it.set_per_worker_buffer_size(2)
936+
self.assertEqual(
937+
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
938+
)
939+
self.assertEqual(next(it), 2)
940+
self.assertEqual(list(it), list(range(3, 11)))
941+
942+
def test_set_per_worker_buffer_size_decrease(self):
943+
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
944+
mp_options = options.MultiprocessingOptions(
945+
num_workers=1, per_worker_buffer_size=2
946+
)
947+
ds = prefetch.MultiprocessPrefetchIterDataset(
948+
ds,
949+
mp_options,
950+
)
951+
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
952+
self.assertEqual(next(it), 1)
953+
time.sleep(1)
954+
self.assertEqual(
955+
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
956+
)
957+
it.set_per_worker_buffer_size(1)
958+
self.assertEqual(
959+
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
960+
)
961+
self.assertEqual(next(it), 2)
962+
self.assertEqual(list(it), list(range(3, 11)))
963+
964+
def test_set_per_worker_buffer_size_to_trigger_error(self):
965+
def f(x):
966+
if x >= 5:
967+
raise ValueError(f'x={x} is too large')
968+
return x
969+
970+
ds = (
971+
dataset.MapDataset.range(10)
972+
.map(f)
973+
.to_iter_dataset(
974+
read_options=options.ReadOptions(prefetch_buffer_size=0)
975+
)
976+
)
977+
mp_options = options.MultiprocessingOptions(
978+
num_workers=1, per_worker_buffer_size=1
979+
)
980+
it = prefetch.MultiprocessPrefetchIterDataset(ds, mp_options).__iter__()
981+
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, it)
982+
self.assertEqual(next(it), 0)
983+
it.set_per_worker_buffer_size(10)
984+
next(it)
985+
time.sleep(3)
986+
q = it._raw_iterator._reader_queue # pytype: disable=attribute-error
987+
# Prefetching will end once an error is put into the reader queue. The
988+
# elements 2, 3, 4 will be in the queue along with the error for 5.
989+
self.assertEqual(q.qsize(), 4)
990+
920991

921992
class ThreadPrefetchIterDatasetTest(parameterized.TestCase):
922993

grain/_src/python/grain_pool.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from grain._src.python import multiprocessing_common
7070
from grain._src.python import record
7171
from grain._src.python import shared_memory_array
72+
from grain._src.python import variable_size_queue
7273
from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member
7374

7475

@@ -225,7 +226,7 @@ def _worker_loop(
225226
*,
226227
args_queue: queues.Queue,
227228
errors_queue: queues.Queue,
228-
output_queue: queues.Queue,
229+
output_queue: variable_size_queue.VariableSizeMultiprocessingQueue,
229230
termination_event: synchronize.Event,
230231
start_profiling_event: synchronize.Event,
231232
stop_profiling_event: synchronize.Event,
@@ -342,6 +343,9 @@ def __init__(
342343
options: MultiprocessingOptions,
343344
worker_init_fn: Callable[[int, int], None] | None = None,
344345
stats_in_queues: tuple[queues.Queue, ...] | None = None,
346+
worker_output_queues: list[
347+
variable_size_queue.VariableSizeMultiprocessingQueue
348+
],
345349
):
346350
"""Initialise a Grain Pool.
347351
@@ -362,11 +366,13 @@ def __init__(
362366
the total worker count.
363367
stats_in_queues: Queue to propagate execution summary from child processes
364368
to the parent.
369+
worker_output_queues: list of queues for each worker to output elements
370+
to.
365371
"""
366372
self.num_processes = options.num_workers
367373
logging.info("Grain pool will use %i processes.", self.num_processes)
368374
self.worker_args_queues = []
369-
self.worker_output_queues = []
375+
self.worker_output_queues = worker_output_queues
370376
self.processes = []
371377
# Reader termination should always result in worker termination. However,
372378
# worker termination should not shut down the reader: workers are terminated
@@ -396,11 +402,10 @@ def __init__(
396402

397403
for worker_index in range(self.num_processes):
398404
worker_args_queue = ctx.Queue(1)
399-
worker_output_queue = ctx.Queue(options.per_worker_buffer_size)
400405
process_kwargs = dict(
401406
args_queue=worker_args_queue,
402407
errors_queue=self.worker_error_queue,
403-
output_queue=worker_output_queue,
408+
output_queue=self.worker_output_queues[worker_index],
404409
stats_out_queue=(
405410
self.stats_in_queues[worker_index]
406411
if self.stats_in_queues
@@ -434,7 +439,6 @@ def __init__(
434439
target=_worker_loop, kwargs=process_kwargs, daemon=True
435440
)
436441
self.worker_args_queues.append(worker_args_queue)
437-
self.worker_output_queues.append(worker_output_queue)
438442
self.processes.append(process)
439443

440444
logging.info("Grain pool will start child processes.")
@@ -589,6 +593,27 @@ class _GrainPoolProcessingComplete:
589593
]
590594

591595

596+
class _ThreadPoolContainer:
597+
"""Container for ThreadPool to allow replacing it."""
598+
599+
def __init__(self, processes: int):
600+
self.pool = pool.ThreadPool(processes)
601+
602+
def apply_async(self, *args, **kwargs):
603+
return self.pool.apply_async(*args, **kwargs)
604+
605+
def close(self):
606+
self.pool.close()
607+
608+
def join(self):
609+
self.pool.join()
610+
611+
def replace_pool(self, num_threads: int):
612+
old_pool = self.pool
613+
self.pool = pool.ThreadPool(num_threads)
614+
old_pool.close()
615+
616+
592617
def _open_shared_memory_for_leaf(element: Any) -> Any:
593618
if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata):
594619
element = shared_memory_array.SharedMemoryArray.from_metadata(element)
@@ -610,6 +635,9 @@ def _process_elements_in_grain_pool(
610635
get_element_producer_fn: GetElementProducerFn,
611636
multiprocessing_options: MultiprocessingOptions,
612637
reader_queue: queue.Queue[_QueueElement],
638+
worker_output_queues: list[
639+
variable_size_queue.VariableSizeMultiprocessingQueue
640+
],
613641
thread_pool: pool.ThreadPool,
614642
termination_event: threading.Event,
615643
start_profiling_event: synchronize.Event | None,
@@ -636,6 +664,7 @@ def read_thread_should_stop():
636664
options=multiprocessing_options,
637665
worker_init_fn=worker_init_fn,
638666
stats_in_queues=stats_in_queues,
667+
worker_output_queues=worker_output_queues,
639668
) as g_pool:
640669
for element in g_pool:
641670
if read_thread_should_stop():
@@ -714,6 +743,7 @@ def __init__(
714743
self._last_worker_index = worker_index_to_start_reading - 1
715744
self._worker_init_fn = worker_init_fn
716745
self._reader_queue = None
746+
self._worker_output_queues = None
717747
self._reader_thread_pool = None
718748
self._termination_event = None
719749
self._reader_thread = None
@@ -736,15 +766,26 @@ def start_prefetch(self) -> None:
736766
self._multiprocessing_options.num_workers
737767
* self._multiprocessing_options.per_worker_buffer_size
738768
)
739-
self._reader_queue = queue.Queue(maxsize=max_buffered_elements)
740-
self._reader_thread_pool = pool.ThreadPool(max_buffered_elements)
769+
self._reader_queue = variable_size_queue.VariableSizeQueue(
770+
max_buffered_elements
771+
)
772+
self._reader_thread_pool = _ThreadPoolContainer(max_buffered_elements)
741773
self._termination_event = threading.Event()
774+
ctx = mp.get_context("spawn")
775+
self._worker_output_queues = []
776+
for _ in range(self._multiprocessing_options.num_workers):
777+
self._worker_output_queues.append(
778+
variable_size_queue.VariableSizeMultiprocessingQueue(
779+
self._multiprocessing_options.per_worker_buffer_size, ctx
780+
)
781+
)
742782
self._reader_thread = threading.Thread(
743783
target=_process_elements_in_grain_pool,
744784
kwargs=dict(
745785
get_element_producer_fn=self._get_element_producer_fn,
746786
multiprocessing_options=self._multiprocessing_options,
747787
reader_queue=self._reader_queue,
788+
worker_output_queues=self._worker_output_queues,
748789
thread_pool=self._reader_thread_pool,
749790
termination_event=self._termination_event,
750791
start_profiling_event=self._start_profiling_event,
@@ -775,6 +816,7 @@ def stop_prefetch(self) -> None:
775816
self._reader_thread_pool = None
776817
self._reader_thread = None
777818
self._reader_queue = None
819+
self._worker_output_queues = None
778820

779821
def __enter__(self):
780822
self.start_prefetch()
@@ -809,7 +851,7 @@ def __next__(self):
809851
"MultiProcessIterator is in an invalid state. Note that"
810852
" MultiProcessIterator should be used with a 'with' statement."
811853
)
812-
element = multiprocessing_common.get_element_from_queue(
854+
element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types
813855
self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error
814856
)
815857
if isinstance(element, Exception):
@@ -826,9 +868,31 @@ def __next__(self):
826868
)
827869

828870
result = multiprocessing_common.get_async_result(
829-
element.async_result, self._termination_event.is_set
871+
element.async_result, self._termination_event.is_set # pytype: disable=attribute-error
830872
)
831873
if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access
832874
raise StopIteration
833875
self._last_worker_index = element.worker_index
834876
return result
877+
878+
def set_per_worker_buffer_size(self, per_worker_buffer_size: int):
879+
"""Sets the per worker buffer size."""
880+
if self._worker_output_queues is None or self._reader_queue is None:
881+
raise ValueError(
882+
"Cannot change per worker buffer size before the iterator has been"
883+
" initialized."
884+
)
885+
for q in self._worker_output_queues:
886+
q.set_max_size(per_worker_buffer_size)
887+
self._reader_queue.set_max_size(
888+
per_worker_buffer_size * self._multiprocessing_options.num_workers
889+
)
890+
self._multiprocessing_options = dataclasses.replace(
891+
self._multiprocessing_options,
892+
per_worker_buffer_size=per_worker_buffer_size,
893+
)
894+
new_thread_count = (
895+
self._multiprocessing_options.num_workers
896+
* self._multiprocessing_options.per_worker_buffer_size
897+
)
898+
self._reader_thread_pool.replace_pool(new_thread_count) # pytype: disable=attribute-error

0 commit comments

Comments
 (0)