@@ -70,6 +70,20 @@ def robust_map(
70
70
...
71
71
72
72
73
+ @overload
74
+ def robust_map (
75
+ fn : Callable [[T ], U ],
76
+ inputs : Sequence [T ],
77
+ error_output : V = ...,
78
+ index_to_output : dict [int , U | V ] | None = ...,
79
+ max_retries : int | None = ...,
80
+ max_workers : int | None = ...,
81
+ raise_error : bool = ...,
82
+ progress_desc : str = ...,
83
+ ) -> list [U | V ]:
84
+ ...
85
+
86
+
73
87
# TODO(trandustin): Support nested structure inputs like jax.tree.map.
74
88
def robust_map (
75
89
fn : Callable [[T ], U ],
@@ -80,6 +94,7 @@ def robust_map(
80
94
max_workers : int | None = None ,
81
95
raise_error : bool = False ,
82
96
retry_exception_types : list [type [Exception ]] | None = None ,
97
+ progress_desc : str = 'robust_map' ,
83
98
) -> list [U | V ]:
84
99
"""Maps a function to inputs using a threadpool.
85
100
@@ -103,6 +118,7 @@ def robust_map(
103
118
Will override any setting of `error_output`.
104
119
retry_exception_types: Exception types to retry on. Defaults to retrying
105
120
only on grpc's RPC exceptions.
121
+ progress_desc: A string to display in the progress bar.
106
122
107
123
Returns:
108
124
A list of items each of type U. They are the outputs of `fn` applied to
@@ -139,7 +155,7 @@ def robust_map(
139
155
num_existing = len (index_to_output )
140
156
num_inputs = len (inputs )
141
157
logging .info ('Found %s/%s existing examples.' , num_existing , num_inputs )
142
- progress_bar = tqdm .tqdm (total = num_inputs - num_existing , desc = 'robust_map' )
158
+ progress_bar = tqdm .tqdm (total = num_inputs - num_existing , desc = progress_desc )
143
159
indices = [i for i in range (num_inputs ) if i not in index_to_output .keys ()]
144
160
with concurrent .futures .ThreadPoolExecutor (
145
161
max_workers = max_workers
0 commit comments