Skip to content

Commit a731964

Browse files
Edward2 Teamedward-bot
Edward2 Team
authored andcommitted
Internal change
PiperOrigin-RevId: 653739078
1 parent 6fc98b8 commit a731964

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

edward2/maps.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ def robust_map(
7070
...
7171

7272

73+
@overload
74+
def robust_map(
75+
fn: Callable[[T], U],
76+
inputs: Sequence[T],
77+
error_output: V = ...,
78+
index_to_output: dict[int, U | V] | None = ...,
79+
max_retries: int | None = ...,
80+
max_workers: int | None = ...,
81+
raise_error: bool = ...,
82+
progress_desc: str = ...,
83+
) -> list[U | V]:
84+
...
85+
86+
7387
# TODO(trandustin): Support nested structure inputs like jax.tree.map.
7488
def robust_map(
7589
fn: Callable[[T], U],
@@ -80,6 +94,7 @@ def robust_map(
8094
max_workers: int | None = None,
8195
raise_error: bool = False,
8296
retry_exception_types: list[type[Exception]] | None = None,
97+
progress_desc: str = 'robust_map',
8398
) -> list[U | V]:
8499
"""Maps a function to inputs using a threadpool.
85100
@@ -103,6 +118,7 @@ def robust_map(
103118
Will override any setting of `error_output`.
104119
retry_exception_types: Exception types to retry on. Defaults to retrying
105120
only on grpc's RPC exceptions.
121+
progress_desc: A string to display in the progress bar.
106122
107123
Returns:
108124
A list of items each of type U. They are the outputs of `fn` applied to
@@ -139,7 +155,7 @@ def robust_map(
139155
num_existing = len(index_to_output)
140156
num_inputs = len(inputs)
141157
logging.info('Found %s/%s existing examples.', num_existing, num_inputs)
142-
progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc='robust_map')
158+
progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc=progress_desc)
143159
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
144160
with concurrent.futures.ThreadPoolExecutor(
145161
max_workers=max_workers

0 commit comments

Comments
 (0)