Skip to content

Commit 4fbf65b

Browse files
authored
fix(type): make some decorator utility functions type-safe and add some type annotations (#1785)
1 parent 4c95ad0 commit 4fbf65b

File tree

7 files changed

+71
-48
lines changed

7 files changed

+71
-48
lines changed

bentoml/_internal/utils/__init__.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
import socket
66
import tarfile
77
from io import StringIO
8-
from typing import Callable, Optional
8+
from typing import (
9+
Optional, TypeVar, Type, Union, overload, Dict, Iterator, Any, Tuple,
10+
TYPE_CHECKING, Generic, Callable)
911
from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative
1012

1113
from google.protobuf.message import Message
14+
from mypy.typeshed.stdlib.contextlib import _GeneratorContextManager
15+
16+
if TYPE_CHECKING:
17+
from bentoml._internal.yatai_client import YataiClient
1218

1319
from ..utils.gcs import is_gcs_url
1420
from ..utils.lazy_loader import LazyLoader
@@ -41,17 +47,21 @@
4147

4248

4349
class _Missing(object):
44-
def __repr__(self):
50+
def __repr__(self) -> str:
4551
return "no value"
4652

47-
def __reduce__(self):
53+
def __reduce__(self) -> str:
4854
return "_missing"
4955

5056

5157
_missing = _Missing()
5258

5359

54-
class cached_property(property):
60+
T = TypeVar("T")
61+
V = TypeVar("V")
62+
63+
64+
class cached_property(property, Generic[T, V]):
5565
"""A decorator that converts a function into a lazy property. The
5666
function wrapped is called the first time to retrieve the result
5767
and then that calculated result is used the next time you access
@@ -76,28 +86,32 @@ def foo(self):
7686
manual invocation.
7787
"""
7888

79-
def __init__(
80-
self, func: Callable, name: str = None, doc: str = None
81-
): # pylint:disable=super-init-not-called
89+
def __init__(self, func: Callable[[T], V], name: Optional[str] = None, doc: Optional[str] = None): # pylint:disable=super-init-not-called
8290
self.__name__ = name or func.__name__
8391
self.__module__ = func.__module__
8492
self.__doc__ = doc or func.__doc__
8593
self.func = func
8694

87-
def __set__(self, obj, value):
95+
def __set__(self, obj: T, value: V) -> None:
8896
obj.__dict__[self.__name__] = value
8997

90-
def __get__(self, obj, type=None): # pylint:disable=redefined-builtin
98+
@overload
99+
def __get__(self, obj: None, type: Optional[Type[T]] = None) -> "cached_property": ...
100+
101+
@overload
102+
def __get__(self, obj: T, type: Optional[Type[T]] = None) -> V: ...
103+
104+
def __get__(self, obj: Optional[T], type: Optional[Type[T]] = None) -> Union["cached_property", V]: # pylint:disable=redefined-builtin
91105
if obj is None:
92106
return self
93-
value = obj.__dict__.get(self.__name__, _missing)
107+
value: V = obj.__dict__.get(self.__name__, _missing)
94108
if value is _missing:
95109
value = self.func(obj)
96110
obj.__dict__[self.__name__] = value
97111
return value
98112

99113

100-
class cached_contextmanager:
114+
class cached_contextmanager(Generic[T]):
101115
"""
102116
Just like contextlib.contextmanager, but will cache the yield value for the same
103117
arguments. When one instance of the contextmanager exits, the cache value will
@@ -113,20 +127,21 @@ def start_docker_container_from_image(docker_image, timeout=60):
113127
container.stop()
114128
"""
115129

116-
def __init__(self, cache_key_template=None):
130+
def __init__(self, cache_key_template: Optional[str] = None) -> None:
117131
self._cache_key_template = cache_key_template
118-
self._cache = {}
132+
self._cache: Dict[Union[str, Tuple], T] = {}
119133

120-
def __call__(self, func):
134+
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
135+
def __call__(self, func: Callable[..., Iterator[T]]) -> Callable[..., _GeneratorContextManager[T]]:
121136
func_m = contextlib.contextmanager(func)
122137

123138
@contextlib.contextmanager
124139
@functools.wraps(func)
125-
def _func(*args, **kwargs):
140+
def _func(*args: Any, **kwargs: Any) -> Iterator[T]:
126141
bound_args = inspect.signature(func).bind(*args, **kwargs)
127142
bound_args.apply_defaults()
128143
if self._cache_key_template:
129-
cache_key = self._cache_key_template.format(**bound_args.arguments)
144+
cache_key: Union[str, Tuple] = self._cache_key_template.format(**bound_args.arguments)
130145
else:
131146
cache_key = tuple(bound_args.arguments.values())
132147
if cache_key in self._cache:
@@ -141,7 +156,7 @@ def _func(*args, **kwargs):
141156

142157

143158
@contextlib.contextmanager
144-
def reserve_free_port(host="localhost"):
159+
def reserve_free_port(host: str = "localhost") -> Iterator[int]:
145160
"""
146161
detect free port and reserve until exit the context
147162
"""
@@ -152,13 +167,13 @@ def reserve_free_port(host="localhost"):
152167
sock.close()
153168

154169

155-
def get_free_port(host="localhost"):
170+
def get_free_port(host: str = "localhost") -> int:
156171
"""
157172
detect free port and reserve until exit the context
158173
"""
159174
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
160175
sock.bind((host, 0))
161-
port = sock.getsockname()[1]
176+
port: int = sock.getsockname()[1]
162177
sock.close()
163178
return port
164179

@@ -170,7 +185,7 @@ def is_url(url: str) -> bool:
170185
return False
171186

172187

173-
def dump_to_yaml_str(yaml_dict):
188+
def dump_to_yaml_str(yaml_dict: Dict) -> str:
174189
from ..utils.ruamel_yaml import YAML
175190

176191
yaml = YAML()
@@ -186,7 +201,7 @@ def pb_to_yaml(message: Message) -> str:
186201
return dump_to_yaml_str(message_dict)
187202

188203

189-
def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:
204+
def ProtoMessageToDict(protobuf_msg: Message, **kwargs: Any) -> object:
190205
from google.protobuf.json_format import MessageToDict
191206

192207
if "preserving_proto_field_name" not in kwargs:
@@ -196,7 +211,7 @@ def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:
196211

197212

198213
# This function assume the status is not status.OK
199-
def status_pb_to_error_code_and_message(pb_status) -> (int, str):
214+
def status_pb_to_error_code_and_message(pb_status) -> Tuple[int, str]:
200215
from ..yatai_client.proto import status_pb2
201216

202217
assert pb_status.status_code != status_pb2.Status.OK
@@ -205,14 +220,15 @@ def status_pb_to_error_code_and_message(pb_status) -> (int, str):
205220
return error_code, error_message
206221

207222

208-
class catch_exceptions(object):
209-
def __init__(self, exceptions, fallback=None):
223+
class catch_exceptions(object, Generic[T]):
224+
def __init__(self, exceptions: Union[Type[BaseException], Tuple[Type[BaseException], ...]], fallback: Optional[T] = None) -> None:
210225
self.exceptions = exceptions
211226
self.fallback = fallback
212227

213-
def __call__(self, func):
228+
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
229+
def __call__(self, func: Callable[..., T]) -> Callable[..., Optional[T]]:
214230
@functools.wraps(func)
215-
def _(*args, **kwargs):
231+
def _(*args: Any, **kwargs: Any) -> Optional[T]:
216232
try:
217233
return func(*args, **kwargs)
218234
except self.exceptions:
@@ -253,7 +269,7 @@ def resolve_bundle_path(
253269
)
254270

255271

256-
def get_default_yatai_client():
272+
def get_default_yatai_client() -> YataiClient:
257273
from bentoml._internal.yatai_client import YataiClient
258274

259275
return YataiClient()
@@ -271,7 +287,7 @@ def resolve_bento_bundle_uri(bento_pb):
271287

272288
def archive_directory_to_tar(
273289
source_dir: str, tarfile_dir: str, tarfile_name: str
274-
) -> (str, str):
290+
) -> Tuple[str, str]:
275291
file_name = f"{tarfile_name}.tar"
276292
tarfile_path = os.path.join(tarfile_dir, file_name)
277293
with tarfile.open(tarfile_path, mode="w:gz") as tar:

