Skip to content

fix(type): make some decorator utility functions type-safe and add some type annotations #1785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions bentoml/_internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
import socket
import tarfile
from io import StringIO
from typing import Callable, Optional
from typing import (
Optional, TypeVar, Type, Union, overload, Dict, Iterator, Any, Tuple,
TYPE_CHECKING, Generic, Callable)
from urllib.parse import urlparse, uses_netloc, uses_params, uses_relative

from google.protobuf.message import Message
from mypy.typeshed.stdlib.contextlib import _GeneratorContextManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means adding mypy into our library dependency. I would rather to reserve all typeshed import under TYPE_CHECKING

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot to add the relevant dependencies.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries since its only a type call i just added to the files. I think we shouldn't add mypy to library dependency anw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because mypy uses typeshed types to annotate the standard libraries types during static analysis, so that mypy can only recognize the types in typeshed.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I don't know if its a good practice to include mypy into library dependency. I don't see a lot of other library doing so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to do it, I'll research how other libraries are doing it.


if TYPE_CHECKING:
from bentoml._internal.yatai_client import YataiClient

from ..utils.gcs import is_gcs_url
from ..utils.lazy_loader import LazyLoader
Expand Down Expand Up @@ -41,17 +47,21 @@


class _Missing(object):
def __repr__(self):
def __repr__(self) -> str:
return "no value"

def __reduce__(self):
def __reduce__(self) -> str:
return "_missing"


_missing = _Missing()


class cached_property(property):
T = TypeVar("T")
V = TypeVar("V")


class cached_property(property, Generic[T, V]):
"""A decorator that converts a function into a lazy property. The
function wrapped is called the first time to retrieve the result
and then that calculated result is used the next time you access
Expand All @@ -76,28 +86,32 @@ def foo(self):
manual invocation.
"""

def __init__(
self, func: Callable, name: str = None, doc: str = None
): # pylint:disable=super-init-not-called
def __init__(self, func: Callable[[T], V], name: Optional[str] = None, doc: Optional[str] = None): # pylint:disable=super-init-not-called
self.__name__ = name or func.__name__
self.__module__ = func.__module__
self.__doc__ = doc or func.__doc__
self.func = func

def __set__(self, obj, value):
def __set__(self, obj: T, value: V) -> None:
obj.__dict__[self.__name__] = value

def __get__(self, obj, type=None): # pylint:disable=redefined-builtin
@overload
def __get__(self, obj: None, type: Optional[Type[T]] = None) -> "cached_property": ...

@overload
def __get__(self, obj: T, type: Optional[Type[T]] = None) -> V: ...

def __get__(self, obj: Optional[T], type: Optional[Type[T]] = None) -> Union["cached_property", V]: # pylint:disable=redefined-builtin
if obj is None:
return self
value = obj.__dict__.get(self.__name__, _missing)
value: V = obj.__dict__.get(self.__name__, _missing)
if value is _missing:
value = self.func(obj)
obj.__dict__[self.__name__] = value
return value


class cached_contextmanager:
class cached_contextmanager(Generic[T]):
"""
Just like contextlib.contextmanager, but will cache the yield value for the same
arguments. When one instance of the contextmanager exits, the cache value will
Expand All @@ -113,20 +127,21 @@ def start_docker_container_from_image(docker_image, timeout=60):
container.stop()
"""

def __init__(self, cache_key_template=None):
def __init__(self, cache_key_template: Optional[str] = None) -> None:
self._cache_key_template = cache_key_template
self._cache = {}
self._cache: Dict[Union[str, Tuple], T] = {}

def __call__(self, func):
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
def __call__(self, func: Callable[..., Iterator[T]]) -> Callable[..., _GeneratorContextManager[T]]:
func_m = contextlib.contextmanager(func)

@contextlib.contextmanager
@functools.wraps(func)
def _func(*args, **kwargs):
def _func(*args: Any, **kwargs: Any) -> Iterator[T]:
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
if self._cache_key_template:
cache_key = self._cache_key_template.format(**bound_args.arguments)
cache_key: Union[str, Tuple] = self._cache_key_template.format(**bound_args.arguments)
else:
cache_key = tuple(bound_args.arguments.values())
if cache_key in self._cache:
Expand All @@ -141,7 +156,7 @@ def _func(*args, **kwargs):


