Skip to content

Commit e23cbb7

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 838846430
1 parent 31b9760 commit e23cbb7

File tree

3 files changed

+301
-84
lines changed

3 files changed

+301
-84
lines changed

grain/_src/python/dataset/transformations/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ py_test(
5353

5454
py_test(
5555
name = "prefetch_test",
56-
timeout = "long",
56+
timeout = "eternal",
5757
srcs = ["prefetch_test.py"],
5858
shard_count = 50,
5959
srcs_version = "PY3",

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,8 @@ def __str__(self) -> str:
10171017

10181018
def multithread_prefetch(
10191019
ds: dataset.IterDataset[T],
1020-
num_threads: int,
1021-
buffer_size: int,
1020+
num_threads: int = 0,
1021+
buffer_size: int = 1,
10221022
sequential_slice: bool = False,
10231023
) -> dataset.IterDataset[T]:
10241024
"""Uses a pool of threads to prefetch elements ahead of time.
@@ -1043,14 +1043,17 @@ def multithread_prefetch(
10431043
if num_threads == 0:
10441044
return ds
10451045

1046-
_validate_no_double_prefetch(ds)
1046+
dataset_options = _get_dataset_options(ds)
10471047

10481048
shards = []
10491049
for i in range(num_threads):
1050-
worker_ds = copy.deepcopy(ds)
1051-
_set_slice_iter_dataset(
1052-
worker_ds, slice(i, None, num_threads), sequential_slice
1053-
)
1050+
if num_threads == 1:
1051+
worker_ds = ds
1052+
else:
1053+
worker_ds = copy.deepcopy(ds)
1054+
_set_slice_iter_dataset(
1055+
worker_ds, slice(i, None, num_threads), sequential_slice
1056+
)
10541057
shards.append(
10551058
_MpContextIterDataset(
10561059
worker_ds,
@@ -1061,6 +1064,10 @@ def multithread_prefetch(
10611064
)
10621065
)
10631066

1064-
return interleave.InterleaveIterDataset(
1067+
ds = interleave.InterleaveIterDataset(
10651068
shards, cycle_length=num_threads, iter_buffer_size=buffer_size
10661069
)
1070+
# Apply options from parent dataset because interleave dataset does not
1071+
# propagate options.
1072+
ds = dataset.WithOptionsIterDataset(ds, dataset_options)
1073+
return ds

0 commit comments

Comments
 (0)