2
2
import copy
3
3
import functools
4
4
import subprocess
5
- from pathlib import Path
6
5
from typing import Iterable
7
6
from typing import Optional
7
+ from typing import Sequence
8
8
from typing import Union
9
9
10
10
from _pytask .config import hookimpl
14
14
from _pytask .nodes import PythonFunctionTask
15
15
from _pytask .parametrize import _copy_func
16
16
from pytask_stata .shared import convert_task_id_to_name_of_log_file
17
+ from pytask_stata .shared import get_node_from_dictionary
17
18
18
19
19
20
def stata (options : Optional [Union [str , Iterable [str ]]] = None ):
@@ -25,16 +26,15 @@ def stata(options: Optional[Union[str, Iterable[str]]] = None):
25
26
One or multiple command line options passed to Stata.
26
27
27
28
"""
28
- if options is None :
29
- options = []
30
- elif isinstance (options , str ):
31
- options = [options ]
29
+ options = _to_list (options ) if options is not None else []
30
+ options = [str (i ) for i in options ]
32
31
return options
33
32
34
33
35
- def run_stata_script (stata ):
34
+ def run_stata_script (stata , cwd ):
36
35
"""Run an R script."""
37
- subprocess .run (stata , check = True )
36
+ print ("Executing " + " " .join (stata ) + "." ) # noqa: T001
37
+ subprocess .run (stata , cwd = cwd , check = True )
38
38
39
39
40
40
@hookimpl
@@ -58,7 +58,7 @@ def pytask_collect_task(session, path, name, obj):
58
58
def pytask_collect_task_teardown (session , task ):
59
59
"""Perform some checks and prepare the task function."""
60
60
if get_specific_markers_from_task (task , "stata" ):
61
- source = _get_node_from_dictionary (
61
+ source = get_node_from_dictionary (
62
62
task .depends_on , session .config ["stata_source_key" ]
63
63
)
64
64
if not (isinstance (source , FilePathNode ) and source .value .suffix == ".do" ):
@@ -72,19 +72,13 @@ def pytask_collect_task_teardown(session, task):
72
72
merged_marks = _merge_all_markers (task )
73
73
args = stata (* merged_marks .args , ** merged_marks .kwargs )
74
74
options = _prepare_cmd_options (session , task , args )
75
- stata_function = functools .partial (stata_function , stata = options )
75
+ stata_function = functools .partial (
76
+ stata_function , stata = options , cwd = task .path .parent
77
+ )
76
78
77
79
task .function = stata_function
78
80
79
81
80
- def _get_node_from_dictionary (obj , key , fallback = 0 ):
81
- if isinstance (obj , Path ):
82
- pass
83
- elif isinstance (obj , dict ):
84
- obj = obj .get (key ) or obj .get (fallback )
85
- return obj
86
-
87
-
88
82
def _merge_all_markers (task ):
89
83
"""Combine all information from markers for the Stata function."""
90
84
stata_marks = get_specific_markers_from_task (task , "stata" )
@@ -101,15 +95,45 @@ def _prepare_cmd_options(session, task, args):
101
95
is unique and does not cause any errors when parallelizing the execution.
102
96
103
97
"""
104
- source = _get_node_from_dictionary (
98
+ source = get_node_from_dictionary (
105
99
task .depends_on , session .config ["stata_source_key" ]
106
100
)
107
- log_name = convert_task_id_to_name_of_log_file ( task . name )
108
- return [
101
+
102
+ cmd_options = [
109
103
session .config ["stata" ],
110
104
"-e" ,
111
105
"do" ,
112
- source .value .as_posix (),
106
+ source .path .as_posix (),
113
107
* args ,
114
- f"-{ log_name } " ,
115
108
]
109
+ if session .config ["platform" ] == "win32" :
110
+ log_name = convert_task_id_to_name_of_log_file (task .name )
111
+ cmd_options .append (f"-{ log_name } " )
112
+
113
+ return cmd_options
114
+
115
+
116
+ def _to_list (scalar_or_iter ):
117
+ """Convert scalars and iterables to list.
118
+
119
+ Parameters
120
+ ----------
121
+ scalar_or_iter : str or list
122
+
123
+ Returns
124
+ -------
125
+ list
126
+
127
+ Examples
128
+ --------
129
+ >>> _to_list("a")
130
+ ['a']
131
+ >>> _to_list(["b"])
132
+ ['b']
133
+
134
+ """
135
+ return (
136
+ [scalar_or_iter ]
137
+ if isinstance (scalar_or_iter , str ) or not isinstance (scalar_or_iter , Sequence )
138
+ else list (scalar_or_iter )
139
+ )
0 commit comments