@contextlib.contextmanager
def reserve_free_port(host="localhost"):
def reserve_free_port(host: str = "localhost") -> Iterator[int]:
"""
detect free port and reserve until exit the context
"""
Expand All @@ -152,13 +167,13 @@ def reserve_free_port(host="localhost"):
sock.close()


def get_free_port(host="localhost"):
def get_free_port(host: str = "localhost") -> int:
"""
detect free port and reserve until exit the context
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((host, 0))
port = sock.getsockname()[1]
port: int = sock.getsockname()[1]
sock.close()
return port

Expand All @@ -170,7 +185,7 @@ def is_url(url: str) -> bool:
return False


def dump_to_yaml_str(yaml_dict):
def dump_to_yaml_str(yaml_dict: Dict) -> str:
from ..utils.ruamel_yaml import YAML

yaml = YAML()
Expand All @@ -186,7 +201,7 @@ def pb_to_yaml(message: Message) -> str:
return dump_to_yaml_str(message_dict)


def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:
def ProtoMessageToDict(protobuf_msg: Message, **kwargs: Any) -> object:
from google.protobuf.json_format import MessageToDict

if "preserving_proto_field_name" not in kwargs:
Expand All @@ -196,7 +211,7 @@ def ProtoMessageToDict(protobuf_msg: Message, **kwargs) -> object:


# This function assume the status is not status.OK
def status_pb_to_error_code_and_message(pb_status) -> (int, str):
def status_pb_to_error_code_and_message(pb_status) -> Tuple[int, str]:
from ..yatai_client.proto import status_pb2

assert pb_status.status_code != status_pb2.Status.OK
Expand All @@ -205,14 +220,15 @@ def status_pb_to_error_code_and_message(pb_status) -> (int, str):
return error_code, error_message


class catch_exceptions(object):
def __init__(self, exceptions, fallback=None):
class catch_exceptions(object, Generic[T]):
def __init__(self, exceptions: Union[Type[BaseException], Tuple[Type[BaseException], ...]], fallback: Optional[T] = None) -> None:
self.exceptions = exceptions
self.fallback = fallback

def __call__(self, func):
# TODO: use ParamSpec: https://github.com/python/mypy/issues/8645
def __call__(self, func: Callable[..., T]) -> Callable[..., Optional[T]]:
@functools.wraps(func)
def _(*args, **kwargs):
def _(*args: Any, **kwargs: Any) -> Optional[T]:
try:
return func(*args, **kwargs)
except self.exceptions:
Expand Down Expand Up @@ -253,7 +269,7 @@ def resolve_bundle_path(
)


def get_default_yatai_client():
def get_default_yatai_client() -> YataiClient:
from bentoml._internal.yatai_client import YataiClient

return YataiClient()
Expand All @@ -271,7 +287,7 @@ def resolve_bento_bundle_uri(bento_pb):

def archive_directory_to_tar(
source_dir: str, tarfile_dir: str, tarfile_name: str
) -> (str, str):
) -> Tuple[str, str]:
file_name = f"{tarfile_name}.tar"
tarfile_path = os.path.join(tarfile_dir, file_name)
with tarfile.open(tarfile_path, mode="w:gz") as tar:
Expand Down
16 changes: 8 additions & 8 deletions bentoml/_internal/utils/csv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# CSV utils following https://tools.ietf.org/html/rfc4180
from typing import Iterable, Iterator
from typing import Iterable, Iterator, Union


def csv_splitlines(string) -> Iterator[str]:
def csv_splitlines(string: str) -> Iterator[str]:
if '"' in string:

def _iter_line(line):
def _iter_line(line: str) -> Iterator[str]:
quoted = False
last_cur = 0
for i, c in enumerate(line):
Expand All @@ -25,11 +25,11 @@ def _iter_line(line):
return iter(string.splitlines())


def csv_split(string, delimiter) -> Iterator[str]:
def csv_split(string: str, delimiter: str) -> Iterator[str]:
if '"' in string:
dlen = len(delimiter)

def _iter_line(line):
def _iter_line(line: str) -> Iterator[str]:
quoted = False
last_cur = 0
for i, c in enumerate(line):
Expand All @@ -45,19 +45,19 @@ def _iter_line(line):
return iter(string.split(delimiter))


def csv_row(tds: Iterable):
def csv_row(tds: Iterable) -> str:
return ",".join(csv_quote(td) for td in tds)


def csv_unquote(string):
def csv_unquote(string: str) -> str:
if '"' in string:
string = string.strip()
assert string[0] == '"' and string[-1] == '"'
return string[1:-1].replace('""', '"')
return string


def csv_quote(td):
def csv_quote(td: Union[int, str]) -> str:
"""
>>> csv_quote(1)
'1'
Expand Down
4 changes: 2 additions & 2 deletions bentoml/_internal/utils/dataframe_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import itertools
import json
from typing import Iterable, Iterator, Mapping
from typing import Iterable, Iterator, Mapping, Any, Union, Set

