Skip to content

Commit 939d3e1

Browse files
committed
feat: Support Optional samplesheet types
1 parent f60f342 commit 939d3e1

File tree

2 files changed

+282
-6
lines changed

2 files changed

+282
-6
lines changed

latch/resources/workflow.py

+158-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import inspect
2+
import sys
3+
import typing
24
from dataclasses import is_dataclass
35
from textwrap import dedent
4-
from typing import Callable, Dict, Union, get_args, get_origin
6+
from types import UnionType
7+
from typing import Any, Callable, Dict, Union, get_args, get_origin
58

69
import click
710
import os
@@ -12,6 +15,42 @@
1215
from latch_cli.utils import best_effort_display_name
1316

1417

18+
if sys.version_info >= (3, 10):
19+
from typing import TypeAlias
20+
from typing import TypeGuard
21+
from types import UnionType
22+
else:
23+
from typing_extensions import TypeAlias
24+
from typing_extensions import TypeGuard
25+
26+
# NB: `types.UnionType`, available since Python 3.10, is **not** a `type`, but is a class.
27+
# We declare an empty class here to use in the instance checks below.
28+
class UnionType:
29+
pass
30+
31+
32+
# NB: since `_GenericAlias` is a private attribute of the `typing` module, mypy doesn't find it
33+
TypeAnnotation: TypeAlias = Union[type, typing._GenericAlias, UnionType] # type: ignore[name-defined]
34+
"""
35+
A function parameter's type annotation may be any of the following:
36+
1) `type`, when declaring any of the built-in Python types
37+
2) `typing._GenericAlias`, when declaring generic collection types or union types using pre-PEP
38+
585 and pre-PEP 604 syntax (e.g. `List[int]`, `Optional[int]`, or `Union[int, None]`)
39+
3) `types.UnionType`, when declaring union types using PEP604 syntax (e.g. `int | None`)
40+
4) `types.GenericAlias`, when declaring generic collection types using PEP 585 syntax (e.g.
41+
`list[int]`)
42+
43+
`types.GenericAlias` is a subclass of `type`, but `typing._GenericAlias` and `types.UnionType` are
44+
not and must be considered explicitly.
45+
"""
46+
47+
# TODO When dropping support for Python 3.9, deprecate this in favor of performing instance checks
48+
# directly on the `TypeAnnotation` union type.
49+
# NB: since `_GenericAlias` is a private attribute of the `typing` module, mypy doesn't find it
50+
TYPE_ANNOTATION_TYPES = (type, typing._GenericAlias, UnionType) # type: ignore[attr-defined]
51+
52+
53+
1554
def _generate_metadata(f: Callable) -> LatchMetadata:
1655
signature = inspect.signature(f)
1756
metadata = LatchMetadata(f.__name__, LatchAuthor())
@@ -108,12 +147,125 @@ def _is_valid_samplesheet_parameter_type(parameter: inspect.Parameter) -> bool:
108147
"""
109148
annotation = parameter.annotation
110149

111-
origin = get_origin(annotation)
112-
args = get_args(annotation)
113-
valid = (
114-
origin is not None
150+
# If the parameter did not have a type annotation, short-circuit and return False
151+
if not _is_type_annotation(annotation):
152+
return False
153+
154+
return (
155+
_is_list_of_dataclasses_type(annotation)
156+
or (_is_optional_type(annotation) and _is_list_of_dataclasses_type(_unpack_optional_type(annotation)))
157+
)
158+
159+
160+
def _is_list_of_dataclasses_type(dtype: TypeAnnotation) -> bool:
161+
"""
162+
Check if the type is a list of dataclasses.
163+
164+
Args:
165+
dtype: A type.
166+
167+
Returns:
168+
True if the type is a list of dataclasses.
169+
False otherwise.
170+
171+
Raises:
172+
TypeError: If the input is not a `type`.
173+
"""
174+
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
175+
raise TypeError(f"Expected `type`, got {type(dtype)}: {dtype}")
176+
177+
origin = get_origin(dtype)
178+
args = get_args(dtype)
179+
180+
return (
181+
not _is_optional_type(dtype)
182+
and origin is not None
115183
and issubclass(origin, list)
184+
and len(args) == 1
116185
and is_dataclass(args[0])
117186
)
118187

119-
return valid
188+
189+
def _is_optional_type(dtype: TypeAnnotation) -> bool:
190+
"""
191+
Check if a type is `Optional`.
192+
193+
An optional type may be declared using three syntaxes: `Optional[T]`, `Union[T, None]`, or `T |
194+
None`. All of these syntaxes is supported by this function.
195+
196+
Args:
197+
dtype: A type.
198+
199+
Returns:
200+
True if the type is a union type with exactly two elements, one of which is `None`.
201+
False otherwise.
202+
203+
Raises:
204+
TypeError: If the input is not a `type`.
205+
"""
206+
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
207+
raise TypeError(f"Expected `type`, got {type(dtype)}: {dtype}")
208+
209+
origin = get_origin(dtype)
210+
args = get_args(dtype)
211+
212+
# Optional[T] has `typing.Union` as its origin, but PEP604 syntax (e.g. `int | None`) has
213+
# `types.UnionType` as its origin.
214+
return (origin is Union or origin is UnionType) and len(args) == 2 and type(None) in args
215+
216+
217+
def _unpack_optional_type(dtype: TypeAnnotation) -> type:
218+
"""
219+
Given a type of `Optional[T]`, return `T`.
220+
221+
Args:
222+
dtype: A type of `Optional[T]`, `T | None`, or `Union[T, None]`.
223+
224+
Returns:
225+
The type `T`.
226+
227+
Raises:
228+
TypeError: If the input is not a `type`.
229+
ValueError: If the input type is not `Optional[T]`.
230+
"""
231+
if not isinstance(dtype, TYPE_ANNOTATION_TYPES):
232+
raise TypeError(f"Expected `type`, got {type(dtype)}: {dtype}")
233+
234+
if not _is_optional_type(dtype):
235+
raise ValueError(f"Expected Optional[T], got {type(dtype)}: {dtype}")
236+
237+
args = get_args(dtype)
238+
239+
# Types declared as `Optional[T]` or `T | None` should have the non-None type as the first
240+
# argument. However, it is technically correct for someone to write `None | T`, so we shouldn't
241+
# make assumptions about the argument ordering. (And I'm not certain the ordering is guaranteed
242+
# anywhere by Python spec.)
243+
base_type = [arg for arg in args if arg is not type(None)][0]
244+
245+
return base_type
246+
247+
248+
def _is_type_annotation(annotation: Any) -> TypeGuard[TypeAnnotation]:
249+
"""
250+
Check if the annotation on an `inspect.Parameter` instance is a type annotation.
251+
252+
If the corresponding parameter **did not** have a type annotation, `annotation` is set to the
253+
special class variable `Parameter.empty`.
254+
255+
NB: `Parameter.empty` itself is a subclass of `type`
256+
Otherwise, the annotation is assumed to be a type.
257+
258+
Args:
259+
annotation: The annotation on an `inspect.Parameter` instance.
260+
261+
Returns:
262+
True if the annotation is not `Parameter.empty`.
263+
False otherwise.
264+
265+
Raises:
266+
TypeError: If the annotation is neither a type nor `Parameter.empty`.
267+
"""
268+
if not isinstance(annotation, TYPE_ANNOTATION_TYPES):
269+
raise TypeError(f"Annotation must be a type, not {type(annotation).__name__}")
270+
271+
return annotation is not inspect.Parameter.empty

tests/resources/test_workflow.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import inspect
2+
import sys
3+
import typing
4+
from dataclasses import dataclass
5+
from typing import List
6+
from typing import Any
7+
from typing import Collection, Iterable, Optional, Union, Mapping, Dict, Set, Tuple
8+
9+
import pytest
10+
11+
from anno import _is_list_of_dataclasses_type
12+
from anno import _is_valid_samplesheet_parameter_type
13+
from anno import _is_optional_type
14+
from anno import _is_type_annotation
15+
from anno import _unpack_optional_type
16+
from anno import TypeAnnotation
17+
18+
PRIMITIVE_TYPES = [int, float, bool, str]
19+
COLLECTION_TYPES = [List[int], Dict[str, int], Set[int], Tuple[int], Mapping[str, int], Iterable[int], Collection[int]]
20+
21+
if sys.version_info >= (3, 10):
22+
COLLECTION_TYPES += [list[int], dict[str, int], set[int], tuple[int]]
23+
24+
OPTIONAL_TYPES = [Optional[T] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)]
25+
OPTIONAL_TYPES += [Union[T, None] for T in (PRIMITIVE_TYPES + COLLECTION_TYPES)]
26+
27+
28+
@dataclass
29+
class FakeDataclass:
30+
"""A dataclass for testing."""
31+
foo: str
32+
bar: int
33+
34+
35+
@pytest.mark.parametrize(
36+
"dtype",
37+
[
38+
List[FakeDataclass],
39+
Optional[List[FakeDataclass]],
40+
Union[List[FakeDataclass], None],
41+
]
42+
)
43+
def test_is_valid_samplesheet_parameter_type(dtype: TypeAnnotation) -> None:
44+
"""
45+
`_is_valid_samplesheet_parameter_type` should accept a type that is a list of dataclasses, or an
46+
`Optional` list of dataclasses.
47+
"""
48+
parameter = inspect.Parameter("foo", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dtype)
49+
assert _is_valid_samplesheet_parameter_type(parameter) is True
50+
51+
52+
def test_is_list_of_dataclasses_type() -> None:
53+
"""
54+
`_is_list_of_dataclasses_type` should accept a type that is a list of dataclasses.
55+
"""
56+
assert _is_list_of_dataclasses_type(List[FakeDataclass]) is True
57+
58+
59+
@pytest.mark.parametrize("bad_type", [
60+
str, # Not a list
61+
int, # Not a list
62+
List[str], # Not a list of dataclasses
63+
List[int], # Not a list of dataclasses
64+
FakeDataclass, # Not a list
65+
])
66+
def test_is_list_of_dataclasses_type_rejects_bad_type(bad_type: type) -> None:
67+
"""
68+
`_is_list_of_dataclasses_type` should reject anything else.
69+
"""
70+
assert _is_list_of_dataclasses_type(bad_type) is False
71+
72+
73+
def test_is_list_of_dataclasses_type_raises_if_not_a_type() -> None:
74+
"""
75+
`is_list_of_dataclasses_type` should raise a `TypeError` if the input is not a type.
76+
"""
77+
with pytest.raises(TypeError):
78+
_is_list_of_dataclasses_type([FakeDataclass("hello", 1)])
79+
80+
81+
@pytest.mark.parametrize("dtype", OPTIONAL_TYPES)
82+
def test_is_optional_type(dtype: TypeAnnotation) -> None:
83+
"""`_is_optional_type` should return True for `Optional[T]` types."""
84+
assert _is_optional_type(dtype) is True
85+
86+
87+
@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES)
88+
def test_is_optional_type_returns_false_if_not_optional(dtype: TypeAnnotation) -> None:
89+
"""`_is_optional_type` should return False for non-Optional types."""
90+
assert _is_optional_type(dtype) is False
91+
92+
93+
@pytest.mark.parametrize("dtype", PRIMITIVE_TYPES + COLLECTION_TYPES)
94+
def test_unpack_optional_type(dtype: TypeAnnotation) -> None:
95+
"""`_unpack_optional_type()` should return the base type of `Optional[T]` types."""
96+
assert _unpack_optional_type(Optional[dtype]) is dtype
97+
assert _unpack_optional_type(Union[dtype, None]) is dtype
98+
if sys.version_info >= (3, 10):
99+
assert _unpack_optional_type(dtype | None) is dtype
100+
101+
102+
103+
@pytest.mark.parametrize("annotation", PRIMITIVE_TYPES + COLLECTION_TYPES + OPTIONAL_TYPES)
104+
def test_is_type_annotation(annotation: TypeAnnotation) -> None:
105+
"""
106+
`_is_type_annotation()` should return True for any valid type annotation.
107+
"""
108+
assert _is_type_annotation(annotation) is True
109+
110+
111+
def test_is_type_annotation_returns_false_if_empty() -> None:
112+
"""
113+
`_is_type_annotation()` should only return False if the annotation is `Parameter.empty`.
114+
"""
115+
assert _is_type_annotation(inspect.Parameter.empty) is False
116+
117+
118+
@pytest.mark.parametrize("bad_annotation", [1, "abc", [1, 2], {"foo": 1}, FakeDataclass("hello", 1)])
119+
def test_is_type_annotation_raises_if_annotation_is_not_a_type(bad_annotation: Any) -> None:
120+
"""
121+
`_is_type_annotation()` should raise `TypeError` for any non-type object.
122+
"""
123+
with pytest.raises(TypeError):
124+
_is_type_annotation(bad_annotation)

0 commit comments

Comments
 (0)