From 6b53a4a2504cb51d53fd8639cd694ec7f502f6f5 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Sat, 22 Jul 2023 12:07:42 -0700 Subject: [PATCH 1/5] try type hints again Signed-off-by: Ayush Kamat --- latch/resources/map_tasks.py | 25 ++++++++++++---- latch/resources/tasks.py | 58 ++++++++++++++++++++++++++++++++++-- latch/resources/workflow.py | 24 +++++++++++---- 3 files changed, 95 insertions(+), 12 deletions(-) diff --git a/latch/resources/map_tasks.py b/latch/resources/map_tasks.py index 53a8d52d..6e9ca5ab 100644 --- a/latch/resources/map_tasks.py +++ b/latch/resources/map_tasks.py @@ -1,8 +1,8 @@ """ -A map task lets you run a pod task or a regular task over a -list of inputs within a single workflow node. This means you -can run thousands of instances of the task without creating -a node for every instance, providing valuable performance +A map task lets you run a pod task or a regular task over a +list of inputs within a single workflow node. This means you +can run thousands of instances of the task without creating +a node for every instance, providing valuable performance gains! Some use cases of map tasks include: @@ -40,4 +40,19 @@ def my_map_workflow(a: typing.List[int]) -> str: return coalesced """ -from flytekit.core.map_task import map_task +from typing import Callable, List + +from flytekit.core.map_task import map_task as flyte_map_task +from typing_extensions import TypeVar + +T = TypeVar("T") +S = TypeVar("S") + + +def map_task( + f: Callable[[T], S], + concurrency: int = 0, + min_success_rate: float = 1, + **kwargs, +) -> Callable[[List[T]], S]: + return flyte_map_task(f, concurrency, min_success_rate, **kwargs) diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index b1ad61ca..5a0c9310 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -24,9 +24,16 @@ def my_task(a: int) -> str: https://docs.flyte.org/en/latest/ """ +import datetime import functools - -from flytekit import task +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, ParamSpec, TypeVar, Union + +from flytekit import task as flyte_task +from flytekit.core.base_task import TaskResolverMixin +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.resources import Resources +from flytekit.models.security import Secret from flytekitplugins.pod import Pod from kubernetes.client.models import ( V1Container, @@ -35,6 +42,53 @@ def my_task(a: int) -> str: V1Toleration, ) +P = ParamSpec("P") +T = TypeVar("T") + + +def task( + _task_function: Callable[P, T], + task_config: Optional[Any] = None, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + disable_deck: bool = False, +) -> Callable[P, T]: + return flyte_task( + _task_function=_task_function, + task_config=task_config, + cache=cache, + cache_serialize=cache_serialize, + cache_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + timeout=timeout, + container_image=container_image, + environment=environment, + requests=requests, + limits=limits, + secret_requests=secret_requests, + execution_mode=execution_mode, + dockerfile=dockerfile, + task_resolver=task_resolver, + disable_deck=disable_deck, + ) + def _get_large_gpu_pod() -> Pod: """g5.8xlarge,g5.16xlarge on-demand""" diff --git a/latch/resources/workflow.py b/latch/resources/workflow.py index 2f63f8fb..c018ac00 100644 --- a/latch/resources/workflow.py +++ b/latch/resources/workflow.py @@ -1,25 +1,39 @@ import inspect from dataclasses import is_dataclass from textwrap import dedent -from typing import Callable, Union, get_args, get_origin +from typing import Callable, ParamSpec, TypeVar, Union, get_args, get_origin, overload from flytekit import workflow as _workflow -from flytekit.core.workflow import PythonFunctionWorkflow from latch.types.metadata import LatchMetadata +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +def workflow(metadata: LatchMetadata) -> Callable[[Callable[P, T]], Callable[P, T]]: + ... + + +@overload +def workflow( + metadata: Union[LatchMetadata, Callable[P, T]] +) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: + ... + # this weird Union thing is to ensure backwards compatibility, # so that when users call @workflow without any arguments or # parentheses, the workflow still serializes as expected def workflow( - metadata: Union[LatchMetadata, Callable] -) -> Union[PythonFunctionWorkflow, Callable]: + metadata: Union[LatchMetadata, Callable[P, T]] +) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: if isinstance(metadata, Callable): return _workflow(metadata) else: - def decorator(f: Callable): + def decorator(f: Callable[P, T]) -> Callable[P, T]: if f.__doc__ is None: f.__doc__ = f"{f.__name__}\n\nSample Description" short_desc, long_desc = f.__doc__.split("\n", 1) From 35d83d38aeb357347cc60bbb0e981d304bd2a517 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 7 Aug 2023 10:24:25 -0700 Subject: [PATCH 2/5] kw-only Signed-off-by: Ayush Kamat --- latch/resources/tasks.py | 151 ++++++++++++++++++++++++++------------- 1 file changed, 100 insertions(+), 51 deletions(-) diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index cc7b9bbb..65d5f9cf 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -25,9 +25,8 @@ def my_task(a: int) -> str: """ import datetime -import functools from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, ParamSpec, TypeVar, Union +from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union, overload from warnings import warn from flytekit import task as flyte_task @@ -42,53 +41,103 @@ def my_task(a: int) -> str: V1ResourceRequirements, V1Toleration, ) +from typing_extensions import ParamSpec P = ParamSpec("P") T = TypeVar("T") -def task( - _task_function: Callable[P, T], - task_config: Optional[Any] = None, - cache: bool = False, - cache_serialize: bool = False, - cache_version: str = "", - retries: int = 0, - interruptible: Optional[bool] = None, - deprecated: str = "", - timeout: Union[datetime.timedelta, int] = 0, - container_image: Optional[str] = None, - environment: Optional[Dict[str, str]] = None, - requests: Optional[Resources] = None, - limits: Optional[Resources] = None, - secret_requests: Optional[List[Secret]] = None, - execution_mode: Optional[ - PythonFunctionTask.ExecutionBehavior - ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, - dockerfile: Optional[Path] = None, - task_resolver: Optional[TaskResolverMixin] = None, - disable_deck: bool = False, -) -> Callable[P, T]: - return flyte_task( - _task_function=_task_function, - task_config=task_config, - cache=cache, - cache_serialize=cache_serialize, - cache_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - timeout=timeout, - container_image=container_image, - environment=environment, - requests=requests, - limits=limits, - secret_requests=secret_requests, - execution_mode=execution_mode, - dockerfile=dockerfile, - task_resolver=task_resolver, - disable_deck=disable_deck, - ) +def task_with_config(task_config: Pod): + @overload + def task( + _task_function: Callable[P, T], + *, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Callable[P, T]: + ... + + @overload + def task( + _task_function: Literal[None] = None, + *, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + ... + + def task( + _task_function: Optional[Callable[P, T]] = None, + *, + cache: bool = False, + cache_serialize: bool = False, + cache_version: str = "", + retries: int = 0, + interruptible: Optional[bool] = None, + deprecated: str = "", + timeout: Union[datetime.timedelta, int] = 0, + container_image: Optional[str] = None, + environment: Optional[Dict[str, str]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + secret_requests: Optional[List[Secret]] = None, + execution_mode: Optional[ + PythonFunctionTask.ExecutionBehavior + ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + dockerfile: Optional[Path] = None, + task_resolver: Optional[TaskResolverMixin] = None, + ) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: + return flyte_task( + _task_function=_task_function, + task_config=task_config, + cache=cache, + cache_serialize=cache_serialize, + cache_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + timeout=timeout, + container_image=container_image, + environment=environment, + requests=requests, + limits=limits, + secret_requests=secret_requests, + execution_mode=execution_mode, + dockerfile=dockerfile, + task_resolver=task_resolver, + ) + + return task def _get_large_gpu_pod() -> Pod: @@ -229,7 +278,7 @@ def _get_small_pod() -> Pod: ) -large_gpu_task = functools.partial(task, task_config=_get_large_gpu_pod()) +large_gpu_task = task_with_config(_get_large_gpu_pod()) """This task will get scheduled on a large GPU-enabled node. This node is not necessarily dedicated to the task, but the node itself will be @@ -257,7 +306,7 @@ def _get_small_pod() -> Pod: """ -small_gpu_task = functools.partial(task, task_config=_get_small_gpu_pod()) +small_gpu_task = task_with_config(_get_small_gpu_pod()) """This task will get scheduled on a small GPU-enabled node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -285,7 +334,7 @@ def _get_small_pod() -> Pod: """ -large_task = functools.partial(task, task_config=_get_large_pod()) +large_task = task_with_config(_get_large_pod()) """This task will get scheduled on a large node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -313,7 +362,7 @@ def _get_small_pod() -> Pod: """ -medium_task = functools.partial(task, task_config=_get_medium_pod()) +medium_task = task_with_config(_get_medium_pod()) """This task will get scheduled on a medium node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -341,7 +390,7 @@ def _get_small_pod() -> Pod: """ -small_task = functools.partial(task, task_config=_get_small_pod()) +small_task = task_with_config(_get_small_pod()) """This task will get scheduled on a small node. .. list-table:: Title @@ -415,7 +464,7 @@ def custom_memory_optimized_task(cpu: int, memory: int): ), primary_container_name="primary", ) - return functools.partial(task, task_config=task_config) + return task_with_config(task_config) def custom_task(cpu: int, memory: int, *, storage_gib: int = 500): @@ -518,4 +567,4 @@ def custom_task(cpu: int, memory: int, *, storage_gib: int = 500): " 4949 GiB)" ) - return functools.partial(task, task_config=task_config) + return task_with_config(task_config) From 2d158641965654b7432b1933426465d95661690d Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 7 Aug 2023 10:37:05 -0700 Subject: [PATCH 3/5] fix import + make helper private Signed-off-by: Ayush Kamat --- .pre-commit-config.yaml | 3 ++- latch/resources/tasks.py | 22 ++++++++++------------ latch/resources/workflow.py | 3 ++- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58b5b1b2..bb801d91 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,9 +36,10 @@ repos: # - id: sort-simple-yaml # - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black + args: [--preview] - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index 65d5f9cf..bd430e36 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -47,7 +47,7 @@ def my_task(a: int) -> str: T = TypeVar("T") -def task_with_config(task_config: Pod): +def _task_with_config(task_config: Pod): @overload def task( _task_function: Callable[P, T], @@ -278,7 +278,7 @@ def _get_small_pod() -> Pod: ) -large_gpu_task = task_with_config(_get_large_gpu_pod()) +large_gpu_task = _task_with_config(_get_large_gpu_pod()) """This task will get scheduled on a large GPU-enabled node. This node is not necessarily dedicated to the task, but the node itself will be @@ -306,7 +306,7 @@ def _get_small_pod() -> Pod: """ -small_gpu_task = task_with_config(_get_small_gpu_pod()) +small_gpu_task = _task_with_config(_get_small_gpu_pod()) """This task will get scheduled on a small GPU-enabled node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -334,7 +334,7 @@ def _get_small_pod() -> Pod: """ -large_task = task_with_config(_get_large_pod()) +large_task = _task_with_config(_get_large_pod()) """This task will get scheduled on a large node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -362,7 +362,7 @@ def _get_small_pod() -> Pod: """ -medium_task = task_with_config(_get_medium_pod()) +medium_task = _task_with_config(_get_medium_pod()) """This task will get scheduled on a medium node. This node will be dedicated to the task. No other tasks will be allowed to run @@ -390,7 +390,7 @@ def _get_small_pod() -> Pod: """ -small_task = task_with_config(_get_small_pod()) +small_task = _task_with_config(_get_small_pod()) """This task will get scheduled on a small node. .. list-table:: Title @@ -426,10 +426,8 @@ def custom_memory_optimized_task(cpu: int, memory: int): memory: An integer number of Gibibytes of RAM to request, up to 511 GiB """ warn( - ( - "`custom_memory_optimized_task` is deprecated and will be removed in a" - " future release: use `custom_task` instead" - ), + "`custom_memory_optimized_task` is deprecated and will be removed in a" + " future release: use `custom_task` instead", DeprecationWarning, stacklevel=2, ) @@ -464,7 +462,7 @@ def custom_memory_optimized_task(cpu: int, memory: int): ), primary_container_name="primary", ) - return task_with_config(task_config) + return _task_with_config(task_config) def custom_task(cpu: int, memory: int, *, storage_gib: int = 500): @@ -567,4 +565,4 @@ def custom_task(cpu: int, memory: int, *, storage_gib: int = 500): " 4949 GiB)" ) - return task_with_config(task_config) + return _task_with_config(task_config) diff --git a/latch/resources/workflow.py b/latch/resources/workflow.py index c018ac00..0c4a4508 100644 --- a/latch/resources/workflow.py +++ b/latch/resources/workflow.py @@ -1,9 +1,10 @@ import inspect from dataclasses import is_dataclass from textwrap import dedent -from typing import Callable, ParamSpec, TypeVar, Union, get_args, get_origin, overload +from typing import Callable, TypeVar, Union, get_args, get_origin, overload from flytekit import workflow as _workflow +from typing_extensions import ParamSpec from latch.types.metadata import LatchMetadata From cdbfa4bb8c00903ce8890a08be9dc1a8aae1a959 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Wed, 21 Feb 2024 16:56:59 -0800 Subject: [PATCH 4/5] formatting Signed-off-by: Ayush Kamat --- latch/resources/tasks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index 75fcdfae..c04d9845 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -69,8 +69,7 @@ def task( ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, dockerfile: Optional[Path] = None, task_resolver: Optional[TaskResolverMixin] = None, - ) -> Callable[P, T]: - ... + ) -> Callable[P, T]: ... @overload def task( @@ -93,8 +92,7 @@ def task( ] = PythonFunctionTask.ExecutionBehavior.DEFAULT, dockerfile: Optional[Path] = None, task_resolver: Optional[TaskResolverMixin] = None, - ) -> Callable[[Callable[P, T]], Callable[P, T]]: - ... + ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... def task( _task_function: Optional[Callable[P, T]] = None, From 77b2bd4359405b28f06e3167ea166f1e96c0a9b7 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Thu, 22 Feb 2024 14:39:30 -0800 Subject: [PATCH 5/5] rm kw only + redo map tasks Signed-off-by: Ayush Kamat --- latch/resources/map_tasks.py | 13 +++++++++---- latch/resources/tasks.py | 3 --- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/latch/resources/map_tasks.py b/latch/resources/map_tasks.py index 6e9ca5ab..e7b51cb4 100644 --- a/latch/resources/map_tasks.py +++ b/latch/resources/map_tasks.py @@ -40,13 +40,18 @@ def my_map_workflow(a: typing.List[int]) -> str: return coalesced """ -from typing import Callable, List +from typing import Callable, List, Protocol from flytekit.core.map_task import map_task as flyte_map_task from typing_extensions import TypeVar -T = TypeVar("T") -S = TypeVar("S") +T = TypeVar("T", contravariant=True) +S = TypeVar("S", covariant=True) + + +# Necessary bc Callable[[T], S] is stupid and assumes the arg must be positional +class MapTaskCallable(Protocol[T, S]): + def __call__(self, *args: T, **kwargs: T) -> S: ... def map_task( @@ -54,5 +59,5 @@ def map_task( concurrency: int = 0, min_success_rate: float = 1, **kwargs, -) -> Callable[[List[T]], S]: +) -> MapTaskCallable[List[T], List[S]]: return flyte_map_task(f, concurrency, min_success_rate, **kwargs) diff --git a/latch/resources/tasks.py b/latch/resources/tasks.py index c04d9845..a447b396 100644 --- a/latch/resources/tasks.py +++ b/latch/resources/tasks.py @@ -51,7 +51,6 @@ def _task_with_config(task_config: Pod): @overload def task( _task_function: Callable[P, T], - *, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -74,7 +73,6 @@ def task( @overload def task( _task_function: Literal[None] = None, - *, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -96,7 +94,6 @@ def task( def task( _task_function: Optional[Callable[P, T]] = None, - *, cache: bool = False, cache_serialize: bool = False, cache_version: str = "",