from bentoml.exceptions import BadInput

Expand All @@ -24,7 +24,7 @@ def check_dataframe_column_contains(required_column_names, df):


@catch_exceptions(Exception, fallback=None)
def guess_orient(table, strict=False):
def guess_orient(table: Any, strict: bool = False) -> Union[None, str, Set[str]]:
if isinstance(table, list):
if not table:
if strict:
Expand Down
12 changes: 8 additions & 4 deletions bentoml/_internal/yatai_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#
# logger = logging.getLogger(__name__)

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from bentoml._internal.yatai_client.proto.yatai_service_pb2_grpc import YataiStub

import logging

Expand All @@ -33,11 +37,11 @@ def __init__(self, yatai_server_name: str):
self.deploy_api_client = None

@cached_property
def bundles(self):
def bundles(self) -> BentoRepositoryAPIClient:
return BentoRepositoryAPIClient(self._yatai_service)

@cached_property
def deployment(self):
def deployment(self) -> DeploymentAPIClient:
return DeploymentAPIClient(self._yatai_service)

# def __init__(self, yatai_service: Optional["YataiStub"] = None):
Expand All @@ -50,7 +54,7 @@ def deployment(self):
# return BentoRepositoryAPIClient(self.yatai_service)


def get_yatai_client(yatai_url: str = None) -> "YataiClient":
def get_yatai_client(yatai_url: Optional[str] = None) -> YataiClient:
"""
Args:
yatai_url (`str`):
Expand Down Expand Up @@ -80,7 +84,7 @@ def get_yatai_service(
tls_root_ca_cert: str,
tls_client_key: str,
tls_client_cert: str,
):
) -> YataiStub:
import certifi
import grpc

Expand Down
5 changes: 4 additions & 1 deletion bentoml/_internal/yatai_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
def parse_grpc_url(url):
from typing import Tuple, Optional


def parse_grpc_url(url: str) -> Tuple[Optional[str], str]:
"""
>>> parse_grpc_url("grpcs://yatai.com:43/query")
('grpcs', 'yatai.com:43/query')
Expand Down
6 changes: 3 additions & 3 deletions bentoml/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

import cloudpickle

from ._internal.models.base import (
from bentoml._internal.models.base import (
H5_EXTENSION,
HDF5_EXTENSION,
JSON_EXTENSION,
MODEL_NAMESPACE,
PICKLE_EXTENSION,
Model,
)
from ._internal.types import MetadataType, PathType
from .exceptions import MissingDependencyException
from bentoml._internal.types import MetadataType, PathType
from bentoml.exceptions import MissingDependencyException

# fmt: off
try:
Expand Down
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[mypy]
show_error_codes = True
disable_error_code = attr-defined
exclude = "|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
exclude = "|venv|(bentoml/_internal/yatai_client/proto)|(yatai/yatai/proto)|(yatai/versioneer.py)|"
ignore_missing_imports = True

# mypy --strict --allow-any-generics --allow-subclassing-any --no-check-untyped-defs --allow-untyped-call
Expand All @@ -22,4 +22,4 @@ ignore_errors = True
ignore_errors = True

[mypy-*.exceptions.*]
ignore_errors = True
ignore_errors = True