Skip to content

Commit 65641fc

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 838933101
1 parent 3a5796b commit 65641fc

File tree

12 files changed

+437
-2562
lines changed

12 files changed

+437
-2562
lines changed

grain/_src/python/BUILD

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -208,49 +208,6 @@ py_test(
208208
],
209209
)
210210

211-
py_library(
212-
name = "grain_pool",
213-
srcs = ["grain_pool.py"],
214-
srcs_version = "PY3",
215-
target_compatible_with = select({
216-
"@platforms//os:windows": ["@platforms//:incompatible"],
217-
"//conditions:default": [],
218-
}),
219-
deps = [
220-
":grain_logging",
221-
":multiprocessing_common",
222-
":options",
223-
":record",
224-
":shared_memory_array",
225-
"//grain/_src/core:config",
226-
"//grain/_src/core:monitoring",
227-
"//grain/_src/core:parallel",
228-
"//grain/_src/core:tree_lib",
229-
"@abseil-py//absl/flags",
230-
"@abseil-py//absl/logging",
231-
"@pypi//cloudpickle:pkg",
232-
],
233-
)
234-
235-
py_test(
236-
name = "grain_pool_test",
237-
srcs = ["grain_pool_test.py"],
238-
shard_count = 20,
239-
srcs_version = "PY3",
240-
tags = ["not_run:arm"],
241-
deps = [
242-
":data_sources",
243-
":grain_pool",
244-
":options",
245-
":record",
246-
"//grain/_src/core:config",
247-
"//grain/_src/core:monitoring",
248-
"@abseil-py//absl/flags",
249-
"@abseil-py//absl/testing:absltest",
250-
"@abseil-py//absl/testing:parameterized",
251-
],
252-
)
253-
254211
py_library(
255212
name = "checkpoint_handlers",
256213
srcs = ["checkpoint_handlers.py"],

grain/_src/python/data_loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from grain._src.python.dataset import dataset
4040
from grain._src.python.dataset.transformations import batch as batch_ds
4141
from grain._src.python.dataset.transformations import flatmap
42-
from grain._src.python.dataset.transformations import prefetch
4342
from grain._src.python.operations import BatchOperation
4443
from grain._src.python.operations import Operation
4544
from grain._src.python.samplers import Sampler
@@ -462,10 +461,8 @@ def _create_dataset(self) -> dataset.IterDataset:
462461
ds = _apply_transform_to_dataset(operation, ds)
463462
ds = ds.map(lambda r: r.data)
464463
if self.multiprocessing_options.num_workers > 0:
465-
ds = prefetch.MultiprocessPrefetchIterDataset(
466-
ds,
464+
ds = ds.mp_prefetch(
467465
self.multiprocessing_options,
468-
always_report_worker_state=True,
469466
)
470467
if not self._use_native_dataset_checkpointing:
471468
ds = _DataLoaderStateIterDataset(

grain/_src/python/dataset/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ py_library(
5353
"//grain/_src/core:tree_lib",
5454
"//grain/_src/python:checkpointing",
5555
"//grain/_src/python:grain_logging",
56-
"//grain/_src/python:grain_pool",
5756
"//grain/_src/python:options",
5857
"//grain/_src/python:shared_memory_array",
5958
"//grain/proto:execution_summary_py_pb2",

grain/_src/python/dataset/dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,13 +1324,14 @@ def mp_prefetch(
13241324
A dataset prefetching input elements in separate processes.
13251325
"""
13261326
options = options or grain_options.MultiprocessingOptions(num_workers=10)
1327-
# Loaded lazily due to a circular dependency (dataset <-> prefetch).
1327+
# Loaded lazily due to a circular dependency (dataset <-> process_prefetch).
13281328
# pylint: disable=g-import-not-at-top
1329-
from grain._src.python.dataset.transformations import prefetch
1329+
from grain._src.python.dataset.transformations import process_prefetch
13301330
# pylint: enable=g-import-not-at-top
1331-
return prefetch.MultiprocessPrefetchIterDataset(
1331+
return process_prefetch.multiprocess_prefetch(
13321332
self,
1333-
multiprocessing_options=options,
1333+
num_workers=options.num_workers,
1334+
buffer_size=options.per_worker_buffer_size,
13341335
worker_init_fn=worker_init_fn,
13351336
sequential_slice=sequential_slice,
13361337
)

grain/_src/python/dataset/transformations/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ py_test(
5353

5454
py_test(
5555
name = "prefetch_test",
56-
timeout = "long",
56+
timeout = "eternal",
5757
srcs = ["prefetch_test.py"],
5858
shard_count = 50,
5959
srcs_version = "PY3",

0 commit comments

Comments
 (0)