Skip to content

Commit 8007071

Browse files
committed
add type hint system
1 parent ee80325 commit 8007071

File tree

3 files changed

+127
-32
lines changed

3 files changed

+127
-32
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
exclude_also =
33
# pragma: no cover
44
if TYPE_CHECKING:
5+
if t.TYPE_CHECKING:
56
raise NotImplementedError

src/pyper/_core/decorators.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,111 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Callable, Dict, Optional, overload, Tuple
4+
import typing as t
55

6-
from .pipeline import Pipeline
6+
from .pipeline import AsyncPipeline, Pipeline
77
from .task import Task
88

99

10+
_P = t.ParamSpec('P')
11+
_R = t.TypeVar('R')
12+
_ArgsKwargs: t.TypeAlias = t.Optional[t.Tuple[t.Tuple[t.Any], t.Dict[str, t.Any]]]
13+
14+
1015
class task:
1116
"""Decorator class to transform a function into a `Task` object, and then initialize a `Pipeline` with this task.
1217
A Pipeline initialized in this way consists of one Task, and can be piped into other Pipelines.
1318
1419
The behaviour of each task within a Pipeline is determined by the parameters:
15-
`join`: allows the function to take all previous results as input, instead of single results
16-
`concurrency`: runs the functions with multiple (async or threaded) workers
17-
`throttle`: limits the number of results the function is able to produce when all consumers are busy
20+
* `join`: allows the function to take all previous results as input, instead of single results
21+
* `concurrency`: runs the functions with multiple (async or threaded) workers
22+
* `throttle`: limits the number of results the function is able to produce when all consumers are busy
23+
* `daemon`: determines whether threaded workers are daemon threads (cannot be True for async tasks)
24+
* `bind`: additional args and kwargs to bind to the function when defining a pipeline
1825
"""
19-
@overload
20-
def __new__(cls, func: None = None, /, *, join: bool = False, concurrency: int = 1, throttle: int = 0, daemon: bool = False, bind: Optional[Tuple[Tuple, Dict]] = None) -> Callable[..., Pipeline]:
21-
"""Enable type hints for functions decorated with `@task()`."""
26+
@t.overload
27+
def __new__(
28+
cls,
29+
func: None = None,
30+
/,
31+
*,
32+
join: bool = False,
33+
concurrency: int = 1,
34+
throttle: int = 0,
35+
daemon: bool = False,
36+
bind: _ArgsKwargs = None
37+
) -> t.Type[task]: ...
2238

23-
@overload
24-
def __new__(cls, func: Callable, /, *, join: bool = False, concurrency: int = 1, throttle: int = 0, daemon: bool = False, bind: Optional[Tuple[Tuple, Dict]] = None) -> Pipeline:
25-
"""Enable type hints for functions decorated with `@task`."""
39+
@t.overload
40+
def __new__(
41+
cls,
42+
func: t.Callable[_P, t.Awaitable[_R]],
43+
/,
44+
*,
45+
join: bool = False,
46+
concurrency: int = 1,
47+
throttle: int = 0,
48+
daemon: bool = False,
49+
bind: _ArgsKwargs = None
50+
) -> AsyncPipeline[_P, _R]: ...
2651

52+
@t.overload
53+
def __new__(
54+
cls,
55+
func: t.Callable[_P, t.AsyncGenerator[_R]],
56+
/,
57+
*,
58+
join: bool = False,
59+
concurrency: int = 1,
60+
throttle: int = 0,
61+
daemon: bool = False,
62+
bind: _ArgsKwargs = None
63+
) -> AsyncPipeline[_P, _R]: ...
64+
65+
@t.overload
66+
def __new__(
67+
cls,
68+
func: t.Callable[_P, t.Generator[_R]],
69+
/,
70+
*,
71+
join: bool = False,
72+
concurrency: int = 1,
73+
throttle: int = 0,
74+
daemon: bool = False,
75+
bind: _ArgsKwargs = None
76+
) -> Pipeline[_P, _R]: ...
77+
78+
@t.overload
79+
def __new__(
80+
cls,
81+
func: t.Callable[_P, _R],
82+
/,
83+
*,
84+
join: bool = False,
85+
concurrency: int = 1,
86+
throttle: int = 0,
87+
daemon: bool = False,
88+
bind: _ArgsKwargs = None
89+
) -> Pipeline[_P, _R]: ...
90+
2791
def __new__(
2892
cls,
29-
func: Optional[Callable] = None,
93+
func: t.Optional[t.Callable] = None,
3094
/,
3195
*,
3296
join: bool = False,
3397
concurrency: int = 1,
3498
throttle: int = 0,
3599
daemon: bool = False,
36-
bind: Optional[Tuple[Tuple, Dict]] = None
100+
bind: _ArgsKwargs = None
37101
):
38102
# Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
39103
if func is None:
40104
return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)
41105
return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)])
42106

43107
@staticmethod
44-
def bind(*args, **kwargs) -> Optional[Tuple[Tuple, Dict]]:
108+
def bind(*args, **kwargs) -> _ArgsKwargs:
45109
"""Utility method, to be used with `functools.partial`."""
46110
if not args and not kwargs:
47111
return None

src/pyper/_core/pipeline.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

