1
1
"""Collect tasks."""
2
2
from __future__ import annotations
3
3
4
- import functools
5
4
import subprocess
5
+ import warnings
6
6
from pathlib import Path
7
- from types import FunctionType
8
7
from typing import Any
9
8
10
- from pytask import depends_on
11
9
from pytask import has_mark
12
10
from pytask import hookimpl
11
+ from pytask import is_task_function
13
12
from pytask import Mark
14
- from pytask import parse_nodes
15
- from pytask import produces
13
+ from pytask import NodeInfo
14
+ from pytask import parse_dependencies_from_task_function
15
+ from pytask import parse_products_from_task_function
16
+ from pytask import PathNode
17
+ from pytask import PTask
18
+ from pytask import PythonNode
16
19
from pytask import remove_marks
17
20
from pytask import Session
18
21
from pytask import Task
22
+ from pytask import TaskWithoutPath
23
+ from pytask_r .serialization import create_path_to_serialized
19
24
from pytask_r .serialization import SERIALIZERS
20
25
from pytask_r .shared import r
21
- from pytask_r .shared import R_SCRIPT_KEY
22
26
23
27
24
- def run_r_script (script : Path , options : list [str ], serialized : Path ) -> None :
28
+ def run_r_script (
29
+ _script : Path , _options : list [str ], _serialized : Path , ** kwargs : Any # noqa: ARG001
30
+ ) -> None :
25
31
"""Run an R script."""
26
- cmd = ["Rscript" , script .as_posix (), * options , str (serialized )]
32
+ cmd = ["Rscript" , _script .as_posix (), * _options , str (_serialized )]
27
33
print ("Executing " + " " .join (cmd ) + "." ) # noqa: T201
28
34
subprocess .run (cmd , check = True ) # noqa: S603
29
35
30
36
31
37
@hookimpl
32
38
def pytask_collect_task (
33
- session : Session , path : Path , name : str , obj : Any
34
- ) -> Task | None :
39
+ session : Session , path : Path | None , name : str , obj : Any
40
+ ) -> PTask | None :
35
41
"""Perform some checks."""
36
42
__tracebackhide__ = True
37
43
38
44
if (
39
45
(name .startswith ("task_" ) or has_mark (obj , "task" ))
40
- and callable (obj )
46
+ and is_task_function (obj )
41
47
and has_mark (obj , "r" )
42
48
):
49
+ # Parse @pytask.mark.r decorator.
43
50
obj , marks = remove_marks (obj , "r" )
44
-
45
51
if len (marks ) > 1 :
46
52
raise ValueError (
47
53
f"Task { name !r} has multiple @pytask.mark.r marks, but only one is "
@@ -54,40 +60,97 @@ def pytask_collect_task(
54
60
default_serializer = session .config ["r_serializer" ],
55
61
default_suffix = session .config ["r_suffix" ],
56
62
)
57
- script , options , _ , _ = r (** marks [0 ].kwargs )
63
+ script , options , _ , suffix = r (** marks [0 ].kwargs )
58
64
59
65
obj .pytask_meta .markers .append (mark )
60
66
61
- dependencies = parse_nodes ( session , path , name , obj , depends_on )
62
- products = parse_nodes ( session , path , name , obj , produces )
67
+ # Collect the nodes in @pytask.mark.julia and validate them.
68
+ path_nodes = Path . cwd () if path is None else path . parent
63
69
64
- markers = obj .pytask_meta .markers if hasattr (obj , "pytask_meta" ) else []
65
- kwargs = obj .pytask_meta .kwargs if hasattr (obj , "pytask_meta" ) else {}
66
-
67
- task = Task (
68
- base_name = name ,
69
- path = path ,
70
- function = _copy_func (run_r_script ), # type: ignore[arg-type]
71
- depends_on = dependencies ,
72
- produces = products ,
73
- markers = markers ,
74
- kwargs = kwargs ,
75
- )
70
+ if isinstance (script , str ):
71
+ warnings .warn (
72
+ "Passing a string to the @pytask.mark.r parameter 'script' is "
73
+ "deprecated. Please, use a pathlib.Path instead." ,
74
+ stacklevel = 1 ,
75
+ )
76
+ script = Path (script )
76
77
77
78
script_node = session .hook .pytask_collect_node (
78
- session = session , path = path , node = script
79
+ session = session ,
80
+ path = path_nodes ,
81
+ node_info = NodeInfo (
82
+ arg_name = "_script" ,
83
+ path = (),
84
+ value = script ,
85
+ task_path = path ,
86
+ task_name = name ,
87
+ ),
88
+ )
89
+
90
+ if not (isinstance (script_node , PathNode ) and script_node .path .suffix == ".r" ):
91
+ raise ValueError (
92
+ "The 'script' keyword of the @pytask.mark.r decorator must point "
93
+ f"to Julia file with the .r suffix, but it is { script_node } ."
94
+ )
95
+
96
+ options_node = session .hook .pytask_collect_node (
97
+ session = session ,
98
+ path = path_nodes ,
99
+ node_info = NodeInfo (
100
+ arg_name = "_options" ,
101
+ path = (),
102
+ value = options ,
103
+ task_path = path ,
104
+ task_name = name ,
105
+ ),
106
+ )
107
+
108
+ dependencies = parse_dependencies_from_task_function (
109
+ session , path , name , path_nodes , obj
110
+ )
111
+ products = parse_products_from_task_function (
112
+ session , path , name , path_nodes , obj
79
113
)
80
114
81
- if isinstance (task .depends_on , dict ):
82
- task .depends_on [R_SCRIPT_KEY ] = script_node
83
- task .attributes ["r_keep_dict" ] = True
115
+ # Add script
116
+ dependencies ["_script" ] = script_node
117
+ dependencies ["_options" ] = options_node
118
+
119
+ markers = obj .pytask_meta .markers if hasattr (obj , "pytask_meta" ) else []
120
+
121
+ task : PTask
122
+ if path is None :
123
+ task = TaskWithoutPath (
124
+ name = name ,
125
+ function = run_r_script ,
126
+ depends_on = dependencies ,
127
+ produces = products ,
128
+ markers = markers ,
129
+ )
84
130
else :
85
- task .depends_on = {0 : task .depends_on , R_SCRIPT_KEY : script_node }
86
- task .attributes ["r_keep_dict" ] = False
131
+ task = Task (
132
+ base_name = name ,
133
+ path = path ,
134
+ function = run_r_script ,
135
+ depends_on = dependencies ,
136
+ produces = products ,
137
+ markers = markers ,
138
+ )
87
139
88
- task .function = functools .partial (
89
- task .function , script = task .depends_on [R_SCRIPT_KEY ].path , options = options
140
+ # Add serialized node that depends on the task id.
141
+ serialized = create_path_to_serialized (task , suffix ) # type: ignore[arg-type]
142
+ serialized_node = session .hook .pytask_collect_node (
143
+ session = session ,
144
+ path = path_nodes ,
145
+ node_info = NodeInfo (
146
+ arg_name = "_serialized" ,
147
+ path = (),
148
+ value = PythonNode (value = serialized ),
149
+ task_path = path ,
150
+ task_name = name ,
151
+ ),
90
152
)
153
+ task .depends_on ["_serialized" ] = serialized_node
91
154
92
155
return task
93
156
return None
@@ -120,28 +183,3 @@ def _parse_r_mark(
120
183
121
184
mark = Mark ("r" , (), parsed_kwargs )
122
185
return mark
123
-
124
-
125
- def _copy_func (func : FunctionType ) -> FunctionType :
126
- """Create a copy of a function.
127
-
128
- Based on https://stackoverflow.com/a/13503277/7523785.
129
-
130
- Example
131
- -------
132
- >>> def _func(): pass
133
- >>> copied_func = _copy_func(_func)
134
- >>> _func is copied_func
135
- False
136
-
137
- """
138
- new_func = FunctionType (
139
- func .__code__ ,
140
- func .__globals__ ,
141
- name = func .__name__ ,
142
- argdefs = func .__defaults__ ,
143
- closure = func .__closure__ ,
144
- )
145
- new_func = functools .update_wrapper (new_func , func )
146
- new_func .__kwdefaults__ = func .__kwdefaults__
147
- return new_func
0 commit comments