-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrappers.py
320 lines (252 loc) · 10.8 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""Contains functions related to processes and loky."""
from __future__ import annotations
import functools
import os
import sys
import warnings
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from contextlib import suppress
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from attrs import define
from pytask import PNode
from pytask import PPathNode
from pytask import PTask
from pytask import PythonNode
from pytask import Traceback
from pytask import WarningReport
from pytask import console
from pytask import parse_warning_filter
from pytask import warning_record_to_str
from pytask.tree_util import PyTree
from pytask.tree_util import tree_map
from pytask.tree_util import tree_map_with_path
from pytask.tree_util import tree_structure
from pytask_parallel.nodes import RemotePathNode
from pytask_parallel.typing import CarryOverPath
from pytask_parallel.typing import is_local_path
from pytask_parallel.utils import CoiledFunction
if TYPE_CHECKING:
from types import TracebackType
from pytask import Mark
from rich.console import ConsoleOptions
__all__ = ["wrap_task_in_process", "wrap_task_in_thread"]
@define(kw_only=True)
class WrapperResult:
carry_over_products: PyTree[CarryOverPath | PythonNode | None]
warning_reports: list[WarningReport]
exc_info: (
tuple[type[BaseException], BaseException, TracebackType | str | None] | None
)
stdout: str
stderr: str
def wrap_task_in_thread(
task: PTask, *, remote: bool, shared_memory: dict[str, bool] | None, **kwargs: Any
) -> WrapperResult:
"""Mock execution function such that it returns the same as for processes.
The function for processes returns ``warning_reports`` and an ``exception``. With
threads, these object are collected by the main and not the subprocess. So, we just
return placeholders.
"""
__tracebackhide__ = True
# Add task to shared memory to indicate that it is currently being executed.
if shared_memory is not None:
shared_memory[task.signature] = True
try:
out = task.function(**kwargs)
except Exception: # noqa: BLE001
exc_info = sys.exc_info()
else:
_handle_function_products(task, out, remote=remote)
exc_info = None # type: ignore[assignment]
# Remove task from shared memory to indicate that it is no longer being executed.
if shared_memory is not None:
shared_memory.pop(task.signature)
return WrapperResult(
carry_over_products=None, # type: ignore[arg-type]
warning_reports=[],
exc_info=exc_info, # type: ignore[arg-type]
stdout="",
stderr="",
)
def wrap_task_in_process( # noqa: PLR0913
task: PTask,
*,
console_options: ConsoleOptions,
kwargs: dict[str, Any],
remote: bool,
session_filterwarnings: tuple[str, ...],
shared_memory: dict[str, bool] | None,
show_locals: bool,
task_filterwarnings: tuple[Mark, ...],
) -> WrapperResult:
"""Execute a task in a spawned process.
This function receives bytes and unpickles them to a task which is them execute in a
spawned process or thread.
"""
# Hide this function from tracebacks.
__tracebackhide__ = True
# Add task to shared memory to indicate that it is currently being executed.
if shared_memory is not None:
shared_memory[task.signature] = True
# Patch set_trace and breakpoint to show a better error message.
_patch_set_trace_and_breakpoint()
captured_stdout_buffer = StringIO()
captured_stderr_buffer = StringIO()
# Catch warnings and store them in a list.
with warnings.catch_warnings(record=True) as log, redirect_stdout(
captured_stdout_buffer
), redirect_stderr(captured_stderr_buffer):
# Apply global filterwarnings.
for arg in session_filterwarnings:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
# Apply filters from "filterwarnings" marks
for mark in task_filterwarnings:
for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
processed_exc_info: tuple[type[BaseException], BaseException, str] | None
try:
resolved_kwargs = _write_local_files_to_remote(kwargs)
out = task.execute(**resolved_kwargs)
except Exception: # noqa: BLE001
exc_info = sys.exc_info()
processed_exc_info = _render_traceback_to_string(
exc_info, # type: ignore[arg-type]
show_locals,
console_options,
)
products = None
else:
products = _handle_function_products(task, out, remote=remote)
processed_exc_info = None
_delete_local_files_on_remote(kwargs)
task_display_name = getattr(task, "display_name", task.name)
warning_reports = []
for warning_message in log:
fs_location = warning_message.filename, warning_message.lineno
warning_reports.append(
WarningReport(
message=warning_record_to_str(warning_message),
fs_location=fs_location,
id_=task_display_name,
)
)
captured_stdout_buffer.seek(0)
captured_stderr_buffer.seek(0)
captured_stdout = captured_stdout_buffer.read()
captured_stderr = captured_stderr_buffer.read()
captured_stdout_buffer.close()
captured_stderr_buffer.close()
# Remove task from shared memory to indicate that it is no longer being executed.
if shared_memory is not None:
shared_memory.pop(task.signature)
return WrapperResult(
carry_over_products=products, # type: ignore[arg-type]
warning_reports=warning_reports,
exc_info=processed_exc_info,
stdout=captured_stdout,
stderr=captured_stderr,
)
def rewrap_task_with_coiled_function(task: PTask) -> CoiledFunction:
return functools.wraps(wrap_task_in_process)(
CoiledFunction(wrap_task_in_process, **task.attributes["coiled_kwargs"])
)
def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001
msg = (
"You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the "
"execution of tasks with pytask-parallel. Please, remove the breakpoint or run "
"the task without parallelization to debug it."
)
raise RuntimeError(msg)
def _patch_set_trace_and_breakpoint() -> None:
"""Patch :func:`pdb.set_trace` and :func:`breakpoint`.
Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in
a subprocess and print a better exception message.
"""
import pdb # noqa: T100
import sys
pdb.set_trace = _raise_exception_on_breakpoint
sys.breakpointhook = _raise_exception_on_breakpoint
def _render_traceback_to_string(
exc_info: tuple[type[BaseException], BaseException, TracebackType | None],
show_locals: bool, # noqa: FBT001
console_options: ConsoleOptions,
) -> tuple[type[BaseException], BaseException, str]:
"""Process the exception and convert the traceback to a string."""
traceback = Traceback(exc_info, show_locals=show_locals)
segments = console.render(traceback, options=console_options)
text = "".join(segment.text for segment in segments)
return (*exc_info[:2], text)
def _handle_function_products(
task: PTask, out: Any, *, remote: bool = False
) -> PyTree[CarryOverPath | PythonNode | None]:
"""Handle the products of the task.
The functions first responsibility is to push the returns of the function to the
defined nodes.
Its second responsibility is to carry over products from remote to local
environments if the product is a :class:`PPathNode` with a local path.
"""
# Check that the return value has the correct structure.
if "return" in task.produces:
structure_out = tree_structure(out)
structure_return = tree_structure(task.produces["return"])
# strict must be false when none is leaf.
if not structure_return.is_prefix(structure_out, strict=False):
msg = (
"The structure of the return annotation is not a subtree of "
"the structure of the function return.\n\nFunction return: "
f"{structure_out}\n\nReturn annotation: {structure_return}"
)
raise ValueError(msg)
def _save_and_carry_over_product(
path: tuple[Any, ...], node: PNode
) -> CarryOverPath | PythonNode | None:
argument = path[0]
# Handle the case when it is not a return annotation product.
if argument != "return":
if isinstance(node, PythonNode):
return node
# If the product was a local path and we are remote, we load the file
# content as bytes and carry it over.
if isinstance(node, PPathNode) and is_local_path(node.path) and remote:
return CarryOverPath(content=node.path.read_bytes())
return None
# If it is a return value annotation, index the return until we get the value.
value = out
for p in path[1:]:
value = value[p]
# If the node is a PythonNode, we need to carry it over to the main process.
if isinstance(node, PythonNode):
node.save(value=value)
return node
# If the path is local and we are remote, we need to carry over the value to
# the main process as a PythonNode and save it later.
if isinstance(node, PPathNode) and is_local_path(node.path) and remote:
return PythonNode(value=value)
# If no condition applies, we save the value and do not carry it over. Like a
# remote path to S3.
node.save(value)
return None
return tree_map_with_path(_save_and_carry_over_product, task.produces)
def _write_local_files_to_remote(
kwargs: dict[str, PyTree[Any]],
) -> dict[str, PyTree[Any]]:
"""Write local files to remote.
The main process pushed over kwargs that might contain RemotePathNodes. These need
to be resolved.
"""
return tree_map(lambda x: x.load() if isinstance(x, RemotePathNode) else x, kwargs) # type: ignore[return-value]
def _delete_local_files_on_remote(kwargs: dict[str, PyTree[Any]]) -> None:
"""Delete local files on remote.
Local files were copied over to the remote via RemotePathNodes. We need to delete
them after the task is executed.
"""
def _delete(potential_node: Any) -> None:
if isinstance(potential_node, RemotePathNode):
with suppress(OSError):
os.close(potential_node.fd)
Path(potential_node.remote_path).unlink(missing_ok=True)
tree_map(_delete, kwargs)