diff --git a/docs/source/changes.md b/docs/source/changes.md index fcbcf708..8106610a 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -70,6 +70,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`485` adds missing steps to unconfigure pytask after the job is done, which caused flaky tests. - {pull}`486` adds default names to {class}`~pytask.PPathNode`. +- {pull}`487` implements task generators and provisional nodes. - {pull}`488` raises an error when an invalid value is used in a return annotation. - {pull}`489` and {pull}`491` simplifies parsing products and does not raise an error when a product annotation is used with the argument name `produces`. And allow diff --git a/docs/source/how_to_guides/index.md b/docs/source/how_to_guides/index.md index 0babca6a..8f0e9f47 100644 --- a/docs/source/how_to_guides/index.md +++ b/docs/source/how_to_guides/index.md @@ -19,6 +19,7 @@ capture_warnings how_to_influence_build_order hashing_inputs_of_tasks using_task_returns +provisional_nodes_and_task_generators writing_custom_nodes extending_pytask the_data_catalog diff --git a/docs/source/how_to_guides/provisional_nodes_and_task_generators.md b/docs/source/how_to_guides/provisional_nodes_and_task_generators.md new file mode 100644 index 00000000..d5530eee --- /dev/null +++ b/docs/source/how_to_guides/provisional_nodes_and_task_generators.md @@ -0,0 +1,94 @@ +# Provisional nodes and task generators + +pytask's execution model can usually be separated into three phases. + +1. Collection of tasks, dependencies, and products. +1. Building the DAG. +1. Executing the tasks. + +But, in some situations, pytask needs to be more flexible. + +Imagine you want to download a folder with files from an online storage. Before the task +is completed you do not know the total number of files or their filenames. How can you +still describe the files as products of the task? + +And how would you define another task that depends on these files? + +The following sections will explain how you use pytask in these situations. + +## Producing provisional nodes + +As an example for the aforementioned scenario, let us write a task that downloads all +files without a file extension from the root folder of the pytask GitHub repository. The +files are downloaded to a folder called `downloads`. `downloads` is in the same folder +as the task module because it is a relative path. + +```{literalinclude} ../../../docs_src/how_to_guides/provisional_products.py +--- +emphasize-lines: 4, 22 +--- +``` + +Since the names of the files are not known when pytask is started, we need to use a +{class}`~pytask.DirectoryNode` to define the task's product. With a +{class}`~pytask.DirectoryNode` we can specify where pytask can find the files. The files +are described with a root path (default is the directory of the task module) and a glob +pattern (default is `*`). + +When we use the {class}`~pytask.DirectoryNode` as a product annotation, we get access to +the `root_dir` as a {class}`~pathlib.Path` object inside the function, which allows us +to store the files. + +```{note} +The {class}`~pytask.DirectoryNode` is a provisional node that implements +{class}`~pytask.PProvisionalNode`. A provisional node is not a {class}`~pytask.PNode`, +but when its {meth}`~pytask.PProvisionalNode.collect` method is called, it returns +actual nodes. A {class}`~pytask.DirectoryNode`, for example, returns +{class}`~pytask.PathNode`. +``` + +## Depending on provisional nodes + +In the next step, we want to define a task that consumes and merges all previously +downloaded files into one file. + +The difficulty here is how can we reference the downloaded files before they have been +downloaded. + +```{literalinclude} ../../../docs_src/how_to_guides/provisional_task.py +--- +emphasize-lines: 9 +--- +``` + +To reference the files that will be downloaded, we use the +{class}`~pytask.DirectoryNode` is a dependency. Before the task is executed, the list of +files in the folder defined by the root path and the pattern are automatically collected +and passed to the task. + +If we use a {class}`~pytask.DirectoryNode` with the same `root_dir` and `pattern` in +both tasks, pytask will automatically recognize that the second task depends on the +first. If that is not true, you might need to make this dependency more explicit by +using {func}`@task(after=...) `, which is explained {ref}`here `. + +## Task generators + +What if we wanted to process each downloaded file separately instead of dealing with +them in one task? + +For that, we have to write a task generator to define an unknown number of tasks for an +unknown number of downloaded files. + +A task generator is a task function in which we define more tasks, just as if we were +writing functions in a task module. + +The code snippet shows each task takes one of the downloaded files and copies its +content to a `.txt` file. + +```{literalinclude} ../../../docs_src/how_to_guides/provisional_task_generator.py +``` + +```{important} +The generated tasks need to be decoratored with {func}`@task ` to be +collected. +``` diff --git a/docs/source/reference_guides/api.md b/docs/source/reference_guides/api.md index b70d8e88..22872173 100644 --- a/docs/source/reference_guides/api.md +++ b/docs/source/reference_guides/api.md @@ -190,6 +190,8 @@ Protocols define how tasks and nodes for dependencies and products have to be se :show-inheritance: .. autoprotocol:: pytask.PTaskWithPath :show-inheritance: +.. autoprotocol:: pytask.PProvisionalNode + :show-inheritance: ``` ## Nodes @@ -203,6 +205,8 @@ Nodes are the interface for different kinds of dependencies or products. :members: .. autoclass:: pytask.PythonNode :members: +.. autoclass:: pytask.DirectoryNode + :members: ``` To parse dependencies and products from nodes, use the following functions. diff --git a/docs/source/tutorials/skipping_tasks.md b/docs/source/tutorials/skipping_tasks.md index e223b1cd..1ad3b38c 100644 --- a/docs/source/tutorials/skipping_tasks.md +++ b/docs/source/tutorials/skipping_tasks.md @@ -13,18 +13,18 @@ skip tasks during development that take too much time to compute right now. ```{literalinclude} ../../../docs_src/tutorials/skipping_tasks_example_1.py ``` -Not only will this task be skipped, but all tasks that depend on +Not only will this task be skipped, but all tasks depending on `time_intensive_product.pkl`. ## Conditional skipping In large projects, you may have many long-running tasks that you only want to execute on -a remote server but not when you are not working locally. +a remote server, but not when you are not working locally. In this case, use the {func}`@pytask.mark.skipif ` decorator, which requires a condition and a reason as arguments. -Place the condition variable in a different module than the task, so you can change it +Place the condition variable in a module different from the task so you can change it without causing a rerun of the task. ```python diff --git a/docs_src/how_to_guides/provisional_products.py b/docs_src/how_to_guides/provisional_products.py new file mode 100644 index 00000000..269c327a --- /dev/null +++ b/docs_src/how_to_guides/provisional_products.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import httpx +from pytask import DirectoryNode +from pytask import Product +from typing_extensions import Annotated + + +def get_files_without_file_extensions_from_repo() -> list[str]: + url = "https://api.github.com/repos/pytask-dev/pytask/git/trees/main" + response = httpx.get(url) + elements = response.json()["tree"] + return [ + e["path"] + for e in elements + if e["type"] == "blob" and Path(e["path"]).suffix == "" + ] + + +def task_download_files( + download_folder: Annotated[ + Path, DirectoryNode(root_dir=Path("downloads"), pattern="*"), Product + ], +) -> None: + """Download files.""" + # Contains names like CITATION or LICENSE. + files_to_download = get_files_without_file_extensions_from_repo() + + for file_ in files_to_download: + url = "raw.githubusercontent.com/pytask-dev/pytask/main" + response = httpx.get(url=f"{url}/{file_}", timeout=5) + content = response.text + download_folder.joinpath(file_).write_text(content) diff --git a/docs_src/how_to_guides/provisional_task.py b/docs_src/how_to_guides/provisional_task.py new file mode 100644 index 00000000..1f158eef --- /dev/null +++ b/docs_src/how_to_guides/provisional_task.py @@ -0,0 +1,14 @@ +from pathlib import Path + +from pytask import DirectoryNode +from typing_extensions import Annotated + + +def task_merge_files( + paths: Annotated[ + list[Path], DirectoryNode(root_dir=Path("downloads"), pattern="*") + ], +) -> Annotated[str, Path("all_text.txt")]: + """Merge files.""" + contents = [path.read_text() for path in paths] + return "\n".join(contents) diff --git a/docs_src/how_to_guides/provisional_task_generator.py b/docs_src/how_to_guides/provisional_task_generator.py new file mode 100644 index 00000000..435643a7 --- /dev/null +++ b/docs_src/how_to_guides/provisional_task_generator.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from pytask import DirectoryNode +from pytask import task +from typing_extensions import Annotated + + +@task(is_generator=True) +def task_copy_files( + paths: Annotated[ + list[Path], DirectoryNode(root_dir=Path("downloads"), pattern="*") + ], +) -> None: + """Create tasks to copy each file to a ``.txt`` file.""" + for path in paths: + # The path of the copy will be CITATION.txt, for example. + path_to_copy = path.with_suffix(".txt") + + @task + def copy_file(path: Annotated[Path, path]) -> Annotated[str, path_to_copy]: + return path.read_text() diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 804c06bf..1b1f9a2e 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -30,7 +30,9 @@ from _pytask.mark_utils import has_mark from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask +from _pytask.nodes import DirectoryNode from _pytask.nodes import PathNode from _pytask.nodes import PythonNode from _pytask.nodes import Task @@ -299,6 +301,8 @@ def pytask_collect_task( raise ValueError(msg) path_nodes = Path.cwd() if path is None else path.parent + + # Collect dependencies and products. dependencies = parse_dependencies_from_task_function( session, path, name, path_nodes, obj ) @@ -309,6 +313,9 @@ def pytask_collect_task( markers = get_all_marks(obj) collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else [] + is_generator = ( + obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False + ) # Get the underlying function to avoid having different states of the function, # e.g. due to pytask_meta, in different layers of the wrapping. @@ -321,7 +328,11 @@ def pytask_collect_task( depends_on=dependencies, produces=products, markers=markers, - attributes={"collection_id": collection_id, "after": after}, + attributes={ + "collection_id": collection_id, + "after": after, + "is_generator": is_generator, + }, ) return Task( base_name=name, @@ -330,41 +341,25 @@ def pytask_collect_task( depends_on=dependencies, produces=products, markers=markers, - attributes={"collection_id": collection_id, "after": after}, + attributes={ + "collection_id": collection_id, + "after": after, + "is_generator": is_generator, + }, ) if isinstance(obj, PTask): return obj return None -_TEMPLATE_ERROR: str = """\ -The provided path of the dependency/product is - -{} - -, but the path of the file on disk is - -{} - -Case-sensitive file systems would raise an error because the upper and lower case \ -format of the paths does not match. - -Please, align the names to ensure reproducibility on case-sensitive file systems \ -(often Linux or macOS) or disable this error with 'check_casing_of_paths = false' in \ -the pyproject.toml file. - -Hint: If parts of the path preceding your project directory are not properly \ -formatted, check whether you need to call `.resolve()` on `SRC`, `BLD` or other paths \ -created from the `__file__` attribute of a module. -""" - - _TEMPLATE_ERROR_DIRECTORY: str = """\ The path '{path}' points to a directory, although only files are allowed.""" @hookimpl(trylast=True) -def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PNode: # noqa: C901, PLR0912 +def pytask_collect_node( # noqa: C901, PLR0912 + session: Session, path: Path, node_info: NodeInfo +) -> PNode | PProvisionalNode: """Collect a node of a task as a :class:`pytask.PNode`. Strings are assumed to be paths. This might be a strict assumption, but since this @@ -384,6 +379,21 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN """ node = node_info.value + if isinstance(node, DirectoryNode): + if node.root_dir is None: + node.root_dir = path + if ( + not node.name + or node.name == node.root_dir.joinpath(node.pattern).as_posix() + ): + short_root_dir = shorten_path( + node.root_dir, session.config["paths"] or (session.config["root"],) + ) + node.name = Path(short_root_dir, node.pattern).as_posix() + + if isinstance(node, PProvisionalNode): + return node + if isinstance(node, PythonNode): node.node_info = node_info if not node.name: @@ -418,9 +428,11 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN raise ValueError(_TEMPLATE_ERROR_DIRECTORY.format(path=node.path)) if isinstance(node, PNode): + if not node.name: + node.name = create_name_of_python_node(node_info) return node - if isinstance(node, UPath): + if isinstance(node, UPath): # pragma: no cover if not node.protocol: node = Path(node) else: @@ -459,6 +471,28 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> PN return PythonNode(value=node, name=node_name, node_info=node_info) +_TEMPLATE_ERROR: str = """\ +The provided path of the dependency/product is + +{} + +, but the path of the file on disk is + +{} + +Case-sensitive file systems would raise an error because the upper and lower case \ +format of the paths does not match. + +Please, align the names to ensure reproducibility on case-sensitive file systems \ +(often Linux or macOS) or disable this error with 'check_casing_of_paths = false' in \ +your pytask configuration file. + +Hint: If parts of the path preceding your project directory are not properly \ +formatted, check whether you need to call `.resolve()` on `SRC`, `BLD` or other paths \ +created from the `__file__` attribute of a module. +""" + + def _raise_error_if_casing_of_path_is_wrong( path: Path, check_casing_of_paths: bool ) -> None: diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index 04ac4bd4..fc09bda7 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -15,8 +15,10 @@ from _pytask.exceptions import NodeNotCollectedError from _pytask.models import NodeInfo from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode from _pytask.nodes import PythonNode from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults +from _pytask.tree_util import PyTree from _pytask.tree_util import tree_leaves from _pytask.tree_util import tree_map_with_path from _pytask.typing import ProductType @@ -34,6 +36,7 @@ __all__ = [ + "collect_dependency", "parse_dependencies_from_task_function", "parse_products_from_task_function", ] @@ -65,6 +68,7 @@ def parse_dependencies_from_task_function( kwargs.pop("produces", None) parameters_with_product_annot = _find_args_with_product_annotation(obj) + parameters_with_product_annot.append("return") parameters_with_node_annot = _find_args_with_node_annotation(obj) # Complete kwargs with node annotations, when no value is given by kwargs. @@ -79,25 +83,16 @@ def parse_dependencies_from_task_function( raise ValueError(msg) for parameter_name, value in kwargs.items(): - if ( - parameter_name in parameters_with_product_annot - or parameter_name == "return" - ): + if parameter_name in parameters_with_product_annot: continue - nodes = tree_map_with_path( - lambda p, x: _collect_dependency( - session, - node_path, - task_name, - NodeInfo( - arg_name=parameter_name, # noqa: B023 - path=p, - value=x, - task_path=task_path, - task_name=task_name, - ), - ), + nodes = _collect_nodes_and_provisional_nodes( + collect_dependency, + session, + node_path, + task_name, + task_path, + parameter_name, value, ) @@ -106,7 +101,10 @@ def parse_dependencies_from_task_function( are_all_nodes_python_nodes_without_hash = all( isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes) ) - if not isinstance(nodes, PNode) and are_all_nodes_python_nodes_without_hash: + if ( + not isinstance(nodes, (PNode, PProvisionalNode)) + and are_all_nodes_python_nodes_without_hash + ): node_name = create_name_of_python_node( NodeInfo( arg_name=parameter_name, @@ -122,7 +120,9 @@ def parse_dependencies_from_task_function( return dependencies -def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, PNode]: +def _find_args_with_node_annotation( + func: Callable[..., Any], +) -> dict[str, PNode | PProvisionalNode]: """Find args with node annotations.""" annotations = get_annotations(func, eval_str=True) metas = { @@ -220,20 +220,13 @@ def parse_products_from_task_function( value = kwargs.get(parameter_name) or parameters_with_node_annot.get( parameter_name ) - - collected_products = tree_map_with_path( - lambda p, x: _collect_product( - session, - node_path, - task_name, - NodeInfo( - arg_name=parameter_name, # noqa: B023 - path=p, - value=x, - task_path=task_path, - task_name=task_name, - ), - ), + collected_products = _collect_nodes_and_provisional_nodes( + _collect_product, + session, + node_path, + task_name, + task_path, + parameter_name, value, ) out[parameter_name] = collected_products @@ -241,19 +234,13 @@ def parse_products_from_task_function( task_produces = obj.pytask_meta.produces if hasattr(obj, "pytask_meta") else None if task_produces: has_task_decorator = True - collected_products = tree_map_with_path( - lambda p, x: _collect_product( - session, - node_path, - task_name, - NodeInfo( - arg_name="return", - path=p, - value=x, - task_path=task_path, - task_name=task_name, - ), - ), + collected_products = _collect_nodes_and_provisional_nodes( + _collect_product, + session, + node_path, + task_name, + task_path, + "return", task_produces, ) out = {"return": collected_products} @@ -264,6 +251,32 @@ def parse_products_from_task_function( return out +def _collect_nodes_and_provisional_nodes( # noqa: PLR0913 + collection_func: Callable[..., Any], + session: Session, + node_path: Path, + task_name: str, + task_path: Path | None, + parameter_name: str, + value: Any, +) -> PyTree[PProvisionalNode | PNode]: + return tree_map_with_path( + lambda p, x: collection_func( + session, + node_path, + task_name, + NodeInfo( + arg_name=parameter_name, + path=p, + value=x, + task_path=task_path, + task_name=task_name, + ), + ), + value, + ) + + def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: """Find args with product annotations.""" annotations = get_annotations(func, eval_str=True) @@ -282,9 +295,9 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: return args_with_product_annot -def _collect_dependency( +def collect_dependency( session: Session, path: Path, name: str, node_info: NodeInfo -) -> PNode: +) -> PNode | PProvisionalNode: """Collect nodes for a task. Raises @@ -318,7 +331,7 @@ def _collect_product( path: Path, task_name: str, node_info: NodeInfo, -) -> PNode: +) -> PNode | PProvisionalNode: """Collect products for a task. Defining products with strings is only allowed when using the decorator. Parameter diff --git a/src/_pytask/console.py b/src/_pytask/console.py index 4bd1b04b..ecc0c384 100644 --- a/src/_pytask/console.py +++ b/src/_pytask/console.py @@ -28,6 +28,7 @@ from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTaskWithPath from _pytask.path import shorten_path @@ -139,7 +140,9 @@ def format_task_name(task: PTask, editor_url_scheme: str) -> Text: return Text(task.name, style=url_style) -def format_node_name(node: PNode, paths: Sequence[Path] = ()) -> Text: +def format_node_name( + node: PNode | PProvisionalNode, paths: Sequence[Path] = () +) -> Text: """Format the name of a node.""" if isinstance(node, PPathNode): if node.name != node.path.as_posix(): diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index 14d15401..d8c8b325 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -1,4 +1,4 @@ -"""Contains code related to resolving dependencies.""" +"""Contains code related to the DAG.""" from __future__ import annotations @@ -21,6 +21,7 @@ from _pytask.mark import select_by_after_keyword from _pytask.mark import select_tasks_by_marks_and_expressions from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask from _pytask.nodes import PythonNode from _pytask.reports import DagReport @@ -33,18 +34,13 @@ from _pytask.session import Session -__all__ = ["create_dag"] +__all__ = ["create_dag", "create_dag_from_session"] def create_dag(session: Session) -> nx.DiGraph: """Create a directed acyclic graph (DAG) for the workflow.""" try: - dag = _create_dag(tasks=session.tasks) - _check_if_dag_has_cycles(dag) - _check_if_tasks_have_the_same_products(dag, session.config["paths"]) - _modify_dag(session=session, dag=dag) - select_tasks_by_marks_and_expressions(session=session, dag=dag) - + dag = create_dag_from_session(session) except Exception: # noqa: BLE001 report = DagReport.from_exception(sys.exc_info()) _log_dag(report=report) @@ -54,10 +50,22 @@ def create_dag(session: Session) -> nx.DiGraph: return dag -def _create_dag(tasks: list[PTask]) -> nx.DiGraph: +def create_dag_from_session(session: Session) -> nx.DiGraph: + """Create a DAG from a session.""" + dag = _create_dag_from_tasks(tasks=session.tasks) + _check_if_dag_has_cycles(dag) + _check_if_tasks_have_the_same_products(dag, session.config["paths"]) + dag = _modify_dag(session=session, dag=dag) + select_tasks_by_marks_and_expressions(session=session, dag=dag) + return dag + + +def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph: """Create the DAG from tasks, dependencies and products.""" - def _add_dependency(dag: nx.DiGraph, task: PTask, node: PNode) -> None: + def _add_dependency( + dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode + ) -> None: """Add a dependency to the DAG.""" dag.add_node(node.signature, node=node) dag.add_edge(node.signature, task.signature) @@ -68,7 +76,9 @@ def _add_dependency(dag: nx.DiGraph, task: PTask, node: PNode) -> None: if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): dag.add_edge(node.value.signature, node.signature) - def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None: + def _add_product( + dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode + ) -> None: """Add a product to the DAG.""" dag.add_node(node.signature, node=node) dag.add_edge(task.signature, node.signature) @@ -93,7 +103,7 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None: return dag -def _modify_dag(session: Session, dag: nx.DiGraph) -> None: +def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph: """Create dependencies between tasks when using ``@task(after=...)``.""" temporary_id_to_task = { task.attributes["collection_id"]: task @@ -114,6 +124,7 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> None: for signature in signatures: for successor in dag.successors(signature): dag.add_edge(successor, task.signature) + return dag def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None: @@ -144,7 +155,7 @@ def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str: node = dag.nodes[x].get("task") or dag.nodes[x].get("node") if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain - elif isinstance(node, PNode): + elif isinstance(node, (PNode, PProvisionalNode)): short_name = node.name lines.extend((short_name, " " + ARROW_DOWN_ICON)) # Join while removing last arrow. diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index b8df2f26..e41d4f37 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -248,3 +248,5 @@ def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None: path.parent.mkdir(exist_ok=True, parents=True) graph = nx.nx_agraph.to_agraph(dag) graph.draw(path, prog=layout) + console.print() + console.print(f"Written to {path}.") diff --git a/src/_pytask/data_catalog.py b/src/_pytask/data_catalog.py index f59c2040..615fc021 100644 --- a/src/_pytask/data_catalog.py +++ b/src/_pytask/data_catalog.py @@ -20,6 +20,7 @@ from _pytask.models import NodeInfo from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.nodes import PickleNode from _pytask.pluginmanager import storage from _pytask.session import Session @@ -59,7 +60,7 @@ class DataCatalog: """ default_node: type[PNode] = PickleNode - entries: dict[str, PNode] = field(factory=dict) + entries: dict[str, PNode | PProvisionalNode] = field(factory=dict) name: str = "default" path: Path | None = None _session_config: dict[str, Any] = field( @@ -84,13 +85,13 @@ def _initialize(self) -> None: node = pickle.loads(path.read_bytes()) # noqa: S301 self.entries[node.name] = node - def __getitem__(self, name: str) -> PNode: + def __getitem__(self, name: str) -> PNode | PProvisionalNode: """Allow to access entries with the squared brackets syntax.""" if name not in self.entries: self.add(name) return self.entries[name] - def add(self, name: str, node: PNode | None = None) -> None: + def add(self, name: str, node: PNode | PProvisionalNode | None = None) -> None: """Add an entry to the data catalog.""" assert isinstance(self.path, Path) @@ -109,7 +110,7 @@ def add(self, name: str, node: PNode | None = None) -> None: self.path.joinpath(f"{filename}-node.pkl").write_bytes( pickle.dumps(self.entries[name]) ) - elif isinstance(node, PNode): + elif isinstance(node, (PNode, PProvisionalNode)): self.entries[name] = node else: # Acquire the latest pluginmanager. diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 310889fb..7ef1dd00 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -29,6 +29,7 @@ from _pytask.mark_utils import has_mark from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask from _pytask.outcomes import Exit from _pytask.outcomes import SkippedUnchanged @@ -36,11 +37,13 @@ from _pytask.outcomes import WouldBeExecuted from _pytask.outcomes import count_outcomes from _pytask.pluginmanager import hookimpl +from _pytask.provisional_utils import collect_provisional_products from _pytask.reports import ExecutionReport from _pytask.traceback import remove_traceback_from_exc_info from _pytask.tree_util import tree_leaves from _pytask.tree_util import tree_map from _pytask.tree_util import tree_structure +from _pytask.typing import is_task_generator if TYPE_CHECKING: from _pytask.session import Session @@ -115,7 +118,7 @@ def pytask_execute_task_protocol(session: Session, task: PTask) -> ExecutionRepo @hookimpl(trylast=True) -def pytask_execute_task_setup(session: Session, task: PTask) -> None: +def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C901 """Set up the execution of a task. 1. Check whether all dependencies of a task are available. @@ -127,14 +130,25 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: dag = session.dag - needs_to_be_executed = session.config["force"] + # Task generators are always executed since their states are not updated, but we + # skip the checks as well. + needs_to_be_executed = session.config["force"] or is_task_generator(task) + if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): node = dag.nodes[node_signature].get("task") or dag.nodes[ node_signature ].get("node") + + # Skip provisional nodes that are products since they do not have a state. + if node_signature not in predecessors and isinstance( + node, PProvisionalNode + ): + continue + node_state = node.state() + if node_signature in predecessors and not node_state: msg = f"{task.name!r} requires missing node {node.name!r}." if IS_FILE_SYSTEM_CASE_SENSITIVE: @@ -150,6 +164,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: break if not needs_to_be_executed: + collect_provisional_products(session, task) raise SkippedUnchanged # Create directory for product if it does not exist. Maybe this should be a `setup` @@ -160,7 +175,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: node.path.parent.mkdir(parents=True, exist_ok=True) -def _safe_load(node: PNode, task: PTask, is_product: bool) -> Any: +def _safe_load(node: PNode | PProvisionalNode, task: PTask, is_product: bool) -> Any: try: return node.load(is_product=is_product) except Exception as e: # noqa: BLE001 @@ -189,6 +204,7 @@ def pytask_execute_task(session: Session, task: PTask) -> bool: if "return" in task.produces: structure_out = tree_structure(out) structure_return = tree_structure(task.produces["return"]) + # strict must be false when none is leaf. if not structure_return.is_prefix(structure_out, strict=False): msg = ( @@ -201,14 +217,19 @@ def pytask_execute_task(session: Session, task: PTask) -> bool: nodes = tree_leaves(task.produces["return"]) values = structure_return.flatten_up_to(out) for node, value in zip(nodes, values): - node.save(value) + if not isinstance(node, PProvisionalNode): + node.save(value) return True -@hookimpl +@hookimpl(trylast=True) def pytask_execute_task_teardown(session: Session, task: PTask) -> None: - """Check if :class:`_pytask.nodes.PathNode` are produced by a task.""" + """Check if nodes are produced by a task.""" + if is_task_generator(task): + return + + collect_provisional_products(session, task) missing_nodes = [node for node in tree_leaves(task.produces) if not node.state()] if missing_nodes: paths = session.config["paths"] diff --git a/src/_pytask/hookspecs.py b/src/_pytask/hookspecs.py index 347af6d6..02ce0721 100644 --- a/src/_pytask/hookspecs.py +++ b/src/_pytask/hookspecs.py @@ -20,6 +20,7 @@ from _pytask.models import NodeInfo from _pytask.node_protocols import PNode + from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import TaskOutcome @@ -195,7 +196,7 @@ def pytask_collect_task_teardown(session: Session, task: PTask) -> None: @hookspec(firstresult=True) def pytask_collect_node( session: Session, path: Path, node_info: NodeInfo -) -> PNode | None: +) -> PNode | PProvisionalNode | None: """Collect a node which is a dependency or a product of a task.""" diff --git a/src/_pytask/models.py b/src/_pytask/models.py index 24a0559b..afd28eda 100644 --- a/src/_pytask/models.py +++ b/src/_pytask/models.py @@ -33,6 +33,8 @@ class CollectionMetadata: id will be generated. See :doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for more information. + is_generator + An indicator for whether a task generates other tasks or not. kwargs A dictionary containing keyword arguments which are passed to the task when it is executed. @@ -48,6 +50,7 @@ class CollectionMetadata: """ after: str | list[Callable[..., Any]] = field(factory=list) + is_generator: bool = False id_: str | None = None kwargs: dict[str, Any] = field(factory=dict) markers: list[Mark] = field(factory=list) @@ -59,6 +62,6 @@ class CollectionMetadata: class NodeInfo(NamedTuple): arg_name: str path: tuple[str | int, ...] - value: Any task_path: Path | None task_name: str + value: Any diff --git a/src/_pytask/node_protocols.py b/src/_pytask/node_protocols.py index 58a6fc50..6a1d8fc0 100644 --- a/src/_pytask/node_protocols.py +++ b/src/_pytask/node_protocols.py @@ -13,7 +13,7 @@ from _pytask.tree_util import PyTree -__all__ = ["PNode", "PPathNode", "PTask", "PTaskWithPath"] +__all__ = ["PNode", "PPathNode", "PProvisionalNode", "PTask", "PTaskWithPath"] @runtime_checkable @@ -67,8 +67,8 @@ class PTask(Protocol): """Protocol for nodes.""" name: str - depends_on: dict[str, PyTree[PNode]] - produces: dict[str, PyTree[PNode]] + depends_on: dict[str, PyTree[PNode | PProvisionalNode]] + produces: dict[str, PyTree[PNode | PProvisionalNode]] function: Callable[..., Any] markers: list[Mark] report_sections: list[tuple[str, str, str]] @@ -99,3 +99,42 @@ class PTaskWithPath(PTask, Protocol): """ path: Path + + +@runtime_checkable +class PProvisionalNode(Protocol): + """A protocol for provisional nodes. + + This type of nodes is provisional since it resolves to actual nodes, :class:`PNode`, + right before a task is executed as a dependency and after the task is executed as a + product. + + Provisional nodes are nodes that define how the actual nodes look like. They can be + useful when, for example, a task produces an unknown amount of nodes because it + downloads some files. + + """ + + name: str + + @property + def signature(self) -> str: + """Return the signature of the node.""" + + def load(self, is_product: bool = False) -> Any: # pragma: no cover + """Load a probisional node. + + A provisional node will never be loaded as a dependency since it would be + collected before. + + It is possible to load a provisional node as a dependency so that it can inject + basic information about it in the task. For example, + :meth:`pytask.DirectoryNode.load` injects the root directory. + + """ + if is_product: + ... + raise NotImplementedError + + def collect(self) -> list[Any]: + """Collect the objects that are defined by the provisional nodes.""" diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 80697feb..e7674e0f 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -19,6 +19,7 @@ from _pytask._hashlib import hash_value from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask from _pytask.node_protocols import PTaskWithPath from _pytask.path import hash_path @@ -31,7 +32,14 @@ from _pytask.tree_util import PyTree -__all__ = ["PathNode", "PickleNode", "PythonNode", "Task", "TaskWithoutPath"] +__all__ = [ + "DirectoryNode", + "PathNode", + "PickleNode", + "PythonNode", + "Task", + "TaskWithoutPath", +] @define(kw_only=True) @@ -63,8 +71,8 @@ class TaskWithoutPath(PTask): name: str function: Callable[..., Any] - depends_on: dict[str, PyTree[PNode]] = field(factory=dict) - produces: dict[str, PyTree[PNode]] = field(factory=dict) + depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(factory=dict) + produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(factory=dict) markers: list[Mark] = field(factory=list) report_sections: list[tuple[str, str, str]] = field(factory=list) attributes: dict[Any, Any] = field(factory=dict) @@ -117,8 +125,8 @@ class Task(PTaskWithPath): path: Path function: Callable[..., Any] name: str = field(default="", init=False) - depends_on: dict[str, PyTree[PNode]] = field(factory=dict) - produces: dict[str, PyTree[PNode]] = field(factory=dict) + depends_on: dict[str, PyTree[PNode | PProvisionalNode]] = field(factory=dict) + produces: dict[str, PyTree[PNode | PProvisionalNode]] = field(factory=dict) markers: list[Mark] = field(factory=list) report_sections: list[tuple[str, str, str]] = field(factory=list) attributes: dict[Any, Any] = field(factory=dict) @@ -323,6 +331,45 @@ def save(self, value: Any) -> None: pickle.dump(value, f) +@define(kw_only=True) +class DirectoryNode(PProvisionalNode): + """The class for a provisional node that works with directories. + + Attributes + ---------- + name + The name of the node. + pattern + Patterns are the same as for :mod:`fnmatch`, with the addition of ``**`` which + means "this directory and all subdirectories, recursively". + root_dir + The pattern is interpreted relative to the path given by ``root_dir``. If + ``root_dir = None``, it is the directory where the path is defined. + + """ + + name: str = "" + pattern: str = "*" + root_dir: Path | None = None + + @property + def signature(self) -> str: + """The unique signature of the node.""" + raw_key = "".join(str(hash_value(arg)) for arg in (self.root_dir, self.pattern)) + return hashlib.sha256(raw_key.encode()).hexdigest() + + def load(self, is_product: bool = False) -> Path: + """Inject a path into the task when loaded as a product.""" + if is_product: + return self.root_dir # type: ignore[return-value] + msg = "'DirectoryNode' cannot be loaded as a dependency" # pragma: no cover + raise NotImplementedError(msg) # pragma: no cover + + def collect(self) -> list[Path]: + """Collect paths defined by the pattern.""" + return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr] + + def _get_state(path: Path) -> str | None: """Get state of a path. diff --git a/src/_pytask/persist.py b/src/_pytask/persist.py index ebbfbca5..7cb272f0 100644 --- a/src/_pytask/persist.py +++ b/src/_pytask/persist.py @@ -12,6 +12,7 @@ from _pytask.outcomes import Persisted from _pytask.outcomes import TaskOutcome from _pytask.pluginmanager import hookimpl +from _pytask.provisional_utils import collect_provisional_products if TYPE_CHECKING: from _pytask.node_protocols import PTask @@ -60,6 +61,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: ) ) if any_node_changed: + collect_provisional_products(session, task) raise Persisted diff --git a/src/_pytask/pluginmanager.py b/src/_pytask/pluginmanager.py index 3159915c..4674c28b 100644 --- a/src/_pytask/pluginmanager.py +++ b/src/_pytask/pluginmanager.py @@ -46,6 +46,7 @@ def pytask_add_hooks(pm: PluginManager) -> None: "_pytask.dag_command", "_pytask.database", "_pytask.debugging", + "_pytask.provisional", "_pytask.execute", "_pytask.live", "_pytask.logging", diff --git a/src/_pytask/provisional.py b/src/_pytask/provisional.py new file mode 100644 index 00000000..084aabd2 --- /dev/null +++ b/src/_pytask/provisional.py @@ -0,0 +1,127 @@ +"""Contains hook implementations for provisional nodes and task generators.""" + +from __future__ import annotations + +import inspect +import sys +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Mapping + +from pytask import TaskOutcome + +from _pytask.config import hookimpl +from _pytask.exceptions import NodeLoadError +from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode +from _pytask.node_protocols import PTask +from _pytask.node_protocols import PTaskWithPath +from _pytask.outcomes import CollectionOutcome +from _pytask.provisional_utils import TASKS_WITH_PROVISIONAL_NODES +from _pytask.provisional_utils import collect_provisional_nodes +from _pytask.provisional_utils import recreate_dag +from _pytask.reports import ExecutionReport +from _pytask.task_utils import COLLECTED_TASKS +from _pytask.task_utils import parse_collected_tasks_with_task_marker +from _pytask.tree_util import tree_map +from _pytask.tree_util import tree_map_with_path +from _pytask.typing import is_task_generator + +if TYPE_CHECKING: + from _pytask.session import Session + + +@hookimpl +def pytask_execute_task_setup(session: Session, task: PTask) -> None: + """Collect provisional nodes and parse them.""" + task.depends_on = tree_map_with_path( # type: ignore[assignment] + lambda p, x: collect_provisional_nodes(session, task, x, p), task.depends_on + ) + if task.signature in TASKS_WITH_PROVISIONAL_NODES: + recreate_dag(session, task) + + +def _safe_load(node: PNode | PProvisionalNode, task: PTask, is_product: bool) -> Any: + try: + return node.load(is_product=is_product) + except Exception as e: # noqa: BLE001 + msg = f"Exception while loading node {node.name!r} of task {task.name!r}" + raise NodeLoadError(msg) from e + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> None: + """Execute task generators and collect the tasks.""" + if is_task_generator(task): + kwargs = {} + for name, value in task.depends_on.items(): + kwargs[name] = tree_map(lambda x: _safe_load(x, task, False), value) + + parameters = inspect.signature(task.function).parameters + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: _safe_load(x, task, True), value) + + task.execute(**kwargs) + + # Parse tasks created with @task. + name_to_function: Mapping[str, Callable[..., Any] | PTask] + if isinstance(task, PTaskWithPath) and task.path in COLLECTED_TASKS: + tasks = COLLECTED_TASKS.pop(task.path) + name_to_function = parse_collected_tasks_with_task_marker(tasks) + elif None in COLLECTED_TASKS: + tasks = COLLECTED_TASKS.pop(None) + name_to_function = parse_collected_tasks_with_task_marker(tasks) + else: + msg = "The task generator {task.name!r} did not create any tasks." + raise RuntimeError(msg) + + new_reports = [] + for name, function in name_to_function.items(): + report = session.hook.pytask_collect_task_protocol( + session=session, + reports=session.collection_reports, + path=task.path if isinstance(task, PTaskWithPath) else None, + name=name, + obj=function, + ) + new_reports.append(report) + + session.tasks.extend( + i.node + for i in new_reports + if i.outcome == CollectionOutcome.SUCCESS and isinstance(i.node, PTask) + ) + + try: + session.hook.pytask_collect_modify_tasks( + session=session, tasks=session.tasks + ) + except Exception: # noqa: BLE001 # pragma: no cover + report = ExecutionReport.from_task_and_exception( + task=task, exc_info=sys.exc_info() + ) + session.collection_reports.append(report) + + recreate_dag(session, task) + + +@hookimpl +def pytask_execute_task_process_report(report: ExecutionReport) -> bool | None: + """Prevent update of states for successful task generators. + + It also leads to task generators always being executed, but we have an additional + switch implemented in ``pytask_execute_task_setup``. + + """ + task = report.task + if report.outcome == TaskOutcome.SUCCESS and is_task_generator(task): + return True + return None + + +@hookimpl +def pytask_unconfigure() -> None: + """Clear the global variable after execution.""" + TASKS_WITH_PROVISIONAL_NODES.clear() diff --git a/src/_pytask/provisional_utils.py b/src/_pytask/provisional_utils.py new file mode 100644 index 00000000..9a09b096 --- /dev/null +++ b/src/_pytask/provisional_utils.py @@ -0,0 +1,107 @@ +"""Contains utilities related to provisional nodes and task generators.""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +from _pytask.collect_utils import collect_dependency +from _pytask.dag import create_dag_from_session +from _pytask.dag_utils import TopologicalSorter +from _pytask.models import NodeInfo +from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode +from _pytask.node_protocols import PTask +from _pytask.node_protocols import PTaskWithPath +from _pytask.nodes import Task +from _pytask.reports import ExecutionReport +from _pytask.tree_util import PyTree +from _pytask.tree_util import tree_map_with_path +from _pytask.typing import is_task_generator + +if TYPE_CHECKING: + from _pytask.session import Session + + +TASKS_WITH_PROVISIONAL_NODES = set() + + +def collect_provisional_nodes( + session: Session, task: PTask, node: Any, path: tuple[Any, ...] +) -> PyTree[PNode | PProvisionalNode]: + """Collect provisional nodes. + + 1. Call the :meth:`pytask.PProvisionalNode.collect` to receive the raw nodes. + 2. Collect the raw nodes as usual. + + """ + if not isinstance(node, PProvisionalNode): + return node + + # Add task to register to update the DAG after the task is executed. + TASKS_WITH_PROVISIONAL_NODES.add(task.signature) + + # Collect provisional nodes and receive raw nodes. + provisional_nodes = node.collect() + + # Collect raw nodes. + node_path = task.path.parent if isinstance(task, PTaskWithPath) else Path.cwd() + task_name = task.base_name if isinstance(task, Task) else task.name + task_path = task.path if isinstance(task, PTaskWithPath) else None + arg_name, *rest_path = path + + return tree_map_with_path( + lambda p, x: collect_dependency( + session, + node_path, + task_name, + NodeInfo( + arg_name=arg_name, + path=(*rest_path, *p), + value=x, + task_path=task_path, + task_name=task_name, + ), + ), + provisional_nodes, + ) + + +def recreate_dag(session: Session, task: PTask) -> None: + """Recreate the DAG when provisional nodes are resolved. + + If the DAG resolution fails, the error is attached as an execution report since + there is not better mechanic yet to display the error. + + """ + try: + session.dag = create_dag_from_session(session) + session.scheduler = TopologicalSorter.from_dag_and_sorter( + session.dag, session.scheduler + ) + + except Exception: # noqa: BLE001 + report = ExecutionReport.from_task_and_exception(task, sys.exc_info()) + session.execution_reports.append(report) + session.should_stop = True + + +def collect_provisional_products(session: Session, task: PTask) -> None: + """Collect provisional products. + + Unfortunately, this function needs to be called when a task finishes successfully + (skipped unchanged, persisted, etc..). + + """ + if is_task_generator(task): + return + + # Replace provisional nodes with their actually resolved nodes. + task.produces = tree_map_with_path( # type: ignore[assignment] + lambda p, x: collect_provisional_nodes(session, task, x, p), task.produces + ) + + if task.signature in TASKS_WITH_PROVISIONAL_NODES: + recreate_dag(session, task) diff --git a/src/_pytask/reports.py b/src/_pytask/reports.py index 40da6cca..9c74335b 100644 --- a/src/_pytask/reports.py +++ b/src/_pytask/reports.py @@ -23,6 +23,7 @@ from rich.console import RenderResult from _pytask.node_protocols import PNode + from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask @@ -31,7 +32,7 @@ class CollectionReport: """A collection report for a task.""" outcome: CollectionOutcome - node: PTask | PNode | None = None + node: PTask | PNode | PProvisionalNode | None = None exc_info: OptionalExceptionInfo | None = None @classmethod @@ -39,7 +40,7 @@ def from_exception( cls: type[CollectionReport], outcome: CollectionOutcome, exc_info: OptionalExceptionInfo, - node: PTask | PNode | None = None, + node: PTask | PNode | PProvisionalNode | None = None, ) -> CollectionReport: return cls(outcome=outcome, node=node, exc_info=exc_info) diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 3b62dc1a..2174b2a7 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -14,6 +14,7 @@ from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask if TYPE_CHECKING: @@ -83,10 +84,13 @@ def reduce_names_of_multiple_nodes( if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain - elif isinstance(node, PNode): + elif isinstance(node, (PNode, PProvisionalNode)): short_name = format_node_name(node, paths).plain else: - msg = f"Requires a 'PTask' or a 'PNode' and not {type(node)!r}." + msg = ( + "Requires a 'PTask', 'PNode', or 'PProvisionalNode' and not " + f"{type(node)!r}." + ) raise TypeError(msg) short_names.append(short_name) diff --git a/src/_pytask/skipping.py b/src/_pytask/skipping.py index b28ff8d1..a7678154 100644 --- a/src/_pytask/skipping.py +++ b/src/_pytask/skipping.py @@ -14,6 +14,7 @@ from _pytask.outcomes import SkippedUnchanged from _pytask.outcomes import TaskOutcome from _pytask.pluginmanager import hookimpl +from _pytask.provisional_utils import collect_provisional_products if TYPE_CHECKING: from _pytask.node_protocols import PTask @@ -52,6 +53,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: task, "would_be_executed" ) if is_unchanged and not session.config["force"]: + collect_provisional_products(session, task) raise SkippedUnchanged is_skipped = has_mark(task, "skip") @@ -89,8 +91,9 @@ def pytask_execute_task_process_report( if report.exc_info: if isinstance(report.exc_info[1], SkippedUnchanged): report.outcome = TaskOutcome.SKIP_UNCHANGED + return True - elif isinstance(report.exc_info[1], Skipped): + if isinstance(report.exc_info[1], Skipped): report.outcome = TaskOutcome.SKIP for descending_task_name in descending_tasks(task.signature, session.dag): @@ -102,12 +105,10 @@ def pytask_execute_task_process_report( {"reason": f"Previous task {task.name!r} was skipped."}, ) ) + return True - elif isinstance(report.exc_info[1], SkippedAncestorFailed): + if isinstance(report.exc_info[1], SkippedAncestorFailed): report.outcome = TaskOutcome.SKIP_PREVIOUS_FAILED + return True - if report.exc_info and isinstance( - report.exc_info[1], (Skipped, SkippedUnchanged, SkippedAncestorFailed) - ): - return True return None diff --git a/src/_pytask/task.py b/src/_pytask/task.py index f1a0c213..902f6eb0 100644 --- a/src/_pytask/task.py +++ b/src/_pytask/task.py @@ -82,3 +82,8 @@ def _raise_error_when_task_functions_are_duplicated( f"a lambda expression like 'task(...)(lambda **x: func(**x))'.\n\n{flat_tree}" ) raise ValueError(msg) + + +@hookimpl +def pytask_unconfigure() -> None: + COLLECTED_TASKS.clear() diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 9727968b..a9f556c4 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -40,10 +40,11 @@ """ -def task( +def task( # noqa: PLR0913 name: str | None = None, *, after: str | Callable[..., Any] | list[Callable[..., Any]] | None = None, + is_generator: bool = False, id: str | None = None, # noqa: A002 kwargs: dict[Any, Any] | None = None, produces: Any | None = None, @@ -64,6 +65,19 @@ def task( An expression or a task function or a list of task functions that need to be executed before this task can be executed. See :ref:`after` for more information. + is_generator + An indicator whether this task is a task generator. + id + An id for the task if it is part of a parametrization. Otherwise, an automatic + id will be generated. See + :doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for + more information. + kwargs + A dictionary containing keyword arguments which are passed to the task when it + is executed. + produces + Definition of products to parse the function returns and store them. See + :doc:`this how-to guide <../how_to_guides/using_task_returns>` for more id An id for the task if it is part of a repetition. Otherwise, an automatic id will be generated. See :ref:`how-to-repeat-a-task-with-different-inputs-the-id` @@ -86,7 +100,8 @@ def task( from typing import Annotated from pytask import task - @task def create_text_file() -> Annotated[str, Path("file.txt")]: + @task() + def create_text_file() -> Annotated[str, Path("file.txt")]: return "Hello, World!" """ @@ -121,20 +136,23 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: parsed_after = _parse_after(after) if hasattr(unwrapped, "pytask_meta"): - unwrapped.pytask_meta.name = parsed_name + unwrapped.pytask_meta.after = parsed_after + unwrapped.pytask_meta.is_generator = is_generator + unwrapped.pytask_meta.id_ = id unwrapped.pytask_meta.kwargs = parsed_kwargs unwrapped.pytask_meta.markers.append(Mark("task", (), {})) - unwrapped.pytask_meta.id_ = id + unwrapped.pytask_meta.name = parsed_name unwrapped.pytask_meta.produces = produces unwrapped.pytask_meta.after = parsed_after else: unwrapped.pytask_meta = CollectionMetadata( - name=parsed_name, + after=parsed_after, + is_generator=is_generator, + id_=id, kwargs=parsed_kwargs, markers=[Mark("task", (), {})], - id_=id, + name=parsed_name, produces=produces, - after=parsed_after, ) # Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage diff --git a/src/_pytask/typing.py b/src/_pytask/typing.py index 770f820b..6bc64862 100644 --- a/src/_pytask/typing.py +++ b/src/_pytask/typing.py @@ -10,6 +10,7 @@ from attrs import define if TYPE_CHECKING: + from pytask import PTask from typing_extensions import TypeAlias @@ -38,6 +39,11 @@ def is_task_function(obj: Any) -> bool: ) +def is_task_generator(task: PTask) -> bool: + """Check if a task is a generator.""" + return task.attributes.get("is_generator", False) + + class _NoDefault(Enum): """A singleton for no defaults. diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index 5e5bdb2a..b4dab4f0 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -43,8 +43,10 @@ from _pytask.models import NodeInfo from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode +from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask from _pytask.node_protocols import PTaskWithPath +from _pytask.nodes import DirectoryNode from _pytask.nodes import PathNode from _pytask.nodes import PickleNode from _pytask.nodes import PythonNode @@ -94,6 +96,7 @@ "DagReport", "DataCatalog", "DatabaseSession", + "DirectoryNode", "EnumChoice", "ExecutionError", "ExecutionReport", @@ -107,6 +110,7 @@ "NodeNotFoundError", "PNode", "PPathNode", + "PProvisionalNode", "PTask", "PTaskWithPath", "PathNode", diff --git a/tests/conftest.py b/tests/conftest.py index b15c9aaa..cf1f65d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import re import sys from contextlib import contextmanager @@ -30,7 +29,7 @@ def _remove_variable_info_from_output(data: str, path: Any) -> str: # noqa: ARG # Remove dynamic versions. index_root = next(i for i, line in enumerate(lines) if line.startswith("Root:")) - new_info_line = "".join(lines[1:index_root]) + new_info_line = " ".join(lines[1:index_root]) for platform in ("linux", "win32", "darwin"): new_info_line = new_info_line.replace(platform, "") pattern = re.compile(version.VERSION_PATTERN, flags=re.IGNORECASE | re.VERBOSE) @@ -112,7 +111,6 @@ def runner(): def pytest_collection_modifyitems(session, config, items) -> None: # noqa: ARG001 """Add markers to Jupyter notebook tests.""" - if sys.platform == "darwin" and "CI" in os.environ: # pragma: no cover - for item in items: - if isinstance(item, NotebookItem): - item.add_marker(pytest.mark.xfail(reason="Fails regularly on MacOS")) + for item in items: + if isinstance(item, NotebookItem): + item.add_marker(pytest.mark.xfail(reason="The tests are flaky.")) diff --git a/tests/test_cache.py b/tests/test_cache.py index 218120f3..1bbc9600 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,10 +2,12 @@ import inspect +import pytest from _pytask.cache import Cache from _pytask.cache import _make_memoize_key +@pytest.mark.unit() def test_cache(): cache = Cache() @@ -29,6 +31,7 @@ def func(a, b): assert func.cache.cache_info.misses == 1 +@pytest.mark.unit() def test_cache_add(): cache = Cache() @@ -52,6 +55,7 @@ def func(a): assert cache.cache_info.misses == 1 +@pytest.mark.unit() def test_make_memoize_key(): def func(a, b): # pragma: no cover return a + b diff --git a/tests/test_capture.py b/tests/test_capture.py index 09a6ad1f..f947c783 100644 --- a/tests/test_capture.py +++ b/tests/test_capture.py @@ -112,6 +112,7 @@ def task_show_capture(): raise NotImplementedError +@pytest.mark.end_to_end() @pytest.mark.xfail( sys.platform == "win32", reason="from pytask ... cannot be found", diff --git a/tests/test_collect.py b/tests/test_collect.py index a6b129bb..93591998 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -425,6 +425,7 @@ def task_example(path: Annotated[Path, Path("file.txt"), Product]) -> None: ... assert "is defined twice" in result.output +@pytest.mark.end_to_end() @pytest.mark.parametrize( "node", [ @@ -446,6 +447,7 @@ def task_example(path = {node}): ... assert all(i in result.output for i in ("only", "files", "are", "allowed")) +@pytest.mark.end_to_end() @pytest.mark.parametrize( "node", [ @@ -469,6 +471,7 @@ def task_example(path: Annotated[Any, Product] = {node}): ... assert all(i in result.output for i in ("only", "files", "are", "allowed")) +@pytest.mark.end_to_end() @pytest.mark.parametrize( "node", [ @@ -494,6 +497,7 @@ def task_example() -> Annotated[str, {node}]: assert session.tasks[0].produces["return"].name == tmp_path.name + "/file.txt" +@pytest.mark.end_to_end() def test_error_when_return_annotation_cannot_be_parsed(runner, tmp_path): source = """ from typing_extensions import Annotated diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index e1400013..1112c3da 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -629,3 +629,85 @@ def task_example() -> Annotated[Dict[str, str], nodes]: assert result.exit_code == ExitCode.OK if sys.platform != "win32": assert result.output == snapshot_cli() + + +@pytest.mark.end_to_end() +@pytest.mark.parametrize( + "node_def", + [ + "paths: Annotated[List[Path], DirectoryNode(pattern='*.txt'), Product])", + "produces=DirectoryNode(pattern='*.txt'))", + ") -> Annotated[None, DirectoryNode(pattern='*.txt')]", + ], +) +def test_collect_task_with_provisional_path_node_as_product(runner, tmp_path, node_def): + source = f""" + from pytask import DirectoryNode, Product + from typing_extensions import Annotated, List + from pathlib import Path + + def task_example({node_def}: ... + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + # Without nodes. + result = runner.invoke(cli, ["collect", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + captured = result.output.replace("\n", "").replace(" ", "") + assert "" in captured + assert "" in captured + + # With nodes. + result = runner.invoke(cli, ["collect", tmp_path.as_posix(), "--nodes"]) + assert result.exit_code == ExitCode.OK + captured = result.output.replace("\n", "").replace(" ", "") + assert "" in captured + assert "" in captured + assert "" in captured + + +@pytest.mark.end_to_end() +def test_collect_task_with_provisional_dependencies(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode + from pathlib import Path + + def task_example( + paths = DirectoryNode(pattern="[ab].txt") + ) -> Annotated[str, Path("merged.txt")]: + path_dict = {path.stem: path for path in paths} + return path_dict["a"].read_text() + path_dict["b"].read_text() + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "[ab].txt" in result.output + + +@pytest.mark.end_to_end() +def test_collect_custom_node_receives_default_name(runner, tmp_path): + source = """ + from typing_extensions import Annotated + + class CustomNode: + name: str = "" + + def state(self): return None + def signature(self): return "signature" + def load(self, is_product): ... + def save(self, value): ... + + def task_example() -> Annotated[None, CustomNode()]: ... + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + output = result.output.replace(" ", "").replace("\n", "") + assert "task_example::return" in output diff --git a/tests/test_config.py b/tests/test_config.py index 3874e1ac..84af88b1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -65,6 +65,7 @@ def test_passing_paths_via_configuration_file(tmp_path, file_or_folder): assert len(session.tasks) == 1 +@pytest.mark.end_to_end() def test_not_existing_path_in_config(runner, tmp_path): config = """ [tool.pytask.ini_options] @@ -76,6 +77,7 @@ def test_not_existing_path_in_config(runner, tmp_path): assert result.exit_code == ExitCode.CONFIGURATION_FAILED +@pytest.mark.end_to_end() def test_paths_are_relative_to_configuration_file_cli(tmp_path): tmp_path.joinpath("src").mkdir() tmp_path.joinpath("tasks").mkdir() @@ -96,6 +98,7 @@ def test_paths_are_relative_to_configuration_file_cli(tmp_path): assert "1 Succeeded" in result.stdout.decode() +@pytest.mark.end_to_end() @pytest.mark.skipif( sys.platform == "win32" and os.environ.get("CI") == "true", reason="Windows does not pick up the right Python interpreter.", diff --git a/tests/test_dag.py b/tests/test_dag.py index 0d4d711d..80c2bab5 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from _pytask.dag import _create_dag +from _pytask.dag import _create_dag_from_tasks from pytask import ExitCode from pytask import PathNode from pytask import Task @@ -26,7 +26,7 @@ def test_create_dag(): 1: PathNode.from_path(root / "node_2"), }, ) - dag = _create_dag(tasks=[task]) + dag = _create_dag_from_tasks(tasks=[task]) for signature in ( "90bb899a1b60da28ff70352cfb9f34a8bed485597c7f40eed9bd4c6449147525", @@ -97,6 +97,7 @@ def task_example(produces = Path("file.txt")): assert result.exit_code == ExitCode.OK +@pytest.mark.end_to_end() def test_python_nodes_are_unique(tmp_path): tmp_path.joinpath("a").mkdir() tmp_path.joinpath("a", "task_example.py").write_text("def task_example(a=1): pass") diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index b70acf35..d6050cfe 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -206,3 +206,31 @@ def test_adding_a_python_node(): data_catalog = DataCatalog() data_catalog.add("node", PythonNode(name="node", value=1)) assert isinstance(data_catalog["node"], PythonNode) + + +@pytest.mark.end_to_end() +def test_use_data_catalog_with_provisional_node(runner, tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated, List + + from pytask import DataCatalog + from pytask import DirectoryNode + + # Generate input data + data_catalog = DataCatalog() + data_catalog.add("directory", DirectoryNode(pattern="*.txt")) + + def task_add_content( + paths: Annotated[List[Path], data_catalog["directory"]] + ) -> Annotated[str, Path("output.txt")]: + name_to_path = {path.stem: path for path in paths} + return name_to_path["a"].read_text() + name_to_path["b"].read_text() + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("a.txt").write_text("Hello, ") + tmp_path.joinpath("b.txt").write_text("World!") + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("output.txt").read_text() == "Hello, World!" diff --git a/tests/test_execute.py b/tests/test_execute.py index 0bfb2613..8459914a 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -702,6 +702,7 @@ def task_example( assert "_pytask/execute.py" not in result.output +@pytest.mark.end_to_end() def test_hashing_works(tmp_path): """Use subprocess or otherwise the cache is filled from other tests.""" source = """ @@ -726,6 +727,7 @@ def task_example() -> Annotated[str, Path("file.txt")]: assert hashes == hashes_ +@pytest.mark.end_to_end() def test_python_node_as_product_with_product_annotation(runner, tmp_path): source = """ from typing_extensions import Annotated @@ -746,6 +748,7 @@ def task_write_file(text: Annotated[str, node]) -> Annotated[str, Path("file.txt assert tmp_path.joinpath("file.txt").read_text() == "Hello, World!" +@pytest.mark.end_to_end() def test_pickle_node_as_product_with_product_annotation(runner, tmp_path): source = """ from typing_extensions import Annotated @@ -822,6 +825,7 @@ def task_d(path=Path("../bld/in.txt"), produces=Path("out.txt")): assert "bld/in.txt" in result.output +@pytest.mark.end_to_end() def test_error_when_node_state_throws_error(runner, tmp_path): source = """ from pytask import PythonNode @@ -836,6 +840,7 @@ def task_example(a = PythonNode(value={"a": 1}, hash=True)): assert "TypeError: unhashable type: 'dict'" in result.output +@pytest.mark.end_to_end() def test_task_is_not_reexecuted(runner, tmp_path): source = """ from typing_extensions import Annotated @@ -860,6 +865,7 @@ def task_second(path = Path("out.txt")) -> Annotated[str, Path("copy.txt")]: assert "1 Skipped because unchanged" in result.output +@pytest.mark.end_to_end() def test_use_functional_interface_with_task(tmp_path): def func(path): path.touch() @@ -878,6 +884,7 @@ def func(path): assert session.exit_code == ExitCode.OK +@pytest.mark.end_to_end() def test_collect_task(runner, tmp_path): source = """ from pytask import Task, PathNode @@ -898,6 +905,7 @@ def func(path): path.touch() assert tmp_path.joinpath("out.txt").exists() +@pytest.mark.end_to_end() def test_collect_task_without_path(runner, tmp_path): source = """ from pytask import TaskWithoutPath, PathNode @@ -918,19 +926,221 @@ def func(path): path.touch() assert tmp_path.joinpath("out.txt").exists() -def test_with_http_path(runner, tmp_path): +@pytest.mark.end_to_end() +def test_task_that_produces_provisional_path_node(tmp_path): source = """ - from upath import UPath from typing_extensions import Annotated + from pytask import DirectoryNode, Product + from pathlib import Path - url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" + def task_example( + root_path: Annotated[Path, DirectoryNode(pattern="*.txt"), Product] + ): + root_path.joinpath("a.txt").write_text("Hello, ") + root_path.joinpath("b.txt").write_text("World!") + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) - def task_example(path = UPath(url)) -> Annotated[str, UPath("data.txt")]: - return path.read_text() + session = build(paths=tmp_path) + + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 1 + assert len(session.tasks[0].produces["root_path"]) == 2 + + # Rexecution does skip the task. + session = build(paths=tmp_path) + assert session.execution_reports[0].outcome == TaskOutcome.SKIP_UNCHANGED + + +@pytest.mark.end_to_end() +def test_task_that_depends_on_relative_provisional_path_node(tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode + from pathlib import Path + + def task_example( + paths = DirectoryNode(pattern="[ab].txt") + ) -> Annotated[str, Path("merged.txt")]: + path_dict = {path.stem: path for path in paths} + return path_dict["a"].read_text() + path_dict["b"].read_text() """ - tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("a.txt").write_text("Hello, ") + tmp_path.joinpath("b.txt").write_text("World!") + + session = build(paths=tmp_path) + + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 1 + assert len(session.tasks[0].depends_on["paths"]) == 2 + + +@pytest.mark.end_to_end() +def test_task_that_depends_on_provisional_path_node_with_root_dir(tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode + from pathlib import Path + + root_dir = Path(__file__).parent / "subfolder" + + def task_example( + paths = DirectoryNode(root_dir=root_dir, pattern="[ab].txt") + ) -> Annotated[str, Path(__file__).parent.joinpath("merged.txt")]: + path_dict = {path.stem: path for path in paths} + return path_dict["a"].read_text() + path_dict["b"].read_text() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("subfolder").mkdir() + tmp_path.joinpath("subfolder", "a.txt").write_text("Hello, ") + tmp_path.joinpath("subfolder", "b.txt").write_text("World!") + + session = build(paths=tmp_path) + + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 1 + assert len(session.tasks[0].depends_on["paths"]) == 2 + + +@pytest.mark.end_to_end() +def test_task_that_depends_on_provisional_task(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode, task + from pathlib import Path + + def task_produces() -> Annotated[None, DirectoryNode(pattern="[ab].txt")]: + path = Path(__file__).parent + path.joinpath("a.txt").write_text("Hello, ") + path.joinpath("b.txt").write_text("World!") + + @task(after=task_produces) + def task_depends( + paths = DirectoryNode(pattern="[ab].txt") + ) -> Annotated[str, Path(__file__).parent.joinpath("merged.txt")]: + path_dict = {path.stem: path for path in paths} + return path_dict["a"].read_text() + path_dict["b"].read_text() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) result = runner.invoke(cli, [tmp_path.as_posix()]) - print(result.output) # noqa: T201 assert result.exit_code == ExitCode.OK - assert tmp_path.joinpath("data.txt").exists() + assert "2 Collected tasks" in result.output + assert "2 Succeeded" in result.output + + +@pytest.mark.end_to_end() +def test_gracefully_fail_when_dag_raises_error(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode, task + from pathlib import Path + + def task_produces() -> Annotated[None, DirectoryNode(pattern="*.txt")]: + path = Path(__file__).parent + path.joinpath("a.txt").write_text("Hello, ") + path.joinpath("b.txt").write_text("World!") + + @task(after=task_produces) + def task_depends( + paths = DirectoryNode(pattern="[ab].txt") + ) -> Annotated[str, Path(__file__).parent.joinpath("merged.txt")]: + path_dict = {path.stem: path for path in paths} + return path_dict["a"].read_text() + path_dict["b"].read_text() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.FAILED + assert "There are some tasks which produce" in result.output + + +@pytest.mark.end_to_end() +def test_provisional_task_generation(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode, task + from pathlib import Path + + def task_produces() -> Annotated[None, DirectoryNode(pattern="[ab].txt")]: + path = Path(__file__).parent + path.joinpath("a.txt").write_text("Hello, ") + path.joinpath("b.txt").write_text("World!") + + @task(after=task_produces, is_generator=True) + def task_depends( + paths = DirectoryNode(pattern="[ab].txt") + ): + for path in paths: + + @task + def task_copy( + path: Path = path + ) -> Annotated[str, path.with_name(path.stem + "-copy.txt")]: + return path.read_text() + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "4 Collected tasks" in result.output + assert "4 Succeeded" in result.output + assert tmp_path.joinpath("a-copy.txt").exists() + assert tmp_path.joinpath("b-copy.txt").exists() + + +@pytest.mark.end_to_end() +def test_gracefully_fail_when_task_generator_raises_error(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode, task, Product + from pathlib import Path + + @task(is_generator=True) + def task_example( + root_dir: Annotated[Path, DirectoryNode(pattern="[a].txt"), Product] + ) -> ...: + raise Exception + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.FAILED + assert "1 Collected task" in result.output + assert "1 Failed" in result.output + + +@pytest.mark.end_to_end() +def test_use_provisional_node_as_product_in_generator_without_rerun(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import DirectoryNode, task, Product + from pathlib import Path + + @task(is_generator=True) + def task_example( + root_dir: Annotated[Path, DirectoryNode(pattern="[ab].txt"), Product] + ) -> ...: + for path in (root_dir / "a.txt", root_dir / "b.txt"): + + @task + def create_file() -> Annotated[Path, path]: + return "content" + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "3 Collected task" in result.output + assert "3 Succeeded" in result.output + + # No rerun. + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "3 Collected task" in result.output + assert "1 Succeeded" in result.output + assert "2 Skipped because unchanged" in result.output diff --git a/tests/test_hashlib.py b/tests/test_hashlib.py index a25c7949..6fec4f1e 100644 --- a/tests/test_hashlib.py +++ b/tests/test_hashlib.py @@ -6,6 +6,7 @@ from _pytask._hashlib import hash_value +@pytest.mark.unit() @pytest.mark.parametrize( ("value", "expected"), [ diff --git a/tests/test_jupyter/test_task_generator.ipynb b/tests/test_jupyter/test_task_generator.ipynb new file mode 100644 index 00000000..2ef6aa61 --- /dev/null +++ b/tests/test_jupyter/test_task_generator.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from pathlib import Path\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import pytask\n", + "from pytask import DirectoryNode, ExitCode, task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "def task_create_files() -> Annotated[None, DirectoryNode(pattern=\"[ab].txt\")]:\n", + " path = Path()\n", + " path.joinpath(\"a.txt\").write_text(\"Hello, \")\n", + " path.joinpath(\"b.txt\").write_text(\"World!\")\n", + "\n", + "\n", + "@task(after=task_create_files, is_generator=True)\n", + "def task_generator_copy_files(\n", + " paths: Annotated[list[Path], DirectoryNode(pattern=\"[ab].txt\")]\n", + "):\n", + " for path in paths:\n", + "\n", + " @task\n", + " def task_copy(\n", + " path: Path = path,\n", + " ) -> Annotated[str, path.with_name(path.stem + \"-copy.txt\")]:\n", + " return path.read_text()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "session = pytask.build(tasks=[task_create_files, task_generator_copy_files])\n", + "assert session.exit_code == ExitCode.OK" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_mark_structures.py b/tests/test_mark_structures.py index 30540902..f0b554f9 100644 --- a/tests/test_mark_structures.py +++ b/tests/test_mark_structures.py @@ -4,6 +4,7 @@ import pytest +@pytest.mark.unit() @pytest.mark.parametrize( ("lhs", "rhs", "expected"), [ @@ -17,6 +18,7 @@ def test__eq__(lhs, rhs, expected) -> None: assert (lhs == rhs) == expected +@pytest.mark.unit() @pytest.mark.filterwarnings("ignore:Unknown pytask\\.mark\\.foo") def test_aliases() -> None: md = pytask.mark.foo(1, "2", three=3) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index b396e326..48bda1a6 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -31,6 +31,7 @@ def test_hash_of_python_node(value, hash_, expected): assert state == expected +@pytest.mark.unit() @pytest.mark.parametrize( ("node", "expected"), [ @@ -110,6 +111,7 @@ def test_hash_of_pickle_node(tmp_path, value, exists, expected): assert state is expected +@pytest.mark.unit() @pytest.mark.parametrize( ("node", "protocol", "expected"), [ diff --git a/tests/test_skipping.py b/tests/test_skipping.py index 6f3a3298..81e20c69 100644 --- a/tests/test_skipping.py +++ b/tests/test_skipping.py @@ -260,6 +260,7 @@ def test_pytask_execute_task_setup(marker_name, force, expectation): pytask_execute_task_setup(session=session, task=task) +@pytest.mark.end_to_end() def test_skip_has_precendence_over_ancestor_failed(runner, tmp_path): source = """ from pathlib import Path @@ -276,6 +277,7 @@ def task_example_2(path=Path("file.txt")): ... assert "1 Skipped" in result.output +@pytest.mark.end_to_end() def test_skipif_has_precendence_over_ancestor_failed(runner, tmp_path): source = """ from pathlib import Path diff --git a/tests/test_task.py b/tests/test_task.py index bd82d41c..71c5507b 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -571,6 +571,7 @@ def task_example(): assert tmp_path.joinpath("file2.txt").read_text() == "World!" +@pytest.mark.end_to_end() def test_error_when_function_is_defined_outside_loop_body(runner, tmp_path): source = """ from pathlib import Path @@ -590,6 +591,7 @@ def func(path: Annotated[Path, Product]): assert "id=None" in result.output +@pytest.mark.end_to_end() def test_error_when_function_is_defined_outside_loop_body_with_id(runner, tmp_path): source = """ from pathlib import Path @@ -610,6 +612,7 @@ def func(path: Annotated[Path, Product]): assert "id=b.txt" in result.output +@pytest.mark.end_to_end() def test_task_will_be_executed_after_another_one_with_string(runner, tmp_path): source = """ from pytask import task @@ -689,6 +692,7 @@ def task_second(): assert result.returncode == ExitCode.OK +@pytest.mark.end_to_end() def test_raise_error_for_wrong_after_expression(runner, tmp_path): source = """ from pytask import task @@ -706,6 +710,7 @@ def task_example() -> Annotated[str, Path("out.txt")]: assert "Wrong expression passed to 'after'" in result.output +@pytest.mark.end_to_end() def test_raise_error_with_builtin_function_as_task(runner, tmp_path): source = """ from pytask import task @@ -723,6 +728,7 @@ def test_raise_error_with_builtin_function_as_task(runner, tmp_path): assert "Builtin functions cannot be wrapped" in result.output +@pytest.mark.end_to_end() def test_task_function_in_another_module(runner, tmp_path): source = """ def func(): diff --git a/tests/test_task_utils.py b/tests/test_task_utils.py index 1eac0218..37e3cc44 100644 --- a/tests/test_task_utils.py +++ b/tests/test_task_utils.py @@ -2,13 +2,17 @@ from contextlib import ExitStack as does_not_raise # noqa: N813 from functools import partial +from pathlib import Path from typing import NamedTuple import pytest +from _pytask.task_utils import COLLECTED_TASKS from _pytask.task_utils import _arg_value_to_id_component from _pytask.task_utils import _parse_name from _pytask.task_utils import _parse_task_kwargs from attrs import define +from pytask import Mark +from pytask import task @pytest.mark.unit() @@ -60,6 +64,23 @@ def test_parse_task_kwargs(kwargs, expectation, expected): assert result == expected +@pytest.mark.integration() +def test_default_values_of_pytask_meta(): + @task() + def task_example(): ... + + assert task_example.pytask_meta.after == [] + assert not task_example.pytask_meta.is_generator + assert task_example.pytask_meta.id_ is None + assert task_example.pytask_meta.kwargs == {} + assert task_example.pytask_meta.markers == [Mark("task", (), {})] + assert task_example.pytask_meta.name == "task_example" + assert task_example.pytask_meta.produces is None + + # Remove collected task. + COLLECTED_TASKS.pop(Path(__file__)) + + def task_func(x): # noqa: ARG001 # pragma: no cover pass diff --git a/tests/test_traceback.py b/tests/test_traceback.py index d636f770..6348484e 100644 --- a/tests/test_traceback.py +++ b/tests/test_traceback.py @@ -45,6 +45,7 @@ def helper(): assert ("This variable should not be shown." in result.output) is not is_hidden +@pytest.mark.unit() def test_render_traceback_with_string_traceback(): traceback = Traceback((Exception, Exception("Help"), "String traceback.")) rendered = render_to_string(traceback, console)