33
import inspect
4-
from typing import Callable, List, TYPE_CHECKING
4+
import typing as t
55

66
from .async_helper.output import AsyncPipelineOutput
77
from .sync_helper.output import PipelineOutput
88

9-
if TYPE_CHECKING:
9+
if t.TYPE_CHECKING:
1010
from .task import Task
1111

1212

13-
class Pipeline:
13+
_P = t.ParamSpec('P')
14+
_R = t.TypeVar('R')
15+
_P_Other = t.ParamSpec("P_Other")
16+
_R_Other = t.TypeVar("R_Other")
17+
18+
19+
class Pipeline(t.Generic[_P, _R]):
1420
"""A sequence of at least 1 Tasks.
1521
1622
Two pipelines can be piped into another via:
@@ -21,59 +27,83 @@ class Pipeline:
2127
```
2228
"""
2329

24-
def __new__(cls, tasks: List[Task]):
30+
def __new__(cls, tasks: t.List[Task]):
2531
if any(task.is_async for task in tasks):
2632
instance = object.__new__(AsyncPipeline)
2733
else:
2834
instance = object.__new__(cls)
2935
instance.__init__(tasks=tasks)
3036
return instance
3137

32-
def __init__(self, tasks: List[Task]):
38+
def __init__(self, tasks: t.List[Task]):
3339
self.tasks = tasks
3440

35-
def __call__(self, *args, **kwargs):
41+
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> t.Generator[_R]:
3642
"""Return the pipeline output."""
3743
output = PipelineOutput(self)
3844
return output(*args, **kwargs)
45+
46+
@t.overload
47+
def pipe(self: AsyncPipeline[_P, _R], other: AsyncPipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
48+
49+
@t.overload
50+
def pipe(self: AsyncPipeline[_P, _R], other: Pipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
3951

40-
def pipe(self, other) -> Pipeline:
52+
@t.overload
53+
def pipe(self, other: AsyncPipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
54+
55+
@t.overload
56+
def pipe(self, other: Pipeline[_P_Other, _R_Other]) -> Pipeline[_P, _R_Other]: ...
57+
58+
def pipe(self, other: Pipeline):
4159
"""Connect two pipelines, returning a new Pipeline."""
4260
if not isinstance(other, Pipeline):
4361
raise TypeError(f"{other} of type {type(other)} cannot be piped into a Pipeline")
4462
return Pipeline(self.tasks + other.tasks)
4563

46-
def __or__(self, other: Pipeline) -> Pipeline:
47-
"""Allow the syntax `pipeline1 | pipeline2`."""
64+
@t.overload
65+
def __or__(self: AsyncPipeline[_P, _R], other: AsyncPipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
66+
67+
@t.overload
68+
def __or__(self: AsyncPipeline[_P, _R], other: Pipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
69+
70+
@t.overload
71+
def __or__(self, other: AsyncPipeline[_P_Other, _R_Other]) -> AsyncPipeline[_P, _R_Other]: ...
72+
73+
@t.overload
74+
def __or__(self, other: Pipeline[_P_Other, _R_Other]) -> Pipeline[_P, _R_Other]: ...
75+
76+
def __or__(self, other: Pipeline):
77+
"""Connect two pipelines, returning a new Pipeline."""
4878
return self.pipe(other)
4979

50-
def consume(self, other: Callable) -> Callable:
80+
def consume(self, other: t.Callable[..., _R_Other]) -> t.Callable[_P, _R_Other]:
5181
"""Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
5282
if callable(other):
53-
def consumer(*args, **kwargs):
83+
def consumer(*args: _P.args, **kwargs: _P.kwargs) -> _R_Other:
5484
return other(self(*args, **kwargs))
5585
return consumer
5686
raise TypeError(f"{other} must be a callable that takes a generator")
5787

58-
def __gt__(self, other: Callable) -> Callable:
59-
"""Allow the syntax `pipeline > consumer`."""
88+
def __gt__(self, other: t.Callable[..., _R_Other]) -> t.Callable[_P, _R_Other]:
89+
"""Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
6090
return self.consume(other)
6191

6292
def __repr__(self):
6393
return f"{self.__class__.__name__} {[task.func for task in self.tasks]}"
6494

6595

66-
class AsyncPipeline(Pipeline):
67-
def __call__(self, *args, **kwargs):
96+
class AsyncPipeline(Pipeline[_P, _R]):
97+
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> t.AsyncGenerator[_R]:
6898
"""Return the pipeline output."""
6999
output = AsyncPipelineOutput(self)
70100
return output(*args, **kwargs)
71-
72-
def consume(self, other: Callable) -> Callable:
101+
102+
def consume(self, other: t.Callable[..., _R_Other]) -> t.Callable[_P, _R_Other]:
73103
"""Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
74104
if callable(other) and \
75105
(inspect.iscoroutinefunction(other) or inspect.iscoroutinefunction(other.__call__)):
76-
async def consumer(*args, **kwargs):
106+
async def consumer(*args: _P.args, **kwargs: _P.kwargs) -> _R_Other:
77107
return await other(self(*args, **kwargs))
78108
return consumer
79109
raise TypeError(f"{other} must be an async callable that takes an async generator")

0 commit comments

Comments
 (0)