Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try type hints again #303

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions latch/resources/map_tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
A map task lets you run a pod task or a regular task over a
list of inputs within a single workflow node. This means you
can run thousands of instances of the task without creating
a node for every instance, providing valuable performance
A map task lets you run a pod task or a regular task over a
list of inputs within a single workflow node. This means you
can run thousands of instances of the task without creating
a node for every instance, providing valuable performance
gains!

Some use cases of map tasks include:
Expand Down Expand Up @@ -40,4 +40,24 @@ def my_map_workflow(a: typing.List[int]) -> str:
return coalesced
"""

from flytekit.core.map_task import map_task
from typing import Callable, List, Protocol

from flytekit.core.map_task import map_task as flyte_map_task
from typing_extensions import TypeVar

T = TypeVar("T", contravariant=True)
S = TypeVar("S", covariant=True)


# Necessary bc Callable[[T], S] is stupid and assumes the arg must be positional
class MapTaskCallable(Protocol[T, S]):
def __call__(self, *args: T, **kwargs: T) -> S: ...


def map_task(
f: Callable[[T], S],
concurrency: int = 0,
min_success_rate: float = 1,
**kwargs,
) -> MapTaskCallable[List[T], List[S]]:
return flyte_map_task(f, concurrency, min_success_rate, **kwargs)
117 changes: 107 additions & 10 deletions latch/resources/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,114 @@ def my_task(a: int) -> str:
"""

import datetime
import functools
from typing import Union
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, TypeVar, Union, overload
from warnings import warn

from flytekit import task
from flytekit import task as flyte_task
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.resources import Resources
from flytekit.models.security import Secret
from flytekitplugins.pod import Pod
from kubernetes.client.models import (
V1Container,
V1PodSpec,
V1ResourceRequirements,
V1Toleration,
)
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")


def _task_with_config(task_config: Pod):
@overload
def task(
_task_function: Callable[P, T],
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
retries: int = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[datetime.timedelta, int] = 0,
container_image: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
secret_requests: Optional[List[Secret]] = None,
execution_mode: Optional[
PythonFunctionTask.ExecutionBehavior
] = PythonFunctionTask.ExecutionBehavior.DEFAULT,
dockerfile: Optional[Path] = None,
task_resolver: Optional[TaskResolverMixin] = None,
) -> Callable[P, T]: ...

@overload
def task(
_task_function: Literal[None] = None,
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
retries: int = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[datetime.timedelta, int] = 0,
container_image: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
secret_requests: Optional[List[Secret]] = None,
execution_mode: Optional[
PythonFunctionTask.ExecutionBehavior
] = PythonFunctionTask.ExecutionBehavior.DEFAULT,
dockerfile: Optional[Path] = None,
task_resolver: Optional[TaskResolverMixin] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...

def task(
_task_function: Optional[Callable[P, T]] = None,
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
retries: int = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[datetime.timedelta, int] = 0,
container_image: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
secret_requests: Optional[List[Secret]] = None,
execution_mode: Optional[
PythonFunctionTask.ExecutionBehavior
] = PythonFunctionTask.ExecutionBehavior.DEFAULT,
dockerfile: Optional[Path] = None,
task_resolver: Optional[TaskResolverMixin] = None,
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
return flyte_task(
_task_function=_task_function,
task_config=task_config,
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
timeout=timeout,
container_image=container_image,
environment=environment,
requests=requests,
limits=limits,
secret_requests=secret_requests,
execution_mode=execution_mode,
dockerfile=dockerfile,
task_resolver=task_resolver,
)

return task


def _get_large_gpu_pod() -> Pod:
Expand Down Expand Up @@ -177,7 +273,7 @@ def _get_small_pod() -> Pod:
)


large_gpu_task = functools.partial(task, task_config=_get_large_gpu_pod())
large_gpu_task = _task_with_config(_get_large_gpu_pod())
"""This task will get scheduled on a large GPU-enabled node.

This node is not necessarily dedicated to the task, but the node itself will be
Expand Down Expand Up @@ -205,7 +301,7 @@ def _get_small_pod() -> Pod:
"""


small_gpu_task = functools.partial(task, task_config=_get_small_gpu_pod())
small_gpu_task = _task_with_config(_get_small_gpu_pod())
"""This task will get scheduled on a small GPU-enabled node.

This node will be dedicated to the task. No other tasks will be allowed to run
Expand All @@ -232,7 +328,8 @@ def _get_small_pod() -> Pod:
- True
"""

large_task = functools.partial(task, task_config=_get_large_pod())

large_task = _task_with_config(_get_large_pod())
"""This task will get scheduled on a large node.

This node will be dedicated to the task. No other tasks will be allowed to run
Expand Down Expand Up @@ -260,7 +357,7 @@ def _get_small_pod() -> Pod:
"""


medium_task = functools.partial(task, task_config=_get_medium_pod())
medium_task = _task_with_config(_get_medium_pod())
"""This task will get scheduled on a medium node.

This node will be dedicated to the task. No other tasks will be allowed to run
Expand Down Expand Up @@ -288,7 +385,7 @@ def _get_small_pod() -> Pod:
"""


small_task = functools.partial(task, task_config=_get_small_pod())
small_task = _task_with_config(_get_small_pod())
"""This task will get scheduled on a small node.

.. list-table:: Title
Expand Down Expand Up @@ -360,7 +457,7 @@ def custom_memory_optimized_task(cpu: int, memory: int):
),
primary_container_name="primary",
)
return functools.partial(task, task_config=task_config)
return _task_with_config(task_config)


def custom_task(
Expand Down Expand Up @@ -469,4 +566,4 @@ def custom_task(
" 4949 GiB)"
)

return functools.partial(task, task_config=task_config, timeout=timeout)
return _task_with_config(task_config)
24 changes: 19 additions & 5 deletions latch/resources/workflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import inspect
from dataclasses import is_dataclass
from textwrap import dedent
from typing import Callable, Union, get_args, get_origin
from typing import Callable, TypeVar, Union, get_args, get_origin, overload

import click
from flytekit import workflow as _workflow
from flytekit.core.workflow import PythonFunctionWorkflow
from typing_extensions import ParamSpec

from latch.types.metadata import LatchAuthor, LatchMetadata, LatchParameter
from latch_cli.utils import best_effort_display_name
Expand All @@ -28,20 +28,34 @@ def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None:
f.__doc__ = f"{short_desc}\n{dedent(long_desc)}\n\n" + str(metadata)


P = ParamSpec("P")
T = TypeVar("T")


@overload
def workflow(metadata: LatchMetadata) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


@overload
def workflow(
metadata: Union[LatchMetadata, Callable[P, T]]
) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]: ...


# this weird Union thing is to ensure backwards compatibility,
# so that when users call @workflow without any arguments or
# parentheses, the workflow still serializes as expected
def workflow(
metadata: Union[LatchMetadata, Callable]
) -> Union[PythonFunctionWorkflow, Callable]:
metadata: Union[LatchMetadata, Callable[P, T]]
) -> Union[Callable[P, T], Callable[[Callable[P, T]], Callable[P, T]]]:
if isinstance(metadata, Callable):
f = metadata
if f.__doc__ is None or "__metadata__:" not in f.__doc__:
metadata = _generate_metadata(f)
_inject_metadata(f, metadata)
return _workflow(f)

def decorator(f: Callable):
def decorator(f: Callable[P, T]) -> Callable[P, T]:
signature = inspect.signature(f)
wf_params = signature.parameters

Expand Down