Skip to content

Commit 4de8da3

Browse files
SigureMolixcli
authored andcommitted
[Typing][B-93] Add type annotations for python/paddle/reader/decorator.py (PaddlePaddle#66305)
* [Typing] Add type annotations for `python/paddle/reader/decorator.py` * missing pep563
1 parent b788ae2 commit 4de8da3

File tree

1 file changed

+110
-23
lines changed

1 file changed

+110
-23
lines changed

python/paddle/reader/decorator.py

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import itertools
1618
import logging
1719
import multiprocessing
@@ -21,9 +23,27 @@
2123
from itertools import zip_longest
2224
from queue import Queue
2325
from threading import Thread
26+
from typing import (
27+
TYPE_CHECKING,
28+
Any,
29+
Callable,
30+
Generator,
31+
Sequence,
32+
TypedDict,
33+
TypeVar,
34+
overload,
35+
)
36+
37+
from typing_extensions import NotRequired, TypeAlias, Unpack
2438

2539
from paddle.base.reader import QUEUE_GET_TIMEOUT
2640

41+
if TYPE_CHECKING:
42+
43+
class _ComposeOptions(TypedDict):
44+
check_alignment: NotRequired[bool]
45+
46+
2747
__all__ = []
2848

2949
# On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing,
@@ -41,8 +61,18 @@
4161
else:
4262
fork_context = multiprocessing
4363

64+
_T = TypeVar('_T')
65+
_T1 = TypeVar('_T1')
66+
_T2 = TypeVar('_T2')
67+
_T3 = TypeVar('_T3')
68+
_T4 = TypeVar('_T4')
69+
_U = TypeVar('_U')
70+
4471

45-
def cache(reader):
72+
_Reader: TypeAlias = Callable[[], Generator[_T, None, None]]
73+
74+
75+
def cache(reader: _Reader[_T]) -> _Reader[_T]:
4676
"""
4777
Cache the reader data into memory.
4878
@@ -77,12 +107,60 @@ def cache(reader):
77107
"""
78108
all_data = tuple(reader())
79109

80-
def __impl__():
110+
def __impl__() -> Generator[_T, None, None]:
81111
yield from all_data
82112

83113
return __impl__
84114

85115

116+
# A temporary solution like builtin map function.
117+
# `Map` maybe the final solution in the future.
118+
# See https://github.com/python/typing/issues/1383
119+
@overload
120+
def map_readers(
121+
func: Callable[[_T1], _U], reader1: _Reader[_T1], /
122+
) -> _Reader[_U]:
123+
...
124+
125+
126+
@overload
127+
def map_readers(
128+
func: Callable[[_T1, _T2], _U],
129+
reader1: _Reader[_T1],
130+
reader2: _Reader[_T2],
131+
/,
132+
) -> _Reader[_U]:
133+
...
134+
135+
136+
@overload
137+
def map_readers(
138+
func: Callable[[_T1, _T2, _T3], _U],
139+
reader1: _Reader[_T1],
140+
reader2: _Reader[_T2],
141+
reader3: _Reader[_T3],
142+
/,
143+
) -> _Reader[_U]:
144+
...
145+
146+
147+
@overload
148+
def map_readers(
149+
func: Callable[[_T1, _T2, _T3, _T4], _U],
150+
reader1: _Reader[_T1],
151+
reader2: _Reader[_T2],
152+
reader3: _Reader[_T3],
153+
reader4: _Reader[_T4],
154+
/,
155+
) -> _Reader[_U]:
156+
...
157+
158+
159+
@overload
160+
def map_readers(func: Callable[..., _U], *readers: _Reader[Any]) -> _Reader[_U]:
161+
...
162+
163+
86164
def map_readers(func, *readers):
87165
"""
88166
Creates a data reader that outputs return value of function using
@@ -124,7 +202,7 @@ def reader():
124202
return reader
125203

126204

127-
def shuffle(reader, buf_size):
205+
def shuffle(reader: _Reader[_T], buf_size: int) -> _Reader[_T]:
128206
"""
129207
This API creates a decorated reader that outputs the shuffled data.
130208
@@ -151,7 +229,7 @@ def shuffle(reader, buf_size):
151229
>>> # outputs are 0~4 unordered arrangement
152230
"""
153231

154-
def data_reader():
232+
def data_reader() -> Generator[_T, None, None]:
155233
buf = []
156234
for e in reader():
157235
buf.append(e)
@@ -169,7 +247,7 @@ def data_reader():
169247
return data_reader
170248

171249

172-
def chain(*readers):
250+
def chain(*readers: _Reader[_T]) -> _Reader[_T]:
173251
"""
174252
Use the input data readers to create a chained data reader. The new created reader
175253
chains the outputs of input readers together as its output, and it do not change
@@ -218,8 +296,8 @@ def chain(*readers):
218296
219297
"""
220298

221-
def reader():
222-
rs = []
299+
def reader() -> Generator[_T, None, None]:
300+
rs: list[Generator[_T, None, None]] = []
223301
for r in readers:
224302
rs.append(r())
225303

@@ -232,7 +310,9 @@ class ComposeNotAligned(ValueError):
232310
pass
233311

234312

235-
def compose(*readers, **kwargs):
313+
def compose(
314+
*readers: _Reader[Any], **kwargs: Unpack[_ComposeOptions]
315+
) -> _Reader[Any]:
236316
"""
237317
Creates a data reader whose output is the combination of input readers.
238318
@@ -289,7 +369,7 @@ def reader():
289369
return reader
290370

291371

292-
def buffered(reader, size):
372+
def buffered(reader: _Reader[_T], size: int) -> _Reader[_T]:
293373
"""
294374
Creates a buffered data reader.
295375
@@ -339,10 +419,7 @@ def data_reader():
339419
q = Queue(maxsize=size)
340420
t = Thread(
341421
target=read_worker,
342-
args=(
343-
r,
344-
q,
345-
),
422+
args=(r, q),
346423
)
347424
t.daemon = True
348425
t.start()
@@ -354,7 +431,7 @@ def data_reader():
354431
return data_reader
355432

356433

357-
def firstn(reader, n):
434+
def firstn(reader: _Reader[_T], n: int) -> _Reader[_T]:
358435
"""
359436
360437
This API creates a decorated reader, and limits the max number of
@@ -399,7 +476,13 @@ class XmapEndSignal:
399476
pass
400477

401478

402-
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
479+
def xmap_readers(
480+
mapper: Callable[[_T], _U],
481+
reader: _Reader[_T],
482+
process_num: int,
483+
buffer_size: int,
484+
order: bool = False,
485+
) -> _Reader[_U]:
403486
"""
404487
Use multi-threads to map samples from reader by a mapper defined by user.
405488
@@ -495,7 +578,11 @@ def xreader():
495578
return xreader
496579

497580

498-
def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
581+
def multiprocess_reader(
582+
readers: Sequence[_Reader[_T]],
583+
use_pipe: bool = True,
584+
queue_size: int = 1000,
585+
) -> _Reader[list[_T]]:
499586
"""
500587
This API use python ``multiprocessing`` to read data from ``readers`` parallelly,
501588
and then ``multiprocess.Queue`` or ``multiprocess.Pipe`` is used to merge
@@ -508,13 +595,13 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
508595
in some platforms.
509596
510597
Parameters:
511-
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
512-
used to read input data
513-
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
514-
default True - use ``multiprocess.Pipe`` which is recommended
515-
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
516-
is used, default 1000. Increase this value can speed up the data reading, and more memory
517-
will be consumed.
598+
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
599+
used to read input data
600+
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
601+
default True - use ``multiprocess.Pipe`` which is recommended
602+
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
603+
is used, default 1000. Increase this value can speed up the data reading, and more memory
604+
will be consumed.
518605
519606
Returns:
520607
``generator``: a new reader which can be run parallelly

0 commit comments

Comments
 (0)