forked from pytask-dev/pytask
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_utils.py
310 lines (256 loc) · 10.9 KB
/
task_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""Contains utilities related to the ``@pytask.mark.task`` decorator."""
from __future__ import annotations
import inspect
from collections import defaultdict
from typing import Any
from typing import Callable
import attrs
from _pytask.mark import Mark
from _pytask.models import CollectionMetadata
from _pytask.shared import find_duplicates
from _pytask.typing import is_task_function
from pathlib import Path
__all__ = [
"COLLECTED_TASKS",
"parse_collected_tasks_with_task_marker",
"parse_keyword_arguments_from_signature_defaults",
"task",
]
COLLECTED_TASKS: dict[Path, list[Callable[..., Any]]] = defaultdict(list)
"""A container for collecting tasks.
Tasks marked by the ``@pytask.mark.task`` decorator can be generated in a loop where one
iteration overwrites the previous task. To retrieve the tasks later, use this dictionary
mapping from paths of modules to a list of tasks per module.
"""
def task(
name: str | None = None,
*,
after: str | Callable[..., Any] | list[Callable[..., Any]] | None = None,
id: str | None = None, # noqa: A002
kwargs: dict[Any, Any] | None = None,
produces: Any | None = None,
) -> Callable[..., Callable[..., Any]]:
"""Decorate a task function.
This decorator declares every callable as a pytask task.
The function also attaches some metadata to the function like parsed kwargs and
markers.
Parameters
----------
name
Use it to override the name of the task that is, by default, the name of the
task function. Read :ref:`customize-task-names` for more information.
after
An expression or a task function or a list of task functions that need to be
executed before this task can be executed. See :ref:`after` for more
information.
id
An id for the task if it is part of a repetition. Otherwise, an automatic id
will be generated. See :ref:`how-to-repeat-a-task-with-different-inputs-the-id`
for more information.
kwargs
Use a dictionary to pass any keyword arguments to the task function which can be
dependencies or products of the task. Read :ref:`task-kwargs` for more
information.
produces
Use this argument if you want to parse the return of the task function as a
product, but you cannot annotate the return of the function. See :doc:`this
how-to guide <../how_to_guides/using_task_returns>` or :ref:`task-produces` for
more information.
Examples
--------
To mark a function without the ``task_`` prefix as a task, attach the decorator.
.. code-block:: python
from typing import Annotated from pytask import task
@task def create_text_file() -> Annotated[str, Path("file.txt")]:
return "Hello, World!"
"""
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
for arg, arg_name in ((name, "name"), (id, "id")):
if not (isinstance(arg, str) or arg is None):
msg = (
f"Argument {arg_name!r} of @pytask.mark.task must be a str, but it "
f"is {arg!r}."
)
raise ValueError(msg)
parsed_kwargs = {} if kwargs is None else kwargs
parsed_name = name if isinstance(name, str) else func.__name__
parsed_after = _parse_after(after)
unwrapped = inspect.unwrap(func)
if hasattr(unwrapped, "pytask_meta"):
unwrapped.pytask_meta.name = parsed_name
unwrapped.pytask_meta.kwargs = parsed_kwargs
unwrapped.pytask_meta.markers.append(Mark("task", (), {}))
unwrapped.pytask_meta.id_ = id
unwrapped.pytask_meta.produces = produces
unwrapped.pytask_meta.after = parsed_after
else:
unwrapped.pytask_meta = CollectionMetadata(
name=parsed_name,
kwargs=parsed_kwargs,
markers=[Mark("task", (), {})],
id_=id,
produces=produces,
after=parsed_after,
)
# Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage
# collection when the function definition is overwritten in a loop.
# Based on https://stackoverflow.com/questions/1095543/get-name-of-calling-functions-module-in-python # noqa: E501
frm = inspect.stack()[1]
task_module = inspect.getmodule(frm.frame)
task_path = Path(task_module.__file__)
COLLECTED_TASKS[task_path].append(unwrapped)
return unwrapped
# In case the decorator is used without parentheses, wrap the function which is
# passed as the first argument with the default arguments.
if is_task_function(name) and kwargs is None:
return task()(name)
return wrapper
def _parse_after(
after: str | Callable[..., Any] | list[Callable[..., Any]] | None
) -> str | list[Callable[..., Any]]:
if not after:
return []
if isinstance(after, str):
return after
if callable(after):
if not hasattr(after, "pytask_meta"):
after.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
return [after.pytask_meta._id] # type: ignore[attr-defined]
if isinstance(after, list):
new_after = []
for func in after:
if not hasattr(func, "pytask_meta"):
func.pytask_meta = CollectionMetadata() # type: ignore[attr-defined]
new_after.append(func.pytask_meta._id) # type: ignore[attr-defined]
msg = (
"'after' should be an expression string, a task, or a list of class. Got "
f"{after}, instead."
)
raise TypeError(msg)
def parse_collected_tasks_with_task_marker(
tasks: list[Callable[..., Any]],
) -> dict[str, Callable[..., Any]]:
"""Parse collected tasks with a task marker."""
parsed_tasks = _parse_tasks_with_preliminary_names(tasks)
all_names = {i[0] for i in parsed_tasks}
duplicated_names = find_duplicates([i[0] for i in parsed_tasks])
collected_tasks = {}
for name in all_names:
if name in duplicated_names:
selected_tasks = [i for i in parsed_tasks if i[0] == name]
names_to_functions = _generate_ids_for_tasks(selected_tasks)
for unique_name, task in names_to_functions.items():
collected_tasks[unique_name] = task
else:
collected_tasks[name] = next(i[1] for i in parsed_tasks if i[0] == name)
return collected_tasks
def _parse_tasks_with_preliminary_names(
tasks: list[Callable[..., Any]],
) -> list[tuple[str, Callable[..., Any]]]:
"""Parse tasks and generate preliminary names for tasks.
The names are preliminary since they can be duplicated and need to be extended to
properly parametrized ids.
"""
parsed_tasks = []
for task in tasks:
name, function = _parse_task(task)
parsed_tasks.append((name, function))
return parsed_tasks
def _parse_task(task: Callable[..., Any]) -> tuple[str, Callable[..., Any]]:
"""Parse a single task."""
meta = task.pytask_meta # type: ignore[attr-defined]
if meta.name is None and task.__name__ == "_":
msg = (
"A task function either needs 'name' passed by the ``@pytask.mark.task`` "
"decorator or the function name of the task function must not be '_'."
)
raise ValueError(msg)
parsed_name = task.__name__ if meta.name is None else meta.name
parsed_kwargs = _parse_task_kwargs(meta.kwargs)
signature_kwargs = parse_keyword_arguments_from_signature_defaults(task)
meta.kwargs = {**signature_kwargs, **parsed_kwargs}
return parsed_name, task
def _parse_task_kwargs(kwargs: Any) -> dict[str, Any]:
"""Parse task kwargs."""
if isinstance(kwargs, dict):
return kwargs
# Handle namedtuples.
if callable(getattr(kwargs, "_asdict", None)):
return kwargs._asdict()
if attrs.has(type(kwargs)):
return attrs.asdict(kwargs)
msg = (
"'@pytask.mark.task(kwargs=...) needs to be a dictionary, namedtuple or an "
"instance of an attrs class."
)
raise ValueError(msg)
def parse_keyword_arguments_from_signature_defaults(
task: Callable[..., Any]
) -> dict[str, Any]:
"""Parse keyword arguments from signature defaults."""
parameters = inspect.signature(task).parameters
kwargs = {}
for parameter in parameters.values():
if parameter.default is not parameter.empty:
kwargs[parameter.name] = parameter.default
return kwargs
def _generate_ids_for_tasks(
tasks: list[tuple[str, Callable[..., Any]]]
) -> dict[str, Callable[..., Any]]:
"""Generate unique ids for parametrized tasks."""
parameters = inspect.signature(tasks[0][1]).parameters
out = {}
for i, (name, task) in enumerate(tasks):
if task.pytask_meta.id_ is not None: # type: ignore[attr-defined]
id_ = f"{name}[{task.pytask_meta.id_}]" # type: ignore[attr-defined]
elif not parameters:
id_ = f"{name}[{i}]"
else:
stringified_args = [
_arg_value_to_id_component(
arg_name=parameter,
arg_value=task.pytask_meta.kwargs.get( # type: ignore[attr-defined]
parameter
),
i=i,
id_func=None,
)
for parameter in parameters
]
id_ = "-".join(stringified_args)
id_ = f"{name}[{id_}]"
out[id_] = task
return out
def _arg_value_to_id_component(
arg_name: str, arg_value: Any, i: int, id_func: Callable[..., Any] | None
) -> str:
"""Create id component from the name and value of the argument.
First, transform the value of the argument with a user-defined function if given.
Otherwise, take the original value. Then, if the value is a :obj:`bool`,
:obj:`float`, :obj:`int`, or :obj:`str`, cast it to a string. Otherwise, define a
placeholder value from the name of the argument and the iteration.
Parameters
----------
arg_name : str
Name of the parametrized function argument.
arg_value : Any
Value of the argument.
i : int
The ith iteration of the parametrization.
id_func : Union[Callable[..., Any], None]
A callable which maps argument values to :obj:`bool`, :obj:`float`, :obj:`int`,
or :obj:`str` or anything else. Any object with a different dtype than the first
will be mapped to an auto-generated id component.
Returns
-------
id_component : str
A part of the final parametrized id.
"""
id_component = id_func(arg_value) if id_func is not None else None
if isinstance(id_component, (bool, float, int, str)):
id_component = str(id_component)
elif isinstance(arg_value, (bool, float, int, str)):
id_component = str(arg_value)
else:
id_component = arg_name + str(i)
return id_component