bentoml/_internal/utils/csv.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# CSV utils following https://tools.ietf.org/html/rfc4180
2-
from typing import Iterable, Iterator
2+
from typing import Iterable, Iterator, Union
33

44

5-
def csv_splitlines(string) -> Iterator[str]:
5+
def csv_splitlines(string: str) -> Iterator[str]:
66
if '"' in string:
77

8-
def _iter_line(line):
8+
def _iter_line(line: str) -> Iterator[str]:
99
quoted = False
1010
last_cur = 0
1111
for i, c in enumerate(line):
@@ -25,11 +25,11 @@ def _iter_line(line):
2525
return iter(string.splitlines())
2626

2727

28-
def csv_split(string, delimiter) -> Iterator[str]:
28+
def csv_split(string: str, delimiter: str) -> Iterator[str]:
2929
if '"' in string:
3030
dlen = len(delimiter)
3131

32-
def _iter_line(line):
32+
def _iter_line(line: str) -> Iterator[str]:
3333
quoted = False
3434
last_cur = 0
3535
for i, c in enumerate(line):
@@ -45,19 +45,19 @@ def _iter_line(line):
4545
return iter(string.split(delimiter))
4646

4747

48-
def csv_row(tds: Iterable):
48+
def csv_row(tds: Iterable) -> str:
4949
return ",".join(csv_quote(td) for td in tds)
5050

