|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import functools
|
4 |
| -from typing import Callable, Dict, Optional, overload, Tuple |
| 4 | +import typing as t |
5 | 5 |
|
6 |
| -from .pipeline import Pipeline |
| 6 | +from .pipeline import AsyncPipeline, Pipeline |
7 | 7 | from .task import Task
|
8 | 8 |
|
9 | 9 |
|
| 10 | +_P = t.ParamSpec('P') |
| 11 | +_R = t.TypeVar('R') |
| 12 | +_ArgsKwargs: t.TypeAlias = t.Optional[t.Tuple[t.Tuple[t.Any], t.Dict[str, t.Any]]] |
| 13 | + |
| 14 | + |
10 | 15 | class task:
|
11 | 16 | """Decorator class to transform a function into a `Task` object, and then initialize a `Pipeline` with this task.
|
12 | 17 | A Pipeline initialized in this way consists of one Task, and can be piped into other Pipelines.
|
13 | 18 |
|
14 | 19 | The behaviour of each task within a Pipeline is determined by the parameters:
|
15 |
| - `join`: allows the function to take all previous results as input, instead of single results |
16 |
| - `concurrency`: runs the functions with multiple (async or threaded) workers |
17 |
| - `throttle`: limits the number of results the function is able to produce when all consumers are busy |
| 20 | + * `join`: allows the function to take all previous results as input, instead of single results |
| 21 | + * `concurrency`: runs the functions with multiple (async or threaded) workers |
| 22 | + * `throttle`: limits the number of results the function is able to produce when all consumers are busy |
| 23 | + * `daemon`: determines whether threaded workers are daemon threads (cannot be True for async tasks) |
| 24 | + * `bind`: additional args and kwargs to bind to the function when defining a pipeline |
18 | 25 | """
|
19 |
| - @overload |
20 |
| - def __new__(cls, func: None = None, /, *, join: bool = False, concurrency: int = 1, throttle: int = 0, daemon: bool = False, bind: Optional[Tuple[Tuple, Dict]] = None) -> Callable[..., Pipeline]: |
21 |
| - """Enable type hints for functions decorated with `@task()`.""" |
| 26 | + @t.overload |
| 27 | + def __new__( |
| 28 | + cls, |
| 29 | + func: None = None, |
| 30 | + /, |
| 31 | + *, |
| 32 | + join: bool = False, |
| 33 | + concurrency: int = 1, |
| 34 | + throttle: int = 0, |
| 35 | + daemon: bool = False, |
| 36 | + bind: _ArgsKwargs = None |
| 37 | + ) -> t.Type[task]: ... |
22 | 38 |
|
23 |
| - @overload |
24 |
| - def __new__(cls, func: Callable, /, *, join: bool = False, concurrency: int = 1, throttle: int = 0, daemon: bool = False, bind: Optional[Tuple[Tuple, Dict]] = None) -> Pipeline: |
25 |
| - """Enable type hints for functions decorated with `@task`.""" |
| 39 | + @t.overload |
| 40 | + def __new__( |
| 41 | + cls, |
| 42 | + func: t.Callable[_P, t.Awaitable[_R]], |
| 43 | + /, |
| 44 | + *, |
| 45 | + join: bool = False, |
| 46 | + concurrency: int = 1, |
| 47 | + throttle: int = 0, |
| 48 | + daemon: bool = False, |
| 49 | + bind: _ArgsKwargs = None |
| 50 | + ) -> AsyncPipeline[_P, _R]: ... |
26 | 51 |
|
| 52 | + @t.overload |
| 53 | + def __new__( |
| 54 | + cls, |
| 55 | + func: t.Callable[_P, t.AsyncGenerator[_R]], |
| 56 | + /, |
| 57 | + *, |
| 58 | + join: bool = False, |
| 59 | + concurrency: int = 1, |
| 60 | + throttle: int = 0, |
| 61 | + daemon: bool = False, |
| 62 | + bind: _ArgsKwargs = None |
| 63 | + ) -> AsyncPipeline[_P, _R]: ... |
| 64 | + |
| 65 | + @t.overload |
| 66 | + def __new__( |
| 67 | + cls, |
| 68 | + func: t.Callable[_P, t.Generator[_R]], |
| 69 | + /, |
| 70 | + *, |
| 71 | + join: bool = False, |
| 72 | + concurrency: int = 1, |
| 73 | + throttle: int = 0, |
| 74 | + daemon: bool = False, |
| 75 | + bind: _ArgsKwargs = None |
| 76 | + ) -> Pipeline[_P, _R]: ... |
| 77 | + |
| 78 | + @t.overload |
| 79 | + def __new__( |
| 80 | + cls, |
| 81 | + func: t.Callable[_P, _R], |
| 82 | + /, |
| 83 | + *, |
| 84 | + join: bool = False, |
| 85 | + concurrency: int = 1, |
| 86 | + throttle: int = 0, |
| 87 | + daemon: bool = False, |
| 88 | + bind: _ArgsKwargs = None |
| 89 | + ) -> Pipeline[_P, _R]: ... |
| 90 | + |
27 | 91 | def __new__(
|
28 | 92 | cls,
|
29 |
| - func: Optional[Callable] = None, |
| 93 | + func: t.Optional[t.Callable] = None, |
30 | 94 | /,
|
31 | 95 | *,
|
32 | 96 | join: bool = False,
|
33 | 97 | concurrency: int = 1,
|
34 | 98 | throttle: int = 0,
|
35 | 99 | daemon: bool = False,
|
36 |
| - bind: Optional[Tuple[Tuple, Dict]] = None |
| 100 | + bind: _ArgsKwargs = None |
37 | 101 | ):
|
38 | 102 | # Classic decorator trick: @task() means func is None, @task without parentheses means func is passed.
|
39 | 103 | if func is None:
|
40 | 104 | return functools.partial(cls, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)
|
41 | 105 | return Pipeline([Task(func=func, join=join, concurrency=concurrency, throttle=throttle, daemon=daemon, bind=bind)])
|
42 | 106 |
|
43 | 107 | @staticmethod
|
44 |
| - def bind(*args, **kwargs) -> Optional[Tuple[Tuple, Dict]]: |
| 108 | + def bind(*args, **kwargs) -> _ArgsKwargs: |
45 | 109 | """Utility method, to be used with `functools.partial`."""
|
46 | 110 | if not args and not kwargs:
|
47 | 111 | return None
|
|
0 commit comments