12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from __future__ import annotations
16
+
15
17
import itertools
16
18
import logging
17
19
import multiprocessing
21
23
from itertools import zip_longest
22
24
from queue import Queue
23
25
from threading import Thread
26
+ from typing import (
27
+ TYPE_CHECKING ,
28
+ Any ,
29
+ Callable ,
30
+ Generator ,
31
+ Sequence ,
32
+ TypedDict ,
33
+ TypeVar ,
34
+ overload ,
35
+ )
36
+
37
+ from typing_extensions import NotRequired , TypeAlias , Unpack
24
38
25
39
from paddle .base .reader import QUEUE_GET_TIMEOUT
26
40
41
+ if TYPE_CHECKING :
42
+
43
+ class _ComposeOptions (TypedDict ):
44
+ check_alignment : NotRequired [bool ]
45
+
46
+
27
47
__all__ = []
28
48
29
49
# On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing,
41
61
else :
42
62
fork_context = multiprocessing
43
63
64
+ _T = TypeVar ('_T' )
65
+ _T1 = TypeVar ('_T1' )
66
+ _T2 = TypeVar ('_T2' )
67
+ _T3 = TypeVar ('_T3' )
68
+ _T4 = TypeVar ('_T4' )
69
+ _U = TypeVar ('_U' )
70
+
44
71
45
- def cache (reader ):
72
+ _Reader : TypeAlias = Callable [[], Generator [_T , None , None ]]
73
+
74
+
75
+ def cache (reader : _Reader [_T ]) -> _Reader [_T ]:
46
76
"""
47
77
Cache the reader data into memory.
48
78
@@ -77,12 +107,60 @@ def cache(reader):
77
107
"""
78
108
all_data = tuple (reader ())
79
109
80
- def __impl__ ():
110
+ def __impl__ () -> Generator [ _T , None , None ] :
81
111
yield from all_data
82
112
83
113
return __impl__
84
114
85
115
116
+ # A temporary solution like builtin map function.
117
+ # `Map` maybe the final solution in the future.
118
+ # See https://github.com/python/typing/issues/1383
119
+ @overload
120
+ def map_readers (
121
+ func : Callable [[_T1 ], _U ], reader1 : _Reader [_T1 ], /
122
+ ) -> _Reader [_U ]:
123
+ ...
124
+
125
+
126
+ @overload
127
+ def map_readers (
128
+ func : Callable [[_T1 , _T2 ], _U ],
129
+ reader1 : _Reader [_T1 ],
130
+ reader2 : _Reader [_T2 ],
131
+ / ,
132
+ ) -> _Reader [_U ]:
133
+ ...
134
+
135
+
136
+ @overload
137
+ def map_readers (
138
+ func : Callable [[_T1 , _T2 , _T3 ], _U ],
139
+ reader1 : _Reader [_T1 ],
140
+ reader2 : _Reader [_T2 ],
141
+ reader3 : _Reader [_T3 ],
142
+ / ,
143
+ ) -> _Reader [_U ]:
144
+ ...
145
+
146
+
147
+ @overload
148
+ def map_readers (
149
+ func : Callable [[_T1 , _T2 , _T3 , _T4 ], _U ],
150
+ reader1 : _Reader [_T1 ],
151
+ reader2 : _Reader [_T2 ],
152
+ reader3 : _Reader [_T3 ],
153
+ reader4 : _Reader [_T4 ],
154
+ / ,
155
+ ) -> _Reader [_U ]:
156
+ ...
157
+
158
+
159
+ @overload
160
+ def map_readers (func : Callable [..., _U ], * readers : _Reader [Any ]) -> _Reader [_U ]:
161
+ ...
162
+
163
+
86
164
def map_readers (func , * readers ):
87
165
"""
88
166
Creates a data reader that outputs return value of function using
@@ -124,7 +202,7 @@ def reader():
124
202
return reader
125
203
126
204
127
- def shuffle (reader , buf_size ) :
205
+ def shuffle (reader : _Reader [ _T ] , buf_size : int ) -> _Reader [ _T ] :
128
206
"""
129
207
This API creates a decorated reader that outputs the shuffled data.
130
208
@@ -151,7 +229,7 @@ def shuffle(reader, buf_size):
151
229
>>> # outputs are 0~4 unordered arrangement
152
230
"""
153
231
154
- def data_reader ():
232
+ def data_reader () -> Generator [ _T , None , None ] :
155
233
buf = []
156
234
for e in reader ():
157
235
buf .append (e )
@@ -169,7 +247,7 @@ def data_reader():
169
247
return data_reader
170
248
171
249
172
- def chain (* readers ) :
250
+ def chain (* readers : _Reader [ _T ]) -> _Reader [ _T ] :
173
251
"""
174
252
Use the input data readers to create a chained data reader. The new created reader
175
253
chains the outputs of input readers together as its output, and it do not change
@@ -218,8 +296,8 @@ def chain(*readers):
218
296
219
297
"""
220
298
221
- def reader ():
222
- rs = []
299
+ def reader () -> Generator [ _T , None , None ] :
300
+ rs : list [ Generator [ _T , None , None ]] = []
223
301
for r in readers :
224
302
rs .append (r ())
225
303
@@ -232,7 +310,9 @@ class ComposeNotAligned(ValueError):
232
310
pass
233
311
234
312
235
- def compose (* readers , ** kwargs ):
313
+ def compose (
314
+ * readers : _Reader [Any ], ** kwargs : Unpack [_ComposeOptions ]
315
+ ) -> _Reader [Any ]:
236
316
"""
237
317
Creates a data reader whose output is the combination of input readers.
238
318
@@ -289,7 +369,7 @@ def reader():
289
369
return reader
290
370
291
371
292
- def buffered (reader , size ) :
372
+ def buffered (reader : _Reader [ _T ] , size : int ) -> _Reader [ _T ] :
293
373
"""
294
374
Creates a buffered data reader.
295
375
@@ -339,10 +419,7 @@ def data_reader():
339
419
q = Queue (maxsize = size )
340
420
t = Thread (
341
421
target = read_worker ,
342
- args = (
343
- r ,
344
- q ,
345
- ),
422
+ args = (r , q ),
346
423
)
347
424
t .daemon = True
348
425
t .start ()
@@ -354,7 +431,7 @@ def data_reader():
354
431
return data_reader
355
432
356
433
357
- def firstn (reader , n ) :
434
+ def firstn (reader : _Reader [ _T ] , n : int ) -> _Reader [ _T ] :
358
435
"""
359
436
360
437
This API creates a decorated reader, and limits the max number of
@@ -399,7 +476,13 @@ class XmapEndSignal:
399
476
pass
400
477
401
478
402
- def xmap_readers (mapper , reader , process_num , buffer_size , order = False ):
479
+ def xmap_readers (
480
+ mapper : Callable [[_T ], _U ],
481
+ reader : _Reader [_T ],
482
+ process_num : int ,
483
+ buffer_size : int ,
484
+ order : bool = False ,
485
+ ) -> _Reader [_U ]:
403
486
"""
404
487
Use multi-threads to map samples from reader by a mapper defined by user.
405
488
@@ -495,7 +578,11 @@ def xreader():
495
578
return xreader
496
579
497
580
498
- def multiprocess_reader (readers , use_pipe = True , queue_size = 1000 ):
581
+ def multiprocess_reader (
582
+ readers : Sequence [_Reader [_T ]],
583
+ use_pipe : bool = True ,
584
+ queue_size : int = 1000 ,
585
+ ) -> _Reader [list [_T ]]:
499
586
"""
500
587
This API use python ``multiprocessing`` to read data from ``readers`` parallelly,
501
588
and then ``multiprocess.Queue`` or ``multiprocess.Pipe`` is used to merge
@@ -508,13 +595,13 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
508
595
in some platforms.
509
596
510
597
Parameters:
511
- readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
512
- used to read input data
513
- use_pipe (bool, optional): control the inner API used to implement the multi-processing,
514
- default True - use ``multiprocess.Pipe`` which is recommended
515
- queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
516
- is used, default 1000. Increase this value can speed up the data reading, and more memory
517
- will be consumed.
598
+ readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
599
+ used to read input data
600
+ use_pipe (bool, optional): control the inner API used to implement the multi-processing,
601
+ default True - use ``multiprocess.Pipe`` which is recommended
602
+ queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
603
+ is used, default 1000. Increase this value can speed up the data reading, and more memory
604
+ will be consumed.
518
605
519
606
Returns:
520
607
``generator``: a new reader which can be run parallelly
0 commit comments