Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Optional samplesheet parameters #475

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 179 additions & 12 deletions latch/resources/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import inspect
import sys
from dataclasses import is_dataclass
from textwrap import dedent
from typing import Callable, Dict, Union, get_args, get_origin
from typing import Any
from typing import Callable
from typing import Dict
from typing import Union
from typing import _GenericAlias # type: ignore[attr-defined] # NB: mypy can't see private attrs
from typing import get_args
from typing import get_origin
from typing_extensions import TypeAlias
from typing_extensions import TypeGuard

import click
import os
Expand All @@ -12,6 +21,34 @@
from latch_cli.utils import best_effort_display_name


if sys.version_info >= (3, 10):
from types import UnionType
else:
# NB: `types.UnionType`, available since Python 3.10, is **not** a `type`, but is a class.
# We declare an empty class here to use in the instance checks below.
class UnionType:
pass


TypeAnnotation: TypeAlias = Union[type, _GenericAlias, UnionType]
"""
A function parameter's type annotation may be any of the following:
1) `type`, when declaring any of the built-in Python types
2) `typing._GenericAlias`, when declaring generic collection types or union types using pre-PEP
585 and pre-PEP 604 syntax (e.g. `List[int]`, `Optional[int]`, or `Union[int, None]`)
3) `types.UnionType`, when declaring union types using PEP604 syntax (e.g. `int | None`)
4) `types.GenericAlias`, when declaring generic collection types using PEP 585 syntax (e.g.
`list[int]`)

`types.GenericAlias` is a subclass of `type`, but `typing._GenericAlias` and `types.UnionType` are
not and must be considered explicitly.
"""

# TODO When dropping support for Python 3.9, deprecate this in favor of performing instance checks
# directly on the `TypeAnnotation` union type.
TYPE_ANNOTATION_TYPES = (type, _GenericAlias, UnionType) # type: ignore[attr-defined]


def _generate_metadata(f: Callable) -> LatchMetadata:
signature = inspect.signature(f)
metadata = LatchMetadata(f.__name__, LatchAuthor())
Expand All @@ -33,7 +70,7 @@ def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None:
# so that when users call @workflow without any arguments or
# parentheses, the workflow still serializes as expected
def workflow(
metadata: Union[LatchMetadata, Callable]
metadata: Union[LatchMetadata, Callable],
) -> Union[PythonFunctionWorkflow, Callable]:
if isinstance(metadata, Callable):
f = metadata
Expand Down Expand Up @@ -81,16 +118,7 @@ def decorator(f: Callable):
if meta_param.samplesheet is not True:
continue

annotation = wf_params[name].annotation

origin = get_origin(annotation)
args = get_args(annotation)
valid = (
origin is not None
and issubclass(origin, list)
and is_dataclass(args[0])
)
if not valid:
if not _is_valid_samplesheet_parameter_type(wf_params[name].annotation):
click.secho(
f"parameter marked as samplesheet is not valid: {name} "
f"in workflow {f.__name__} must be a list of dataclasses",
Expand All @@ -108,3 +136,142 @@ def decorator(f: Callable):
return _workflow(f, wf_name_override=wf_name_override)

return decorator


def _is_valid_samplesheet_parameter_type(annotation: Any) -> TypeGuard[TypeAnnotation]:
"""Check if a workflow parameter is hinted with a valid type for a samplesheet LatchParameter.

Currently, a samplesheet LatchParameter must be defined as a list of dataclasses, or as an
`Optional` list of dataclasses when the parameter is part of a `ForkBranch`.

Args:
parameter: A parameter from the workflow function's signature.

Returns:
True if the parameter is annotated as a list of dataclasses, or as an `Optional` list of
dataclasses.
False otherwise.
"""
# If the parameter did not have a type annotation, short-circuit and return False
if not _is_type_annotation(annotation):
return False
Comment on lines +155 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the wrong place for this check, i cant remember off the top of my head if we throw a semantic error already for not having a type annotation but we should do this outside of a samplesheet specific check (itll always be an error regardless of whether or not the specific param is a samplesheet)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the project does not enforce any type checking, I would prefer to include runtime checks on all argument types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm happy to make this raise a TypeError instead, though in context I thought returning False was sensible.)


return _is_list_of_dataclasses_type(annotation) or (
_is_optional_type(annotation)
and _is_list_of_dataclasses_type(_unpack_optional_type(annotation))
)


def _is_list_of_dataclasses_type(dtype: TypeAnnotation) -> bool:
"""Check if the type is a list of dataclasses.

Args:
dtype: A type.

Returns:
True if the type is a list of dataclasses.
False otherwise.

Raises:
TypeError: If the input is not a valid `TypeAnnotation` type (see above).
"""
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}")

