diff --git a/latch/resources/workflow.py b/latch/resources/workflow.py index 14f84a222..e742e0195 100644 --- a/latch/resources/workflow.py +++ b/latch/resources/workflow.py @@ -1,7 +1,16 @@ import inspect +import sys from dataclasses import is_dataclass from textwrap import dedent -from typing import Callable, Dict, Union, get_args, get_origin +from typing import Any +from typing import Callable +from typing import Dict +from typing import Union +from typing import _GenericAlias # type: ignore[attr-defined] # NB: mypy can't see private attrs +from typing import get_args +from typing import get_origin +from typing_extensions import TypeAlias +from typing_extensions import TypeGuard import click import os @@ -12,6 +21,34 @@ from latch_cli.utils import best_effort_display_name +if sys.version_info >= (3, 10): + from types import UnionType +else: + # NB: `types.UnionType`, available since Python 3.10, is **not** a `type`, but is a class. + # We declare an empty class here to use in the instance checks below. + class UnionType: + pass + + +TypeAnnotation: TypeAlias = Union[type, _GenericAlias, UnionType] +""" +A function parameter's type annotation may be any of the following: + 1) `type`, when declaring any of the built-in Python types + 2) `typing._GenericAlias`, when declaring generic collection types or union types using pre-PEP + 585 and pre-PEP 604 syntax (e.g. `List[int]`, `Optional[int]`, or `Union[int, None]`) + 3) `types.UnionType`, when declaring union types using PEP604 syntax (e.g. `int | None`) + 4) `types.GenericAlias`, when declaring generic collection types using PEP 585 syntax (e.g. + `list[int]`) + +`types.GenericAlias` is a subclass of `type`, but `typing._GenericAlias` and `types.UnionType` are +not and must be considered explicitly. +""" + +# TODO When dropping support for Python 3.9, deprecate this in favor of performing instance checks +# directly on the `TypeAnnotation` union type. +TYPE_ANNOTATION_TYPES = (type, _GenericAlias, UnionType) # type: ignore[attr-defined] + + def _generate_metadata(f: Callable) -> LatchMetadata: signature = inspect.signature(f) metadata = LatchMetadata(f.__name__, LatchAuthor()) @@ -33,7 +70,7 @@ def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None: # so that when users call @workflow without any arguments or # parentheses, the workflow still serializes as expected def workflow( - metadata: Union[LatchMetadata, Callable] + metadata: Union[LatchMetadata, Callable], ) -> Union[PythonFunctionWorkflow, Callable]: if isinstance(metadata, Callable): f = metadata @@ -81,16 +118,7 @@ def decorator(f: Callable): if meta_param.samplesheet is not True: continue - annotation = wf_params[name].annotation - - origin = get_origin(annotation) - args = get_args(annotation) - valid = ( - origin is not None - and issubclass(origin, list) - and is_dataclass(args[0]) - ) - if not valid: + if not _is_valid_samplesheet_parameter_type(wf_params[name].annotation): click.secho( f"parameter marked as samplesheet is not valid: {name} " f"in workflow {f.__name__} must be a list of dataclasses", @@ -108,3 +136,142 @@ def decorator(f: Callable): return _workflow(f, wf_name_override=wf_name_override) return decorator + + +def _is_valid_samplesheet_parameter_type(annotation: Any) -> TypeGuard[TypeAnnotation]: + """Check if a workflow parameter is hinted with a valid type for a samplesheet LatchParameter. + + Currently, a samplesheet LatchParameter must be defined as a list of dataclasses, or as an + `Optional` list of dataclasses when the parameter is part of a `ForkBranch`. + + Args: + parameter: A parameter from the workflow function's signature. + + Returns: + True if the parameter is annotated as a list of dataclasses, or as an `Optional` list of + dataclasses. + False otherwise. + """ + # If the parameter did not have a type annotation, short-circuit and return False + if not _is_type_annotation(annotation): + return False + + return _is_list_of_dataclasses_type(annotation) or ( + _is_optional_type(annotation) + and _is_list_of_dataclasses_type(_unpack_optional_type(annotation)) + ) + + +def _is_list_of_dataclasses_type(dtype: TypeAnnotation) -> bool: + """Check if the type is a list of dataclasses. + + Args: + dtype: A type. + + Returns: + True if the type is a list of dataclasses. + False otherwise. + + Raises: + TypeError: If the input is not a valid `TypeAnnotation` type (see above). + """ + if not isinstance(dtype, TYPE_ANNOTATION_TYPES): + raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}") + + origin = get_origin(dtype) + args = get_args(dtype) + + return ( + origin is not None + and inspect.isclass(origin) + and issubclass(origin, list) + and len(args) == 1 + and is_dataclass(args[0]) + ) + + +def _is_optional_type(dtype: TypeAnnotation) -> bool: + """Check if a type is `Optional`. + + An optional type may be declared using three syntaxes: `Optional[T]`, `Union[T, None]`, or `T | + None`. All of these syntaxes is supported by this function. + + Args: + dtype: A type. + + Returns: + True if the type is a union type with exactly two elements, one of which is `None`. + False otherwise. + + Raises: + TypeError: If the input is not a valid `TypeAnnotation` type (see above). + """ + if not isinstance(dtype, TYPE_ANNOTATION_TYPES): + raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}") + + origin = get_origin(dtype) + args = get_args(dtype) + + # Optional[T] has `typing.Union` as its origin, but PEP604 syntax (e.g. `int | None`) has + # `types.UnionType` as its origin. + return ( + origin is not None + and (origin is Union or origin is UnionType) + and len(args) == 2 + and type(None) in args + ) + + +def _unpack_optional_type(dtype: TypeAnnotation) -> type: + """Given a type of `Optional[T]`, return `T`. + + Args: + dtype: A type of `Optional[T]`, `T | None`, or `Union[T, None]`. + + Returns: + The type `T`. + + Raises: + TypeError: If the input is not a valid `TypeAnnotation` type (see above). + ValueError: If the input type is not `Optional[T]`. + """ + if not isinstance(dtype, TYPE_ANNOTATION_TYPES): + raise TypeError(f"Expected type annotation, got {type(dtype)}: {dtype}") + + if not _is_optional_type(dtype): + raise ValueError(f"Expected `Optional[T]`, got {type(dtype)}: {dtype}") + + # Types declared as `Optional[T]` or `T | None` should have the non-None type as the first + # argument. However, it is technically correct for someone to write `None | T`, so we shouldn't + # make assumptions about the argument ordering. (And I'm not certain the ordering is guaranteed + # anywhere by Python spec.) + base_type = [arg for arg in get_args(dtype) if arg is not type(None)][0] + + return base_type + + +# NB: `inspect.Parameter.annotation` is typed as `Any`, so here we narrow the type. +def _is_type_annotation(annotation: Any) -> TypeGuard[TypeAnnotation]: + """Check if the annotation on an `inspect.Parameter` instance is a type annotation. + + If the corresponding parameter **did not** have a type annotation, `annotation` is set to the + special class variable `inspect.Parameter.empty`. Otherwise, the annotation should be a valid + type annotation. + + Args: + annotation: The annotation on an `inspect.Parameter` instance. + + Returns: + True if the type annotation is not `inspect.Parameter.empty`. + False otherwise. + + Raises: + TypeError: If the annotation is neither a valid `TypeAnnotation` type (see above) nor + `inspect.Parameter.empty`. + """ + # NB: `inspect.Parameter.empty` is a subclass of `type`, so this check passes for unannotated + # parameters. + if not isinstance(annotation, TYPE_ANNOTATION_TYPES): + raise TypeError(f"Annotation must be a type, not {type(annotation).__name__}") + + return annotation is not inspect.Parameter.empty diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/resources/test_workflow.py b/tests/resources/test_workflow.py new file mode 100644 index 000000000..da6d0fd36 --- /dev/null +++ b/tests/resources/test_workflow.py @@ -0,0 +1,154 @@ +import inspect +import sys +from dataclasses import dataclass +from typing import List +from typing import Any +from typing import Collection, Iterable, Optional, Union, Mapping, Dict, Set, Tuple + +import pytest + +from latch.resources.workflow import _is_list_of_dataclasses_type +from latch.resources.workflow import _is_valid_samplesheet_parameter_type +from latch.resources.workflow import _is_optional_type +from latch.resources.workflow import _is_type_annotation +from latch.resources.workflow import _unpack_optional_type +from latch.resources.workflow import TypeAnnotation + +PRIMITIVE_TYPES: List[type] = [int, float, bool, str] +COLLECTION_TYPES: List[TypeAnnotation] = [List[int], Dict[str, int], Set[int], Tuple[int], Mapping[str, int], Iterable[int], Collection[int]] + +if sys.version_info >= (3, 10): + COLLECTION_TYPES += [list[int], dict[str, int], set[int], tuple[int]] + +OPTIONAL_TYPES: List[TypeAnnotation] = [Optional[T] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)] +OPTIONAL_TYPES += [Union[T, None] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)] + +if sys.version_info >= (3, 10): + OPTIONAL_TYPES += [T | None for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)] + + +@dataclass +class FakeDataclass: + """A dataclass for testing.""" + foo: str + bar: int + + +# Enumerate the possible ways to declare a list or optional list of dataclasses +SAMPLESHEET_TYPES: List[TypeAnnotation] = [ + List[FakeDataclass], + Optional[List[FakeDataclass]], + Union[List[FakeDataclass], None], +] + +if sys.version_info >= (3, 10): + SAMPLESHEET_TYPES += [ + list[FakeDataclass], + Optional[list[FakeDataclass]], + Union[list[FakeDataclass], None], + list[FakeDataclass] | None, + List[FakeDataclass] | None, + ] + +BAD_SAMPLESHEET_TYPES: List[TypeAnnotation] = PRIMITIVE_TYPES + COLLECTION_TYPES + OPTIONAL_TYPES +BAD_SAMPLESHEET_TYPES += [ + Union[List[FakeDataclass], int], + Union[Optional[List[FakeDataclass]], int], + Optional[Union[List[FakeDataclass], int]], + Union[List[FakeDataclass], Optional[int]], + Union[Optional[int], List[FakeDataclass]], + List[Union[FakeDataclass, int]], + List[Optional[FakeDataclass]], +] + + +@pytest.mark.parametrize("dtype", SAMPLESHEET_TYPES) +def test_is_valid_samplesheet_parameter_type(dtype: TypeAnnotation) -> None: + """ + `_is_valid_samplesheet_parameter_type` should accept a type that is a list of dataclasses, or an + `Optional` list of dataclasses. + """ + assert _is_valid_samplesheet_parameter_type(dtype) is True + + +@pytest.mark.parametrize("dtype", BAD_SAMPLESHEET_TYPES) +def test_is_valid_samplesheet_parameter_type_rejects_invalid_types(dtype: TypeAnnotation) -> None: + """ + `_is_valid_samplesheet_parameter_type` should reject any other type. + """ + assert _is_valid_samplesheet_parameter_type(dtype) is False + + +def test_is_list_of_dataclasses_type() -> None: + """ + `_is_list_of_dataclasses_type` should accept a type that is a list of dataclasses. + """ + assert _is_list_of_dataclasses_type(List[FakeDataclass]) is True + + +@pytest.mark.parametrize("bad_type", [ + str, # Not a list + int, # Not a list + List[str], # Not a list of dataclasses + List[int], # Not a list of dataclasses + FakeDataclass, # Not a list +]) +def test_is_list_of_dataclasses_type_rejects_bad_type(bad_type: type) -> None: + """ + `_is_list_of_dataclasses_type` should reject anything else. + """ + assert _is_list_of_dataclasses_type(bad_type) is False + + +def test_is_list_of_dataclasses_type_raises_if_not_a_type() -> None: + """ + `is_list_of_dataclasses_type` should raise a `TypeError` if the input is not a type. + """ + with pytest.raises(TypeError): + _is_list_of_dataclasses_type([FakeDataclass("hello", 1)]) + + +@pytest.mark.parametrize("dtype", OPTIONAL_TYPES) +def test_is_optional_type(dtype: TypeAnnotation) -> None: + """`_is_optional_type` should return True for `Optional[T]` types.""" + assert _is_optional_type(dtype) is True + + +@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES) +def test_is_optional_type_returns_false_if_not_optional(dtype: TypeAnnotation) -> None: + """`_is_optional_type` should return False for non-Optional types.""" + assert _is_optional_type(dtype) is False + + +@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES) +def test_unpack_optional_type(dtype: TypeAnnotation) -> None: + """`_unpack_optional_type()` should return the base type of `Optional[T]` types.""" + assert _unpack_optional_type(Optional[dtype]) is dtype + assert _unpack_optional_type(Union[dtype, None]) is dtype + if sys.version_info >= (3, 10): + assert _unpack_optional_type(dtype | None) is dtype + + + +@pytest.mark.parametrize("annotation", PRIMITIVE_TYPES + COLLECTION_TYPES + OPTIONAL_TYPES) +def test_is_type_annotation(annotation: TypeAnnotation) -> None: + """ + `_is_type_annotation()` should return True for any valid type annotation. + """ + assert _is_type_annotation(annotation) is True + + +def test_is_type_annotation_returns_false_if_empty() -> None: + """ + `_is_type_annotation()` should only return False if the annotation is `Parameter.empty`. + """ + assert _is_type_annotation(inspect.Parameter.empty) is False + + +@pytest.mark.parametrize("bad_annotation", [1, "abc", [1, 2], {"foo": 1}, FakeDataclass("hello", 1)]) +def test_is_type_annotation_raises_if_annotation_is_not_a_type(bad_annotation: Any) -> None: + """ + `_is_type_annotation()` should raise `TypeError` for any non-type object. + """ + with pytest.raises(TypeError): + _is_type_annotation(bad_annotation) \ No newline at end of file