5151

52-
def csv_unquote(string):
52+
def csv_unquote(string: str) -> str:
5353
if '"' in string:
5454
string = string.strip()
5555
assert string[0] == '"' and string[-1] == '"'
5656
return string[1:-1].replace('""', '"')
5757
return string
5858

5959

60-
def csv_quote(td):
60+
def csv_quote(td: Union[int, str]) -> str:
6161
"""
6262
>>> csv_quote(1)
6363
'1'

bentoml/_internal/utils/dataframe_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import itertools
33
import json
4-
from typing import Iterable, Iterator, Mapping
4+
from typing import Iterable, Iterator, Mapping, Any, Union, Set
55

66
from bentoml.exceptions import BadInput
77

@@ -24,7 +24,7 @@ def check_dataframe_column_contains(required_column_names, df):
2424

2525

2626
@catch_exceptions(Exception, fallback=None)
27-
def guess_orient(table, strict=False):
27+
def guess_orient(table: Any, strict: bool = False) -> Union[None, str, Set[str]]:
2828
if isinstance(table, list):
2929
if not table:
3030
if strict:

bentoml/_internal/yatai_client/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#
1010
# logger = logging.getLogger(__name__)
1111

12+
from typing import TYPE_CHECKING, Optional
13+
14+
if TYPE_CHECKING:
15+
from bentoml._internal.yatai_client.proto.yatai_service_pb2_grpc import YataiStub
1216

1317
import logging
1418

@@ -33,11 +37,11 @@ def __init__(self, yatai_server_name: str):
3337
self.deploy_api_client = None
3438

3539
@cached_property
36-
def bundles(self):
40+
def bundles(self) -> BentoRepositoryAPIClient:
3741
return BentoRepositoryAPIClient(self._yatai_service)
3842

3943
@cached_property
40-
def deployment(self):
44+
def deployment(self) -> DeploymentAPIClient:
4145
return DeploymentAPIClient(self._yatai_service)
4246

4347
# def __init__(self, yatai_service: Optional["YataiStub"] = None):
@@ -50,7 +54,7 @@ def deployment(self):
5054
# return BentoRepositoryAPIClient(self.yatai_service)
5155

5256

53-
def get_yatai_client(yatai_url: str = None) -> "YataiClient":
57+
def get_yatai_client(yatai_url: Optional[str] = None) -> YataiClient:
5458
"""
5559
Args:
5660
yatai_url (`str`):
@@ -80,7 +84,7 @@ def get_yatai_service(
8084
tls_root_ca_cert: str,
8185
tls_client_key: str,
8286
tls_client_cert: str,
83-
):
87+
) -> YataiStub:
8488
import certifi
8589
import grpc
8690

bentoml/_internal/yatai_client/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
def parse_grpc_url(url):
1+
from typing import Tuple, Optional
2+
3+
4+
def parse_grpc_url(url: str) -> Tuple[Optional[str], str]:
25
"""
36
>>> parse_grpc_url("grpcs://yatai.com:43/query")
47
('grpcs', 'yatai.com:43/query')

bentoml/keras.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
import cloudpickle
55

6-
from ._internal.models.base import (
6+
from bentoml._internal.models.base import (
77
H5_EXTENSION,
88
HDF5_EXTENSION,
99
JSON_EXTENSION,
1010
MODEL_NAMESPACE,
1111
PICKLE_EXTENSION,
1212
Model,
1313
)
14-
from ._internal.types import MetadataType, PathType
15-
from .exceptions import MissingDependencyException
14+
from bentoml._internal.types import MetadataType, PathType
15+
from bentoml.exceptions import MissingDependencyException
1616

1717
# fmt: off
1818
try:

mypy.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[mypy]
22
show_error_codes = True
33
disable_error_code = attr-defined
4-
exclude = "|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
4+
exclude = "|venv|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
55
ignore_missing_imports = True
66

77
# mypy --strict --allow-any-generics --allow-subclassing-any --no-check-untyped-defs --allow-untyped-call
@@ -22,4 +22,4 @@ ignore_errors = True
2222
ignore_errors = True
2323

2424
[mypy-*.exceptions.*]
25-
ignore_errors = True
25+
ignore_errors = True

0 commit comments

Comments
 (0)