5
5
import inspect
6
6
import sys
7
7
import warnings
8
+ from contextlib import redirect_stderr
9
+ from contextlib import redirect_stdout
8
10
from functools import partial
11
+ from io import StringIO
9
12
from typing import TYPE_CHECKING
10
13
from typing import Any
11
14
from typing import Callable
@@ -98,6 +101,8 @@ def _execute_task( # noqa: PLR0913
98
101
PyTree [PythonNode | None ],
99
102
list [WarningReport ],
100
103
tuple [type [BaseException ], BaseException , str ] | None ,
104
+ str ,
105
+ str ,
101
106
]:
102
107
"""Unserialize and execute task.
103
108
@@ -111,8 +116,13 @@ def _execute_task( # noqa: PLR0913
111
116
# Patch set_trace and breakpoint to show a better error message.
112
117
_patch_set_trace_and_breakpoint ()
113
118
119
+ captured_stdout_buffer = StringIO ()
120
+ captured_stderr_buffer = StringIO ()
121
+
114
122
# Catch warnings and store them in a list.
115
- with warnings .catch_warnings (record = True ) as log :
123
+ with warnings .catch_warnings (record = True ) as log , redirect_stdout (
124
+ captured_stdout_buffer
125
+ ), redirect_stderr (captured_stderr_buffer ):
116
126
# Apply global filterwarnings.
117
127
for arg in session_filterwarnings :
118
128
warnings .filterwarnings (* parse_warning_filter (arg , escape = False ))
@@ -146,12 +156,25 @@ def _execute_task( # noqa: PLR0913
146
156
)
147
157
)
148
158
159
+ captured_stdout_buffer .seek (0 )
160
+ captured_stderr_buffer .seek (0 )
161
+ captured_stdout = captured_stdout_buffer .read ()
162
+ captured_stderr = captured_stderr_buffer .read ()
163
+ captured_stdout_buffer .close ()
164
+ captured_stderr_buffer .close ()
165
+
149
166
# Collect all PythonNodes that are products to pass values back to the main process.
150
167
python_nodes = tree_map (
151
168
lambda x : x if isinstance (x , PythonNode ) else None , task .produces
152
169
)
153
170
154
- return python_nodes , warning_reports , processed_exc_info
171
+ return (
172
+ python_nodes ,
173
+ warning_reports ,
174
+ processed_exc_info ,
175
+ captured_stdout ,
176
+ captured_stderr ,
177
+ )
155
178
156
179
157
180
def _process_exception (
0 commit comments