4
4
5
5
import functools
6
6
import subprocess
7
- from types import FunctionType
8
- from typing import TYPE_CHECKING
7
+ import warnings
8
+ from pathlib import Path
9
9
from typing import Any
10
10
11
11
from pytask import Mark
12
+ from pytask import NodeInfo
13
+ from pytask import PathNode
14
+ from pytask import PTask
15
+ from pytask import PythonNode
12
16
from pytask import Session
13
17
from pytask import Task
14
- from pytask import depends_on
18
+ from pytask import TaskWithoutPath
15
19
from pytask import has_mark
16
20
from pytask import hookimpl
17
- from pytask import parse_nodes
18
- from pytask import produces
21
+ from pytask import is_task_function
22
+ from pytask import parse_dependencies_from_task_function
23
+ from pytask import parse_products_from_task_function
19
24
from pytask import remove_marks
20
25
21
26
from pytask_stata .shared import convert_task_id_to_name_of_log_file
22
27
from pytask_stata .shared import stata
23
28
24
- if TYPE_CHECKING :
25
- from pathlib import Path
26
-
27
29
28
30
def run_stata_script (
29
- executable : str , script : Path , options : list [str ], log_name : list [str ], cwd : Path
31
+ _executable : str ,
32
+ _script : Path ,
33
+ _options : list [str ],
34
+ _log_name : str ,
35
+ _cwd : Path ,
30
36
) -> None :
31
37
"""Run an R script."""
32
- cmd = [executable , "-e" , "do" , script .as_posix (), * options , * log_name ]
38
+ cmd = [_executable , "-e" , "do" , _script .as_posix (), * _options , f"- { _log_name } " ]
33
39
print ("Executing " + " " .join (cmd ) + "." ) # noqa: T201
34
- subprocess .run (cmd , cwd = cwd , check = True ) # noqa: S603
40
+ subprocess .run (cmd , cwd = _cwd , check = True ) # noqa: S603
35
41
36
42
37
43
@hookimpl
@@ -43,11 +49,11 @@ def pytask_collect_task(
43
49
44
50
if (
45
51
(name .startswith ("task_" ) or has_mark (obj , "task" ))
46
- and callable (obj )
52
+ and is_task_function (obj )
47
53
and has_mark (obj , "stata" )
48
54
):
55
+ # Parse the @pytask.mark.stata decorator.
49
56
obj , marks = remove_marks (obj , "stata" )
50
-
51
57
if len (marks ) > 1 :
52
58
msg = (
53
59
f"Task { name !r} has multiple @pytask.mark.stata marks, but only one is "
@@ -57,50 +63,123 @@ def pytask_collect_task(
57
63
58
64
mark = _parse_stata_mark (mark = marks [0 ])
59
65
script , options = stata (** marks [0 ].kwargs )
60
-
61
66
obj .pytask_meta .markers .append (mark )
62
67
63
- dependencies = parse_nodes ( session , path , name , obj , depends_on )
64
- products = parse_nodes ( session , path , name , obj , produces )
68
+ # Collect the nodes in @pytask.mark.julia and validate them.
69
+ path_nodes = Path . cwd () if path is None else path . parent
65
70
66
- markers = obj .pytask_meta .markers if hasattr (obj , "pytask_meta" ) else []
67
- kwargs = obj .pytask_meta .kwargs if hasattr (obj , "pytask_meta" ) else {}
68
-
69
- task = Task (
70
- base_name = name ,
71
- path = path ,
72
- function = _copy_func (run_stata_script ), # type: ignore[arg-type]
73
- depends_on = dependencies ,
74
- produces = products ,
75
- markers = markers ,
76
- kwargs = kwargs ,
77
- )
71
+ if isinstance (script , str ):
72
+ warnings .warn (
73
+ "Passing a string to the @pytask.mark.stata parameter 'script' is "
74
+ "deprecated. Please, use a pathlib.Path instead." ,
75
+ stacklevel = 1 ,
76
+ )
77
+ script = Path (script )
78
78
79
79
script_node = session .hook .pytask_collect_node (
80
- session = session , path = path , node = script
80
+ session = session ,
81
+ path = path_nodes ,
82
+ node_info = NodeInfo (
83
+ arg_name = "script" , path = (), value = script , task_path = path , task_name = name
84
+ ),
81
85
)
82
86
83
- if isinstance (task .depends_on , dict ):
84
- task .depends_on ["__script" ] = script_node
87
+ if not (isinstance (script_node , PathNode ) and script_node .path .suffix == ".do" ):
88
+ msg = (
89
+ "The 'script' keyword of the @pytask.mark.stata decorator must point "
90
+ f"to a file with the .do suffix, but it is { script_node } ."
91
+ )
92
+ raise ValueError (msg )
93
+
94
+ options_node = session .hook .pytask_collect_node (
95
+ session = session ,
96
+ path = path_nodes ,
97
+ node_info = NodeInfo (
98
+ arg_name = "_options" ,
99
+ path = (),
100
+ value = options ,
101
+ task_path = path ,
102
+ task_name = name ,
103
+ ),
104
+ )
105
+
106
+ executable_node = session .hook .pytask_collect_node (
107
+ session = session ,
108
+ path = path_nodes ,
109
+ node_info = NodeInfo (
110
+ arg_name = "_executable" ,
111
+ path = (),
112
+ value = session .config ["stata" ],
113
+ task_path = path ,
114
+ task_name = name ,
115
+ ),
116
+ )
117
+
118
+ cwd_node = session .hook .pytask_collect_node (
119
+ session = session ,
120
+ path = path_nodes ,
121
+ node_info = NodeInfo (
122
+ arg_name = "_cwd" ,
123
+ path = (),
124
+ value = path .parent .as_posix (),
125
+ task_path = path ,
126
+ task_name = name ,
127
+ ),
128
+ )
129
+
130
+ dependencies = parse_dependencies_from_task_function (
131
+ session , path , name , path_nodes , obj
132
+ )
133
+ products = parse_products_from_task_function (
134
+ session , path , name , path_nodes , obj
135
+ )
136
+
137
+ # Add script
138
+ dependencies ["_script" ] = script_node
139
+ dependencies ["_options" ] = options_node
140
+ dependencies ["_cwd" ] = cwd_node
141
+ dependencies ["_executable" ] = executable_node
142
+
143
+ partialed = functools .partial (run_stata_script , _cwd = path .parent )
144
+ markers = obj .pytask_meta .markers if hasattr (obj , "pytask_meta" ) else []
145
+
146
+ task : PTask
147
+ if path is None :
148
+ task = TaskWithoutPath (
149
+ name = name ,
150
+ function = partialed ,
151
+ depends_on = dependencies ,
152
+ produces = products ,
153
+ markers = markers ,
154
+ )
85
155
else :
86
- task .depends_on = {0 : task .depends_on , "__script" : script_node }
156
+ task = Task (
157
+ base_name = name ,
158
+ path = path ,
159
+ function = partialed ,
160
+ depends_on = dependencies ,
161
+ produces = products ,
162
+ markers = markers ,
163
+ )
87
164
165
+ # Add log_name node that depends on the task id.
88
166
if session .config ["platform" ] == "win32" :
89
- log_name = convert_task_id_to_name_of_log_file (task .short_name )
90
- log_name_arg = [f"-{ log_name } " ]
167
+ log_name = convert_task_id_to_name_of_log_file (task )
91
168
else :
92
- log_name_arg = []
93
-
94
- stata_function = functools .partial (
95
- task .function ,
96
- executable = session .config ["stata" ],
97
- script = task .depends_on ["__script" ].path ,
98
- options = options ,
99
- log_name = log_name_arg ,
100
- cwd = task .path .parent ,
169
+ log_name = ""
170
+
171
+ log_name_node = session .hook .pytask_collect_node (
172
+ session = session ,
173
+ path = path_nodes ,
174
+ node_info = NodeInfo (
175
+ arg_name = "_log_name" ,
176
+ path = (),
177
+ value = PythonNode (value = log_name ),
178
+ task_path = path ,
179
+ task_name = name ,
180
+ ),
101
181
)
102
-
103
- task .function = stata_function
182
+ task .depends_on ["_log_name" ] = log_name_node
104
183
105
184
return task
106
185
return None
@@ -109,32 +188,5 @@ def pytask_collect_task(
109
188
def _parse_stata_mark (mark : Mark ) -> Mark :
110
189
"""Parse a Stata mark."""
111
190
script , options = stata (** mark .kwargs )
112
-
113
191
parsed_kwargs = {"script" : script or None , "options" : options or []}
114
-
115
192
return Mark ("stata" , (), parsed_kwargs )
116
-
117
-
118
- def _copy_func (func : FunctionType ) -> FunctionType :
119
- """Create a copy of a function.
120
-
121
- Based on https://stackoverflow.com/a/13503277/7523785.
122
-
123
- Example
124
- -------
125
- >>> def _func(): pass
126
- >>> copied_func = _copy_func(_func)
127
- >>> _func is copied_func
128
- False
129
-
130
- """
131
- new_func = FunctionType (
132
- func .__code__ ,
133
- func .__globals__ ,
134
- name = func .__name__ ,
135
- argdefs = func .__defaults__ ,
136
- closure = func .__closure__ ,
137
- )
138
- new_func = functools .update_wrapper (new_func , func )
139
- new_func .__kwdefaults__ = func .__kwdefaults__
140
- return new_func
0 commit comments