5
5
import socket
6
6
import tarfile
7
7
from io import StringIO
8
- from typing import Callable , Optional
8
+ from typing import (
9
+ Optional , TypeVar , Type , Union , overload , Dict , Iterator , Any , Tuple ,
10
+ TYPE_CHECKING , Generic , Callable )
9
11
from urllib .parse import urlparse , uses_netloc , uses_params , uses_relative
10
12
11
13
from google .protobuf .message import Message
14
+ from mypy .typeshed .stdlib .contextlib import _GeneratorContextManager
15
+
16
+ if TYPE_CHECKING :
17
+ from bentoml ._internal .yatai_client import YataiClient
12
18
13
19
from ..utils .gcs import is_gcs_url
14
20
from ..utils .lazy_loader import LazyLoader
41
47
42
48
43
49
class _Missing (object ):
44
- def __repr__ (self ):
50
+ def __repr__ (self ) -> str :
45
51
return "no value"
46
52
47
- def __reduce__ (self ):
53
+ def __reduce__ (self ) -> str :
48
54
return "_missing"
49
55
50
56
51
57
_missing = _Missing ()
52
58
53
59
54
- class cached_property (property ):
60
+ T = TypeVar ("T" )
61
+ V = TypeVar ("V" )
62
+
63
+
64
+ class cached_property (property , Generic [T , V ]):
55
65
"""A decorator that converts a function into a lazy property. The
56
66
function wrapped is called the first time to retrieve the result
57
67
and then that calculated result is used the next time you access
@@ -76,28 +86,32 @@ def foo(self):
76
86
manual invocation.
77
87
"""
78
88
79
- def __init__ (
80
- self , func : Callable , name : str = None , doc : str = None
81
- ): # pylint:disable=super-init-not-called
89
+ def __init__ (self , func : Callable [[T ], V ], name : Optional [str ] = None , doc : Optional [str ] = None ): # pylint:disable=super-init-not-called
82
90
self .__name__ = name or func .__name__
83
91
self .__module__ = func .__module__
84
92
self .__doc__ = doc or func .__doc__
85
93
self .func = func
86
94
87
- def __set__ (self , obj , value ) :
95
+ def __set__ (self , obj : T , value : V ) -> None :
88
96
obj .__dict__ [self .__name__ ] = value
89
97
90
- def __get__ (self , obj , type = None ): # pylint:disable=redefined-builtin
98
+ @overload
99
+ def __get__ (self , obj : None , type : Optional [Type [T ]] = None ) -> "cached_property" : ...
100
+
101
+ @overload
102
+ def __get__ (self , obj : T , type : Optional [Type [T ]] = None ) -> V : ...
103
+
104
+ def __get__ (self , obj : Optional [T ], type : Optional [Type [T ]] = None ) -> Union ["cached_property" , V ]: # pylint:disable=redefined-builtin
91
105
if obj is None :
92
106
return self
93
- value = obj .__dict__ .get (self .__name__ , _missing )
107
+ value : V = obj .__dict__ .get (self .__name__ , _missing )
94
108
if value is _missing :
95
109
value = self .func (obj )
96
110
obj .__dict__ [self .__name__ ] = value
97
111
return value
98
112
99
113
100
- class cached_contextmanager :
114
+ class cached_contextmanager ( Generic [ T ]) :
101
115
"""
102
116
Just like contextlib.contextmanager, but will cache the yield value for the same
103
117
arguments. When one instance of the contextmanager exits, the cache value will
@@ -113,20 +127,21 @@ def start_docker_container_from_image(docker_image, timeout=60):
113
127
container.stop()
114
128
"""
115
129
116
- def __init__ (self , cache_key_template = None ):
130
+ def __init__ (self , cache_key_template : Optional [ str ] = None ) -> None :
117
131
self ._cache_key_template = cache_key_template
118
- self ._cache = {}
132
+ self ._cache : Dict [ Union [ str , Tuple ], T ] = {}
119
133
120
- def __call__ (self , func ):
134
+ # TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
135
+ def __call__ (self , func : Callable [..., Iterator [T ]]) -> Callable [..., _GeneratorContextManager [T ]]:
121
136
func_m = contextlib .contextmanager (func )
122
137
123
138
@contextlib .contextmanager
124
139
@functools .wraps (func )
125
- def _func (* args , ** kwargs ) :
140
+ def _func (* args : Any , ** kwargs : Any ) -> Iterator [ T ] :
126
141
bound_args = inspect .signature (func ).bind (* args , ** kwargs )
127
142
bound_args .apply_defaults ()
128
143
if self ._cache_key_template :
129
- cache_key = self ._cache_key_template .format (** bound_args .arguments )
144
+ cache_key : Union [ str , Tuple ] = self ._cache_key_template .format (** bound_args .arguments )
130
145
else :
131
146
cache_key = tuple (bound_args .arguments .values ())
132
147
if cache_key in self ._cache :
@@ -141,7 +156,7 @@ def _func(*args, **kwargs):
141
156
142
157
143
158
@contextlib .contextmanager
144
- def reserve_free_port (host = "localhost" ):
159
+ def reserve_free_port (host : str = "localhost" ) -> Iterator [ int ] :
145
160
"""
146
161
detect free port and reserve until exit the context
147
162
"""
@@ -152,13 +167,13 @@ def reserve_free_port(host="localhost"):
152
167
sock .close ()
153
168
154
169
155
- def get_free_port (host = "localhost" ):
170
+ def get_free_port (host : str = "localhost" ) -> int :
156
171
"""
157
172
detect free port and reserve until exit the context
158
173
"""
159
174
sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
160
175
sock .bind ((host , 0 ))
161
- port = sock .getsockname ()[1 ]
176
+ port : int = sock .getsockname ()[1 ]
162
177
sock .close ()
163
178
return port
164
179
@@ -170,7 +185,7 @@ def is_url(url: str) -> bool:
170
185
return False
171
186
172
187
173
- def dump_to_yaml_str (yaml_dict ) :
188
+ def dump_to_yaml_str (yaml_dict : Dict ) -> str :
174
189
from ..utils .ruamel_yaml import YAML
175
190
176
191
yaml = YAML ()
@@ -186,7 +201,7 @@ def pb_to_yaml(message: Message) -> str:
186
201
return dump_to_yaml_str (message_dict )
187
202
188
203
189
- def ProtoMessageToDict (protobuf_msg : Message , ** kwargs ) -> object :
204
+ def ProtoMessageToDict (protobuf_msg : Message , ** kwargs : Any ) -> object :
190
205
from google .protobuf .json_format import MessageToDict
191
206
192
207
if "preserving_proto_field_name" not in kwargs :
@@ -196,7 +211,7 @@ def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:
196
211
197
212
198
213
# This function assume the status is not status.OK
199
- def status_pb_to_error_code_and_message (pb_status ) -> ( int , str ) :
214
+ def status_pb_to_error_code_and_message (pb_status ) -> Tuple [ int , str ] :
200
215
from ..yatai_client .proto import status_pb2
201
216
202
217
assert pb_status .status_code != status_pb2 .Status .OK
@@ -205,14 +220,15 @@ def status_pb_to_error_code_and_message(pb_status) -> (int, str):
205
220
return error_code , error_message
206
221
207
222
208
- class catch_exceptions (object ):
209
- def __init__ (self , exceptions , fallback = None ):
223
+ class catch_exceptions (object , Generic [ T ] ):
224
+ def __init__ (self , exceptions : Union [ Type [ BaseException ], Tuple [ Type [ BaseException ], ...]], fallback : Optional [ T ] = None ) -> None :
210
225
self .exceptions = exceptions
211
226
self .fallback = fallback
212
227
213
- def __call__ (self , func ):
228
+ # TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
229
+ def __call__ (self , func : Callable [..., T ]) -> Callable [..., Optional [T ]]:
214
230
@functools .wraps (func )
215
- def _ (* args , ** kwargs ) :
231
+ def _ (* args : Any , ** kwargs : Any ) -> Optional [ T ] :
216
232
try :
217
233
return func (* args , ** kwargs )
218
234
except self .exceptions :
@@ -253,7 +269,7 @@ def resolve_bundle_path(
253
269
)
254
270
255
271
256
- def get_default_yatai_client ():
272
+ def get_default_yatai_client () -> YataiClient :
257
273
from bentoml ._internal .yatai_client import YataiClient
258
274
259
275
return YataiClient ()
@@ -271,7 +287,7 @@ def resolve_bento_bundle_uri(bento_pb):
271
287
272
288
def archive_directory_to_tar (
273
289
source_dir : str , tarfile_dir : str , tarfile_name : str
274
- ) -> ( str , str ) :
290
+ ) -> Tuple [ str , str ] :
275
291
file_name = f"{ tarfile_name } .tar"
276
292
tarfile_path = os .path .join (tarfile_dir , file_name )
277
293
with tarfile .open (tarfile_path , mode = "w:gz" ) as tar :
0 commit comments