diff --git a/latch/resources/map_tasks.py b/latch/resources/map_tasks.py index 53a8d52d..e7b51cb4 100644 --- a/latch/resources/map_tasks.py +++ b/latch/resources/map_tasks.py @@ -1,8 +1,8 @@ """ -A map task lets you run a pod task or a regular task over a -list of inputs within a single workflow node. This means you -can run thousands of instances of the task without creating -a node for every instance, providing valuable performance +A map task lets you run a pod task or a regular task over a +list of inputs within a single workflow node. This means you +can run thousands of instances of the task without creating +a node for every instance, providing valuable performance gains! Some use cases of map tasks include: @@ -40,4 +40,24 @@ def my_map_workflow(a: typing.List[int]) -> str: return coalesced """ -from flytekit.core.map_task import map_task +from typing import Callable, List, Protocol + +from flytekit.core.map_task import map_task as flyte_map_task +from typing_extensions import TypeVar + +T = TypeVar("T", contravariant=True) +S = TypeVar("S", covariant=True) + + +# Necessary bc Callable[[T], S] is stupid and assumes the arg must be positional +class MapTaskCallable(Protocol[T, S]): + def __call__(self, *args: T, **kwargs: T) -> S: ... + + +def map_task( + f: Callable[[T], S], + concurrency: int = 0, + min_success_rate: float = 1, + **kwargs, +) -> MapTaskCallable[List[T], List[S]]: + return flyte_map_task(f, concurrency, min_success_rate, **kwargs) diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index 5b2e00fc..a447b396 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -25,11 +25,15 @@ def my_task(a: int) -> str: """ import datetime -import functools -from typing import Union +from pathlib import Path +from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union, overload from warnings import warn -from flytekit import task +from flytekit import task as flyte_task +from flytekit.core.base_task import TaskResolverMixin +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.resources import Resources +from flytekit.models.security import Secret from flytekitplugins.pod import Pod from kubernetes.client.models import ( V1Container, @@ -37,6 +41,98 @@ def my_task(a: int) -> str: V1ResourceRequirements, V1Toleration, ) +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + + +def _task_with_config(task_config: Pod): + @overload + def task( + _task_function: Callable[P, T], + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Callable[P, T]: ... + + @overload + def task( + _task_function: Literal[None] = None, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + def task( + _task_function: Optional[Callable[P, T]] = None, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: + return flyte_task( + _task_function=_task_function, + task_config=task_config, + cache=cache, + cache_serialize=cache_serialize, + cache_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + timeout=timeout, + container_image=container_image, + environment=environment, + requests=requests, + limits=limits, + secret_requests=secret_requests, + execution_mode=execution_mode, + dockerfile=dockerfile, + task_resolver=task_resolver, + ) + + return task def _get_large_gpu_pod() -> Pod: @@ -177,7 +273,7 @@ def _get_small_pod() -> Pod: ) -large_gpu_task = functools.partial(task, task_config=_get_large_gpu_pod()) +large_gpu_task = _task_with_config(_get_large_gpu_pod()) """This task will get scheduled on a large GPU-enabled node. This node is not necessarily dedicated to the task, but the node itself will be @@ -205,7 +301,7 @@ def _get_small_pod() -> Pod: """ -small_gpu_task = functools.partial(task, task_config=_get_small_gpu_pod()) +small_gpu_task = _task_with_config(_get_small_gpu_pod()) """This task will get scheduled on a small GPU-enabled node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -232,7 +328,8 @@ def _get_small_pod() -> Pod: - True """ -large_task = functools.partial(task, task_config=_get_large_pod()) + +large_task = _task_with_config(_get_large_pod()) """This task will get scheduled on a large node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -260,7 +357,7 @@ def _get_small_pod() -> Pod: """ -medium_task = functools.partial(task, task_config=_get_medium_pod()) +medium_task = _task_with_config(_get_medium_pod()) """This task will get scheduled on a medium node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -288,7 +385,7 @@ def _get_small_pod() -> Pod: """ -small_task = functools.partial(task, task_config=_get_small_pod()) +small_task = _task_with_config(_get_small_pod()) """This task will get scheduled on a small node. .. list-table:: Title @@ -360,7 +457,7 @@ def custom_memory_optimized_task(cpu: int, memory: int): ), primary_container_name="primary", ) - return functools.partial(task, task_config=task_config) + return _task_with_config(task_config) def custom_task( @@ -469,4 +566,4 @@ def custom_task( " 4949 GiB)" ) - return functools.partial(task, task_config=task_config, timeout=timeout) + return _task_with_config(task_config) diff --git a/latch/resources/workflow.py b/latch/resources/workflow.py index 2358ca35..9a04888c 100644 --- a/latch/resources/workflow.py +++ b/latch/resources/workflow.py @@ -1,11 +1,11 @@ import inspect from dataclasses import is_dataclass from textwrap import dedent -from typing import Callable, Union, get_args, get_origin +from typing import Callable, TypeVar, Union, get_args, get_origin, overload import click from flytekit import workflow as _workflow -from flytekit.core.workflow import PythonFunctionWorkflow +from typing_extensions import ParamSpec from latch.types.metadata import LatchAuthor, LatchMetadata, LatchParameter from latch_cli.utils import best_effort_display_name @@ -28,12 +28,26 @@ def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None: f.__doc__ = f"{short_desc}\n{dedent(long_desc)}\n\n" + str(metadata) +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +def workflow(metadata: LatchMetadata) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + +@overload +def workflow( + metadata: Union[LatchMetadata, Callable[P, T]] +) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: ... + + # this weird Union thing is to ensure backwards compatibility, # so that when users call @workflow without any arguments or # parentheses, the workflow still serializes as expected def workflow( - metadata: Union[LatchMetadata, Callable] -) -> Union[PythonFunctionWorkflow, Callable]: + metadata: Union[LatchMetadata, Callable[P, T]] +) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: if isinstance(metadata, Callable): f = metadata if f.__doc__ is None or "__metadata__:" not in f.__doc__: @@ -41,7 +55,7 @@ def workflow( _inject_metadata(f, metadata) return _workflow(f) - def decorator(f: Callable): + def decorator(f: Callable[P, T]) -> Callable[P, T]: signature = inspect.signature(f) wf_params = signature.parameters