origin = get_origin(dtype)
args = get_args(dtype)

return (
origin is not None
and inspect.isclass(origin)
and issubclass(origin, list)
and len(args) == 1
and is_dataclass(args[0])
)


def _is_optional_type(dtype: TypeAnnotation) -> bool:
"""Check if a type is `Optional`.

An optional type may be declared using three syntaxes: `Optional[T]`, `Union[T, None]`, or `T |
None`. All of these syntaxes is supported by this function.

Args:
dtype: A type.

Returns:
True if the type is a union type with exactly two elements, one of which is `None`.
False otherwise.

Raises:
TypeError: If the input is not a valid `TypeAnnotation` type (see above).
"""
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}")

origin = get_origin(dtype)
args = get_args(dtype)

# Optional[T] has `typing.Union` as its origin, but PEP604 syntax (e.g. `int | None`) has
# `types.UnionType` as its origin.
return (
origin is not None
and (origin is Union or origin is UnionType)
and len(args) == 2
and type(None) in args
)


def _unpack_optional_type(dtype: TypeAnnotation) -> type:
"""Given a type of `Optional[T]`, return `T`.

Args:
dtype: A type of `Optional[T]`, `T | None`, or `Union[T, None]`.

Returns:
The type `T`.

Raises:
TypeError: If the input is not a valid `TypeAnnotation` type (see above).
ValueError: If the input type is not `Optional[T]`.
"""
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}")

if not _is_optional_type(dtype):
raise ValueError(f"Expected `Optional[T]`, got {type(dtype)}: {dtype}")

# Types declared as `Optional[T]` or `T | None` should have the non-None type as the first
# argument. However, it is technically correct for someone to write `None | T`, so we shouldn't
# make assumptions about the argument ordering. (And I'm not certain the ordering is guaranteed
# anywhere by Python spec.)
base_type = [arg for arg in get_args(dtype) if arg is not type(None)][0]

return base_type


# NB: `inspect.Parameter.annotation` is typed as `Any`, so here we narrow the type.
def _is_type_annotation(annotation: Any) -> TypeGuard[TypeAnnotation]:
"""Check if the annotation on an `inspect.Parameter` instance is a type annotation.

If the corresponding parameter **did not** have a type annotation, `annotation` is set to the
special class variable `inspect.Parameter.empty`. Otherwise, the annotation should be a valid
type annotation.

Args:
annotation: The annotation on an `inspect.Parameter` instance.

Returns:
True if the type annotation is not `inspect.Parameter.empty`.
False otherwise.

Raises:
TypeError: If the annotation is neither a valid `TypeAnnotation` type (see above) nor
`inspect.Parameter.empty`.
"""
# NB: `inspect.Parameter.empty` is a subclass of `type`, so this check passes for unannotated
# parameters.
if not isinstance(annotation, TYPE_ANNOTATION_TYPES):
raise TypeError(f"Annotation must be a type, not {type(annotation).__name__}")

return annotation is not inspect.Parameter.empty
Empty file added tests/resources/__init__.py
Empty file.
154 changes: 154 additions & 0 deletions tests/resources/test_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import inspect
import sys
from dataclasses import dataclass
from typing import List
from typing import Any
from typing import Collection, Iterable, Optional, Union, Mapping, Dict, Set, Tuple

import pytest

from latch.resources.workflow import _is_list_of_dataclasses_type
from latch.resources.workflow import _is_valid_samplesheet_parameter_type
from latch.resources.workflow import _is_optional_type
from latch.resources.workflow import _is_type_annotation
from latch.resources.workflow import _unpack_optional_type
from latch.resources.workflow import TypeAnnotation

PRIMITIVE_TYPES: List[type] = [int, float, bool, str]
COLLECTION_TYPES: List[TypeAnnotation] = [List[int], Dict[str, int], Set[int], Tuple[int], Mapping[str, int], Iterable[int], Collection[int]]

if sys.version_info >= (3, 10):
COLLECTION_TYPES += [list[int], dict[str, int], set[int], tuple[int]]

OPTIONAL_TYPES: List[TypeAnnotation] = [Optional[T] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)]
OPTIONAL_TYPES += [Union[T, None] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)]

if sys.version_info >= (3, 10):
OPTIONAL_TYPES += [T | None for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)]


@dataclass
class FakeDataclass:
"""A dataclass for testing."""
foo: str
bar: int


