6969from grain ._src .python import multiprocessing_common
7070from grain ._src .python import record
7171from grain ._src .python import shared_memory_array
72+ from grain ._src .python import variable_size_queue
7273from grain ._src .python .options import MultiprocessingOptions # pylint: disable=g-importing-member
7374
7475
@@ -225,7 +226,7 @@ def _worker_loop(
225226 * ,
226227 args_queue : queues .Queue ,
227228 errors_queue : queues .Queue ,
228- output_queue : queues . Queue ,
229+ output_queue : variable_size_queue . VariableSizeMultiprocessingQueue ,
229230 termination_event : synchronize .Event ,
230231 start_profiling_event : synchronize .Event ,
231232 stop_profiling_event : synchronize .Event ,
@@ -342,6 +343,9 @@ def __init__(
342343 options : MultiprocessingOptions ,
343344 worker_init_fn : Callable [[int , int ], None ] | None = None ,
344345 stats_in_queues : tuple [queues .Queue , ...] | None = None ,
346+ worker_output_queues : list [
347+ variable_size_queue .VariableSizeMultiprocessingQueue
348+ ],
345349 ):
346350 """Initialise a Grain Pool.
347351
@@ -362,11 +366,13 @@ def __init__(
362366 the total worker count.
363367 stats_in_queues: Queue to propagate execution summary from child processes
364368 to the parent.
369+ worker_output_queues: list of queues for each worker to output elements
370+ to.
365371 """
366372 self .num_processes = options .num_workers
367373 logging .info ("Grain pool will use %i processes." , self .num_processes )
368374 self .worker_args_queues = []
369- self .worker_output_queues = []
375+ self .worker_output_queues = worker_output_queues
370376 self .processes = []
371377 # Reader termination should always result in worker termination. However,
372378 # worker termination should not shut down the reader: workers are terminated
@@ -396,11 +402,10 @@ def __init__(
396402
397403 for worker_index in range (self .num_processes ):
398404 worker_args_queue = ctx .Queue (1 )
399- worker_output_queue = ctx .Queue (options .per_worker_buffer_size )
400405 process_kwargs = dict (
401406 args_queue = worker_args_queue ,
402407 errors_queue = self .worker_error_queue ,
403- output_queue = worker_output_queue ,
408+ output_queue = self . worker_output_queues [ worker_index ] ,
404409 stats_out_queue = (
405410 self .stats_in_queues [worker_index ]
406411 if self .stats_in_queues
@@ -434,7 +439,6 @@ def __init__(
434439 target = _worker_loop , kwargs = process_kwargs , daemon = True
435440 )
436441 self .worker_args_queues .append (worker_args_queue )
437- self .worker_output_queues .append (worker_output_queue )
438442 self .processes .append (process )
439443
440444 logging .info ("Grain pool will start child processes." )
@@ -589,6 +593,27 @@ class _GrainPoolProcessingComplete:
589593]
590594
591595
596+ class _ThreadPoolContainer :
597+ """Container for ThreadPool to allow replacing it."""
598+
599+ def __init__ (self , processes : int ):
600+ self .pool = pool .ThreadPool (processes )
601+
602+ def apply_async (self , * args , ** kwargs ):
603+ return self .pool .apply_async (* args , ** kwargs )
604+
605+ def close (self ):
606+ self .pool .close ()
607+
608+ def join (self ):
609+ self .pool .join ()
610+
611+ def replace_pool (self , num_threads : int ):
612+ old_pool = self .pool
613+ self .pool = pool .ThreadPool (num_threads )
614+ old_pool .close ()
615+
616+
592617def _open_shared_memory_for_leaf (element : Any ) -> Any :
593618 if isinstance (element , shared_memory_array .SharedMemoryArrayMetadata ):
594619 element = shared_memory_array .SharedMemoryArray .from_metadata (element )
@@ -610,6 +635,9 @@ def _process_elements_in_grain_pool(
610635 get_element_producer_fn : GetElementProducerFn ,
611636 multiprocessing_options : MultiprocessingOptions ,
612637 reader_queue : queue .Queue [_QueueElement ],
638+ worker_output_queues : list [
639+ variable_size_queue .VariableSizeMultiprocessingQueue
640+ ],
613641 thread_pool : pool .ThreadPool ,
614642 termination_event : threading .Event ,
615643 start_profiling_event : synchronize .Event | None ,
@@ -636,6 +664,7 @@ def read_thread_should_stop():
636664 options = multiprocessing_options ,
637665 worker_init_fn = worker_init_fn ,
638666 stats_in_queues = stats_in_queues ,
667+ worker_output_queues = worker_output_queues ,
639668 ) as g_pool :
640669 for element in g_pool :
641670 if read_thread_should_stop ():
@@ -714,6 +743,7 @@ def __init__(
714743 self ._last_worker_index = worker_index_to_start_reading - 1
715744 self ._worker_init_fn = worker_init_fn
716745 self ._reader_queue = None
746+ self ._worker_output_queues = None
717747 self ._reader_thread_pool = None
718748 self ._termination_event = None
719749 self ._reader_thread = None
@@ -736,15 +766,26 @@ def start_prefetch(self) -> None:
736766 self ._multiprocessing_options .num_workers
737767 * self ._multiprocessing_options .per_worker_buffer_size
738768 )
739- self ._reader_queue = queue .Queue (maxsize = max_buffered_elements )
740- self ._reader_thread_pool = pool .ThreadPool (max_buffered_elements )
769+ self ._reader_queue = variable_size_queue .VariableSizeQueue (
770+ max_buffered_elements
771+ )
772+ self ._reader_thread_pool = _ThreadPoolContainer (max_buffered_elements )
741773 self ._termination_event = threading .Event ()
774+ ctx = mp .get_context ("spawn" )
775+ self ._worker_output_queues = []
776+ for _ in range (self ._multiprocessing_options .num_workers ):
777+ self ._worker_output_queues .append (
778+ variable_size_queue .VariableSizeMultiprocessingQueue (
779+ self ._multiprocessing_options .per_worker_buffer_size , ctx
780+ )
781+ )
742782 self ._reader_thread = threading .Thread (
743783 target = _process_elements_in_grain_pool ,
744784 kwargs = dict (
745785 get_element_producer_fn = self ._get_element_producer_fn ,
746786 multiprocessing_options = self ._multiprocessing_options ,
747787 reader_queue = self ._reader_queue ,
788+ worker_output_queues = self ._worker_output_queues ,
748789 thread_pool = self ._reader_thread_pool ,
749790 termination_event = self ._termination_event ,
750791 start_profiling_event = self ._start_profiling_event ,
@@ -775,6 +816,7 @@ def stop_prefetch(self) -> None:
775816 self ._reader_thread_pool = None
776817 self ._reader_thread = None
777818 self ._reader_queue = None
819+ self ._worker_output_queues = None
778820
779821 def __enter__ (self ):
780822 self .start_prefetch ()
@@ -809,7 +851,7 @@ def __next__(self):
809851 "MultiProcessIterator is in an invalid state. Note that"
810852 " MultiProcessIterator should be used with a 'with' statement."
811853 )
812- element = multiprocessing_common .get_element_from_queue (
854+ element = multiprocessing_common .get_element_from_queue ( # pytype: disable=wrong-arg-types
813855 self ._reader_queue , self ._termination_event .is_set # pytype: disable=attribute-error
814856 )
815857 if isinstance (element , Exception ):
@@ -826,9 +868,31 @@ def __next__(self):
826868 )
827869
828870 result = multiprocessing_common .get_async_result (
829- element .async_result , self ._termination_event .is_set
871+ element .async_result , self ._termination_event .is_set # pytype: disable=attribute-error
830872 )
831873 if isinstance (result , multiprocessing_common ._SystemTerminated ): # pylint: disable=protected-access
832874 raise StopIteration
833875 self ._last_worker_index = element .worker_index
834876 return result
877+
878+ def set_per_worker_buffer_size (self , per_worker_buffer_size : int ):
879+ """Sets the per worker buffer size."""
880+ if self ._worker_output_queues is None or self ._reader_queue is None :
881+ raise ValueError (
882+ "Cannot change per worker buffer size before the iterator has been"
883+ " initialized."
884+ )
885+ for q in self ._worker_output_queues :
886+ q .set_max_size (per_worker_buffer_size )
887+ self ._reader_queue .set_max_size (
888+ per_worker_buffer_size * self ._multiprocessing_options .num_workers
889+ )
890+ self ._multiprocessing_options = dataclasses .replace (
891+ self ._multiprocessing_options ,
892+ per_worker_buffer_size = per_worker_buffer_size ,
893+ )
894+ new_thread_count = (
895+ self ._multiprocessing_options .num_workers
896+ * self ._multiprocessing_options .per_worker_buffer_size
897+ )
898+ self ._reader_thread_pool .replace_pool (new_thread_count ) # pytype: disable=attribute-error
0 commit comments