1
1
from __future__ import annotations
2
2
3
3
import inspect
4
- from typing import Callable , List , TYPE_CHECKING
4
+ import typing as t
5
5
6
6
from .async_helper .output import AsyncPipelineOutput
7
7
from .sync_helper .output import PipelineOutput
8
8
9
- if TYPE_CHECKING :
9
+ if t . TYPE_CHECKING :
10
10
from .task import Task
11
11
12
12
13
- class Pipeline :
13
+ _P = t .ParamSpec ('P' )
14
+ _R = t .TypeVar ('R' )
15
+ _P_Other = t .ParamSpec ("P_Other" )
16
+ _R_Other = t .TypeVar ("R_Other" )
17
+
18
+
19
+ class Pipeline (t .Generic [_P , _R ]):
14
20
"""A sequence of at least 1 Tasks.
15
21
16
22
Two pipelines can be piped into another via:
@@ -21,59 +27,83 @@ class Pipeline:
21
27
```
22
28
"""
23
29
24
- def __new__ (cls , tasks : List [Task ]):
30
+ def __new__ (cls , tasks : t . List [Task ]):
25
31
if any (task .is_async for task in tasks ):
26
32
instance = object .__new__ (AsyncPipeline )
27
33
else :
28
34
instance = object .__new__ (cls )
29
35
instance .__init__ (tasks = tasks )
30
36
return instance
31
37
32
- def __init__ (self , tasks : List [Task ]):
38
+ def __init__ (self , tasks : t . List [Task ]):
33
39
self .tasks = tasks
34
40
35
- def __call__ (self , * args , ** kwargs ) :
41
+ def __call__ (self , * args : _P . args , ** kwargs : _P . kwargs ) -> t . Generator [ _R ] :
36
42
"""Return the pipeline output."""
37
43
output = PipelineOutput (self )
38
44
return output (* args , ** kwargs )
45
+
46
+ @t .overload
47
+ def pipe (self : AsyncPipeline [_P , _R ], other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
48
+
49
+ @t .overload
50
+ def pipe (self : AsyncPipeline [_P , _R ], other : Pipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
39
51
40
- def pipe (self , other ) -> Pipeline :
52
+ @t .overload
53
+ def pipe (self , other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
54
+
55
+ @t .overload
56
+ def pipe (self , other : Pipeline [_P_Other , _R_Other ]) -> Pipeline [_P , _R_Other ]: ...
57
+
58
+ def pipe (self , other : Pipeline ):
41
59
"""Connect two pipelines, returning a new Pipeline."""
42
60
if not isinstance (other , Pipeline ):
43
61
raise TypeError (f"{ other } of type { type (other )} cannot be piped into a Pipeline" )
44
62
return Pipeline (self .tasks + other .tasks )
45
63
46
- def __or__ (self , other : Pipeline ) -> Pipeline :
47
- """Allow the syntax `pipeline1 | pipeline2`."""
64
+ @t .overload
65
+ def __or__ (self : AsyncPipeline [_P , _R ], other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
66
+
67
+ @t .overload
68
+ def __or__ (self : AsyncPipeline [_P , _R ], other : Pipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
69
+
70
+ @t .overload
71
+ def __or__ (self , other : AsyncPipeline [_P_Other , _R_Other ]) -> AsyncPipeline [_P , _R_Other ]: ...
72
+
73
+ @t .overload
74
+ def __or__ (self , other : Pipeline [_P_Other , _R_Other ]) -> Pipeline [_P , _R_Other ]: ...
75
+
76
+ def __or__ (self , other : Pipeline ):
77
+ """Connect two pipelines, returning a new Pipeline."""
48
78
return self .pipe (other )
49
79
50
- def consume (self , other : Callable ) -> Callable :
80
+ def consume (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
51
81
"""Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
52
82
if callable (other ):
53
- def consumer (* args , ** kwargs ) :
83
+ def consumer (* args : _P . args , ** kwargs : _P . kwargs ) -> _R_Other :
54
84
return other (self (* args , ** kwargs ))
55
85
return consumer
56
86
raise TypeError (f"{ other } must be a callable that takes a generator" )
57
87
58
- def __gt__ (self , other : Callable ) -> Callable :
59
- """Allow the syntax ` pipeline > consumer` ."""
88
+ def __gt__ (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
89
+ """Connect the pipeline to a consumer function (a callable that takes the pipeline output as input) ."""
60
90
return self .consume (other )
61
91
62
92
def __repr__ (self ):
63
93
return f"{ self .__class__ .__name__ } { [task .func for task in self .tasks ]} "
64
94
65
95
66
- class AsyncPipeline (Pipeline ):
67
- def __call__ (self , * args , ** kwargs ) :
96
+ class AsyncPipeline (Pipeline [ _P , _R ] ):
97
+ def __call__ (self , * args : _P . args , ** kwargs : _P . kwargs ) -> t . AsyncGenerator [ _R ] :
68
98
"""Return the pipeline output."""
69
99
output = AsyncPipelineOutput (self )
70
100
return output (* args , ** kwargs )
71
-
72
- def consume (self , other : Callable ) -> Callable :
101
+
102
+ def consume (self , other : t . Callable [..., _R_Other ] ) -> t . Callable [ _P , _R_Other ] :
73
103
"""Connect the pipeline to a consumer function (a callable that takes the pipeline output as input)."""
74
104
if callable (other ) and \
75
105
(inspect .iscoroutinefunction (other ) or inspect .iscoroutinefunction (other .__call__ )):
76
- async def consumer (* args , ** kwargs ) :
106
+ async def consumer (* args : _P . args , ** kwargs : _P . kwargs ) -> _R_Other :
77
107
return await other (self (* args , ** kwargs ))
78
108
return consumer
79
109
raise TypeError (f"{ other } must be an async callable that takes an async generator" )
0 commit comments