# Enumerate the possible ways to declare a list or optional list of dataclasses
SAMPLESHEET_TYPES: List[TypeAnnotation] = [
List[FakeDataclass],
Optional[List[FakeDataclass]],
Union[List[FakeDataclass], None],
]

if sys.version_info >= (3, 10):
SAMPLESHEET_TYPES += [
list[FakeDataclass],
Optional[list[FakeDataclass]],
Union[list[FakeDataclass], None],
list[FakeDataclass] | None,
List[FakeDataclass] | None,
]

BAD_SAMPLESHEET_TYPES: List[TypeAnnotation] = PRIMITIVE_TYPES + COLLECTION_TYPES + OPTIONAL_TYPES
BAD_SAMPLESHEET_TYPES += [
Union[List[FakeDataclass], int],
Union[Optional[List[FakeDataclass]], int],
Optional[Union[List[FakeDataclass], int]],
Union[List[FakeDataclass], Optional[int]],
Union[Optional[int], List[FakeDataclass]],
List[Union[FakeDataclass, int]],
List[Optional[FakeDataclass]],
]


@pytest.mark.parametrize("dtype", SAMPLESHEET_TYPES)
def test_is_valid_samplesheet_parameter_type(dtype: TypeAnnotation) -> None:
"""
`_is_valid_samplesheet_parameter_type` should accept a type that is a list of dataclasses, or an
`Optional` list of dataclasses.
"""
assert _is_valid_samplesheet_parameter_type(dtype) is True


@pytest.mark.parametrize("dtype", BAD_SAMPLESHEET_TYPES)
def test_is_valid_samplesheet_parameter_type_rejects_invalid_types(dtype: TypeAnnotation) -> None:
"""
`_is_valid_samplesheet_parameter_type` should reject any other type.
"""
assert _is_valid_samplesheet_parameter_type(dtype) is False


def test_is_list_of_dataclasses_type() -> None:
"""
`_is_list_of_dataclasses_type` should accept a type that is a list of dataclasses.
"""
assert _is_list_of_dataclasses_type(List[FakeDataclass]) is True


@pytest.mark.parametrize("bad_type", [
str, # Not a list
int, # Not a list
List[str], # Not a list of dataclasses
List[int], # Not a list of dataclasses
FakeDataclass, # Not a list
])
def test_is_list_of_dataclasses_type_rejects_bad_type(bad_type: type) -> None:
"""
`_is_list_of_dataclasses_type` should reject anything else.
"""
assert _is_list_of_dataclasses_type(bad_type) is False


def test_is_list_of_dataclasses_type_raises_if_not_a_type() -> None:
"""
`is_list_of_dataclasses_type` should raise a `TypeError` if the input is not a type.
"""
with pytest.raises(TypeError):
_is_list_of_dataclasses_type([FakeDataclass("hello", 1)])


@pytest.mark.parametrize("dtype", OPTIONAL_TYPES)
def test_is_optional_type(dtype: TypeAnnotation) -> None:
"""`_is_optional_type` should return True for `Optional[T]` types."""
assert _is_optional_type(dtype) is True


@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES)
def test_is_optional_type_returns_false_if_not_optional(dtype: TypeAnnotation) -> None:
"""`_is_optional_type` should return False for non-Optional types."""
assert _is_optional_type(dtype) is False


@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES)
def test_unpack_optional_type(dtype: TypeAnnotation) -> None:
"""`_unpack_optional_type()` should return the base type of `Optional[T]` types."""
assert _unpack_optional_type(Optional[dtype]) is dtype
assert _unpack_optional_type(Union[dtype, None]) is dtype
if sys.version_info >= (3, 10):
assert _unpack_optional_type(dtype | None) is dtype



@pytest.mark.parametrize("annotation", PRIMITIVE_TYPES + COLLECTION_TYPES + OPTIONAL_TYPES)
def test_is_type_annotation(annotation: TypeAnnotation) -> None:
"""
`_is_type_annotation()` should return True for any valid type annotation.
"""
assert _is_type_annotation(annotation) is True


def test_is_type_annotation_returns_false_if_empty() -> None:
"""
`_is_type_annotation()` should only return False if the annotation is `Parameter.empty`.
"""
assert _is_type_annotation(inspect.Parameter.empty) is False


@pytest.mark.parametrize("bad_annotation", [1, "abc", [1, 2], {"foo": 1}, FakeDataclass("hello", 1)])
def test_is_type_annotation_raises_if_annotation_is_not_a_type(bad_annotation: Any) -> None:
"""
`_is_type_annotation()` should raise `TypeError` for any non-type object.
"""
with pytest.raises(TypeError):
_is_type_annotation(bad_annotation)