22import copy
33import functools
44import subprocess
5- from pathlib import Path
65from typing import Iterable
76from typing import Optional
7+ from typing import Sequence
88from typing import Union
99
1010from _pytask .config import hookimpl
1414from _pytask .nodes import PythonFunctionTask
1515from _pytask .parametrize import _copy_func
1616from pytask_stata .shared import convert_task_id_to_name_of_log_file
17+ from pytask_stata .shared import get_node_from_dictionary
1718
1819
1920def stata (options : Optional [Union [str , Iterable [str ]]] = None ):
@@ -25,16 +26,15 @@ def stata(options: Optional[Union[str, Iterable[str]]] = None):
2526 One or multiple command line options passed to Stata.
2627
2728 """
28- if options is None :
29- options = []
30- elif isinstance (options , str ):
31- options = [options ]
29+ options = _to_list (options ) if options is not None else []
30+ options = [str (i ) for i in options ]
3231 return options
3332
3433
35- def run_stata_script (stata ):
34+ def run_stata_script (stata , cwd ):
3635 """Run an R script."""
37- subprocess .run (stata , check = True )
36+ print ("Executing " + " " .join (stata ) + "." ) # noqa: T001
37+ subprocess .run (stata , cwd = cwd , check = True )
3838
3939
4040@hookimpl
@@ -58,7 +58,7 @@ def pytask_collect_task(session, path, name, obj):
5858def pytask_collect_task_teardown (session , task ):
5959 """Perform some checks and prepare the task function."""
6060 if get_specific_markers_from_task (task , "stata" ):
61- source = _get_node_from_dictionary (
61+ source = get_node_from_dictionary (
6262 task .depends_on , session .config ["stata_source_key" ]
6363 )
6464 if not (isinstance (source , FilePathNode ) and source .value .suffix == ".do" ):
@@ -72,19 +72,13 @@ def pytask_collect_task_teardown(session, task):
7272 merged_marks = _merge_all_markers (task )
7373 args = stata (* merged_marks .args , ** merged_marks .kwargs )
7474 options = _prepare_cmd_options (session , task , args )
75- stata_function = functools .partial (stata_function , stata = options )
75+ stata_function = functools .partial (
76+ stata_function , stata = options , cwd = task .path .parent
77+ )
7678
7779 task .function = stata_function
7880
7981
80- def _get_node_from_dictionary (obj , key , fallback = 0 ):
81- if isinstance (obj , Path ):
82- pass
83- elif isinstance (obj , dict ):
84- obj = obj .get (key ) or obj .get (fallback )
85- return obj
86-
87-
8882def _merge_all_markers (task ):
8983 """Combine all information from markers for the Stata function."""
9084 stata_marks = get_specific_markers_from_task (task , "stata" )
@@ -101,15 +95,45 @@ def _prepare_cmd_options(session, task, args):
10195 is unique and does not cause any errors when parallelizing the execution.
10296
10397 """
104- source = _get_node_from_dictionary (
98+ source = get_node_from_dictionary (
10599 task .depends_on , session .config ["stata_source_key" ]
106100 )
107- log_name = convert_task_id_to_name_of_log_file ( task . name )
108- return [
101+
102+ cmd_options = [
109103 session .config ["stata" ],
110104 "-e" ,
111105 "do" ,
112- source .value .as_posix (),
106+ source .path .as_posix (),
113107 * args ,
114- f"-{ log_name } " ,
115108 ]
109+ if session .config ["platform" ] == "win32" :
110+ log_name = convert_task_id_to_name_of_log_file (task .name )
111+ cmd_options .append (f"-{ log_name } " )
112+
113+ return cmd_options
114+
115+
116+ def _to_list (scalar_or_iter ):
117+ """Convert scalars and iterables to list.
118+
119+ Parameters
120+ ----------
121+ scalar_or_iter : str or list
122+
123+ Returns
124+ -------
125+ list
126+
127+ Examples
128+ --------
129+ >>> _to_list("a")
130+ ['a']
131+ >>> _to_list(["b"])
132+ ['b']
133+
134+ """
135+ return (
136+ [scalar_or_iter ]
137+ if isinstance (scalar_or_iter , str ) or not isinstance (scalar_or_iter , Sequence )
138+ else list (scalar_or_iter )
139+ )
0 commit comments