Skip to content

Commit d3a7bdc

Browse files
authored
Redirect stdout and stderr. (#92)
1 parent f1049a5 commit d3a7bdc

File tree

6 files changed

+112
-10
lines changed

6 files changed

+112
-10
lines changed

CHANGES.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
1010
- {pull}`85` simplifies code since loky is a dependency.
1111
- {pull}`88` updates handling `Traceback`.
1212
- {pull}`89` restructures the package.
13+
- {pull}`92` redirects stdout and stderr from processes and loky and shows them in error
14+
reports.
1315

1416
## 0.4.1 - 2024-01-12
1517

src/pytask_parallel/execute.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,24 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
9090
future = running_tasks[task_name]
9191

9292
if future.done():
93-
python_nodes, warnings_reports, exc_info = parse_future_result(
94-
future
95-
)
93+
(
94+
python_nodes,
95+
warnings_reports,
96+
exc_info,
97+
captured_stdout,
98+
captured_stderr,
99+
) = parse_future_result(future)
96100
session.warnings.extend(warnings_reports)
97101

102+
if captured_stdout:
103+
task.report_sections.append(
104+
("call", "stdout", captured_stdout)
105+
)
106+
if captured_stderr:
107+
task.report_sections.append(
108+
("call", "stderr", captured_stderr)
109+
)
110+
98111
if exc_info is not None:
99112
task = session.dag.nodes[task_name]["task"]
100113
newly_collected_reports.append(

src/pytask_parallel/processes.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import inspect
66
import sys
77
import warnings
8+
from contextlib import redirect_stderr
9+
from contextlib import redirect_stdout
810
from functools import partial
11+
from io import StringIO
912
from typing import TYPE_CHECKING
1013
from typing import Any
1114
from typing import Callable
@@ -98,6 +101,8 @@ def _execute_task( # noqa: PLR0913
98101
PyTree[PythonNode | None],
99102
list[WarningReport],
100103
tuple[type[BaseException], BaseException, str] | None,
104+
str,
105+
str,
101106
]:
102107
"""Unserialize and execute task.
103108
@@ -111,8 +116,13 @@ def _execute_task( # noqa: PLR0913
111116
# Patch set_trace and breakpoint to show a better error message.
112117
_patch_set_trace_and_breakpoint()
113118

119+
captured_stdout_buffer = StringIO()
120+
captured_stderr_buffer = StringIO()
121+
114122
# Catch warnings and store them in a list.
115-
with warnings.catch_warnings(record=True) as log:
123+
with warnings.catch_warnings(record=True) as log, redirect_stdout(
124+
captured_stdout_buffer
125+
), redirect_stderr(captured_stderr_buffer):
116126
# Apply global filterwarnings.
117127
for arg in session_filterwarnings:
118128
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
@@ -146,12 +156,25 @@ def _execute_task( # noqa: PLR0913
146156
)
147157
)
148158

159+
captured_stdout_buffer.seek(0)
160+
captured_stderr_buffer.seek(0)
161+
captured_stdout = captured_stdout_buffer.read()
162+
captured_stderr = captured_stderr_buffer.read()
163+
captured_stdout_buffer.close()
164+
captured_stderr_buffer.close()
165+
149166
# Collect all PythonNodes that are products to pass values back to the main process.
150167
python_nodes = tree_map(
151168
lambda x: x if isinstance(x, PythonNode) else None, task.produces
152169
)
153170

154-
return python_nodes, warning_reports, processed_exc_info
171+
return (
172+
python_nodes,
173+
warning_reports,
174+
processed_exc_info,
175+
captured_stdout,
176+
captured_stderr,
177+
)
155178

156179

157180
def _process_exception(

src/pytask_parallel/threads.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[Any]:
3535
def _mock_processes_for_threads(
3636
task: PTask, **kwargs: Any
3737
) -> tuple[
38-
None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None
38+
None,
39+
list[Any],
40+
tuple[type[BaseException], BaseException, TracebackType] | None,
41+
str,
42+
str,
3943
]:
4044
"""Mock execution function such that it returns the same as for processes.
4145
@@ -52,4 +56,4 @@ def _mock_processes_for_threads(
5256
else:
5357
handle_task_function_return(task, out)
5458
exc_info = None
55-
return None, [], exc_info
59+
return None, [], exc_info, "", ""

src/pytask_parallel/utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@ def parse_future_result(
2626
dict[str, PyTree[PythonNode | None]] | None,
2727
list[WarningReport],
2828
tuple[type[BaseException], BaseException, TracebackType] | None,
29+
str,
30+
str,
2931
]:
3032
"""Parse the result of a future."""
3133
# An exception was raised before the task was executed.
3234
future_exception = future.exception()
3335
if future_exception is not None:
3436
exc_info = _parse_future_exception(future_exception)
35-
return None, [], exc_info
37+
return None, [], exc_info, "", ""
3638

3739
out = future.result()
38-
if isinstance(out, tuple) and len(out) == 3: # noqa: PLR2004
40+
if isinstance(out, tuple) and len(out) == 5: # noqa: PLR2004
3941
return out
4042

4143
if out is None:
42-
return None, [], None
44+
return None, [], None, "", ""
4345

4446
# What to do when the output does not match?
4547
msg = (

tests/test_capture.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import textwrap
2+
3+
import pytest
4+
from pytask import ExitCode
5+
from pytask import cli
6+
from pytask_parallel import ParallelBackend
7+
8+
9+
@pytest.mark.end_to_end()
10+
@pytest.mark.parametrize(
11+
"parallel_backend", [ParallelBackend.PROCESSES, ParallelBackend.LOKY]
12+
)
13+
@pytest.mark.parametrize("show_capture", ["no", "stdout", "stderr", "all"])
14+
def test_show_capture(tmp_path, runner, parallel_backend, show_capture):
15+
source = """
16+
import sys
17+
18+
def task_show_capture():
19+
sys.stdout.write("xxxx")
20+
sys.stderr.write("zzzz")
21+
raise Exception
22+
"""
23+
tmp_path.joinpath("task_show_capture.py").write_text(textwrap.dedent(source))
24+
25+
cmd_arg = "-s" if show_capture == "s" else f"--show-capture={show_capture}"
26+
result = runner.invoke(
27+
cli,
28+
[
29+
tmp_path.as_posix(),
30+
cmd_arg,
31+
"--parallel-backend",
32+
parallel_backend,
33+
"-n",
34+
"2",
35+
],
36+
)
37+
38+
assert result.exit_code == ExitCode.FAILED
39+
40+
if show_capture in ("no", "s"):
41+
assert "Captured" not in result.output
42+
elif show_capture == "stdout":
43+
assert "Captured stdout" in result.output
44+
assert "xxxx" in result.output
45+
assert "Captured stderr" not in result.output
46+
# assert "zzzz" not in result.output
47+
elif show_capture == "stderr":
48+
assert "Captured stdout" not in result.output
49+
# assert "xxxx" not in result.output
50+
assert "Captured stderr" in result.output
51+
assert "zzzz" in result.output
52+
elif show_capture == "all":
53+
assert "Captured stdout" in result.output
54+
assert "xxxx" in result.output
55+
assert "Captured stderr" in result.output
56+
assert "zzzz" in result.output
57+
else: # pragma: no cover
58+
raise NotImplementedError

0 commit comments

Comments
 (0)