Skip to content

Commit

Permalink
refactor: fix ruff errors
Browse files Browse the repository at this point in the history
BREAKING CHANGE:
  • Loading branch information
Rizhiy committed Mar 3, 2024
1 parent e4bd7f3 commit 8d1a119
Show file tree
Hide file tree
Showing 18 changed files with 187 additions and 161 deletions.
26 changes: 13 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = ["coloredlogs", "docstring-parser", "python-dateutil", "xxhash"]

[project.optional-dependencies]
test = ["flaky", "pytest", "pytest-asyncio", "pytest-coverage", "types-python-dateutil", "types-xxhash"]
dev = ["black", "replete[test]", "ruff"]
dev = ["replete[test]", "ruff"]

[tool.flit.sdist]
include = ["README.md"]
Expand All @@ -27,10 +27,6 @@ minversion = "6.0"
addopts = "--doctest-modules --no-success-flaky-report"
asyncio_mode = "auto"

[tool.black]
line-length = 120
skip-magic-trailing-comma = true

[tool.yamlfix]
line_length = 120
section_whitelines = 1
Expand All @@ -47,15 +43,19 @@ preview = true
select = ["A", "ARG", "B", "BLE", "C4", "COM", "E", "ERA", "F", "FBT", "FIX", "FLY", "FURB", "I", "IC", "INP", "ISC", "LOG", "N", "NPY", "PERF", "PIE", "PT", "PTH", "Q", "R", "RET", "RSE", "S", "SIM", "SLF", "T20", "TCH", "TD", "TID", "TRY", "UP", "W"]
fixable = ["ALL"]
ignore = ["A003", "E203", "FIX002", "FURB113", "N817", "PTH123", "RET503", "S113", "TD002", "TD003", "TRY003", "UP007", "UP035"]
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = [
"F401", # Allow unused imports in module files
"F401", # Allow unused imports in module files
]
"tests/test_cli.py" = [
"W291", # Need to ignore it for correct formatting
]
"tests/**/*.py" = [
"E501", # Test strings can be long
"S101", # Asserts in tests are fine
"T201", # Prints are useful for debugging
"TCH001",
"TCH002",
"TCH003", # Tests don't need to be super performant, prefer simpler code
"E501", # Test strings can be long
"FBT001", # We don't expect to call the tests
"S101", # Asserts in tests are fine
"T201", # Prints are useful for debugging
"TCH001",
"TCH002",
"TCH003", # Tests don't need to be super performant, prefer simpler code
]
7 changes: 4 additions & 3 deletions replete/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
""" Assorted utilities with minimal dependencies """
"""Assorted utilities with minimal dependencies"""

from __future__ import annotations

from .aio import LazyWrapAsync, achunked, alist
from .aio import LazyWrapAsync, achunked
from .cli import autocli
from .consistent_hash import consistent_hash, picklehash
from .datetime import date_range, datetime_range, round_dt
from .enum import ComparableEnum
from .logging import assert_with_logging, setup_logging, warn_with_traceback
from .logging import WarnWithTraceback, assert_with_logging, setup_logging
from .register import Register
from .timing import RateLimiter, Timer
from .utils import chunks, deep_update, ensure_unique_keys, grouped, split_list
Expand Down
8 changes: 0 additions & 8 deletions replete/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
TLazyWrapValue = TypeVar("TLazyWrapValue")


async def alist(async_iter: AsyncIterable[T]) -> list[T]:
"""Simple gatherer of an async iterable into a list"""
result = []
async for item in async_iter:
result.append(item)
return result


async def achunked(aiterable: AsyncIterable[T], size: int) -> AsyncIterable[list[T]]:
"""Async iterable chunker"""
chunk: list[T] = []
Expand Down
109 changes: 63 additions & 46 deletions replete/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
import functools
import inspect
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, TypeVar, Union, cast, overload # noqa: TC002
from typing import Any, ClassVar, Generic, NamedTuple, Sequence, TypeVar, Union, cast, overload

import docstring_parser

if TYPE_CHECKING:
from collections.abc import Sequence

T = TypeVar("T")
class CLIError(Exception):
"""Exception for all things CLI"""


T = TypeVar("T")
TRet = TypeVar("TRet")
TCallable = TypeVar("TCallable", bound=Union[Callable[..., TRet], Callable[..., None]]) # type: ignore
NoneType = type(None)
Expand All @@ -33,21 +35,25 @@
def pairs_to_dict(
key_parse: Callable[[Any], Any] = str,
val_parse: Callable[[Any], Any] = str,
*,
allow_empty: bool = True,
name: str = "",
) -> Callable[[Any], Any]:
"""
Convert list of keys and values into dict
"""

def _pairs_to_dict_configured(items: Sequence[T] | None) -> dict[T, T]:
if not items:
if allow_empty:
return {}
raise ValueError(f"{name!r}: Must not be empty")
raise ValueError("Got empty list of items")
if isinstance(items, dict):
return items
if len(items) % 2 != 0:
raise ValueError("Must have key-value pairs (even number)", dict(items=items))
raise ValueError("Must have key-value pairs (even number)", {"items": items})
keys = [key_parse(key) for key in items[::2]]
values = [val_parse(val) for val in items[1::2]]
return dict(zip(keys, values))
return dict(zip(keys, values, strict=True))

return _pairs_to_dict_configured

Expand All @@ -65,8 +71,8 @@ def _resolve_type_annotation(type_annotation: Any, annotations_ns: dict[str, Any
return None, None
if isinstance(type_annotation, str):
try:
type_annotation = eval(type_annotation, COMMON_ANNOTATIONS_NS | (annotations_ns or {}))
except Exception:
type_annotation = eval(type_annotation, COMMON_ANNOTATIONS_NS | (annotations_ns or {})) # noqa: S307
except Exception: # noqa: BLE001
return None, None

type_origin = getattr(type_annotation, "__origin__", None)
Expand All @@ -87,11 +93,11 @@ def _resolve_type_annotation(type_annotation: Any, annotations_ns: dict[str, Any
if isinstance(type_args, tuple) and len(type_args) == 2:
return type_origin, type_args
return type_origin, None
elif type_origin in (list, tuple):
if type_origin in (list, tuple):
if isinstance(type_args, tuple) and len(type_args) == 1:
return type_origin, type_args[0]
return type_origin, None
elif type_name in ("Sequence", "Iterable"):
if type_name in ("Sequence", "Iterable"):
if isinstance(type_args, tuple) and len(type_args) == 1:
return list, type_args[0]
return list, None
Expand Down Expand Up @@ -153,7 +159,7 @@ def replace(self, **kwargs: Any) -> ParamExtras:


@dataclass
class AutoCLIBase(Generic[TCallable]):
class AutoCLIBase(Generic[TCallable], ABC):
"""Base interface for an automatic function-to-CLI processing"""

func: TCallable
Expand All @@ -169,20 +175,21 @@ class AutoCLIBase(Generic[TCallable]):
}
TYPE_AUTO_PARAMS: ClassVar[dict[Any, Callable[[ParamInfo], ParamExtras]]] = {
# TODO: actually make the boolean arguments into `--flag/--no-flag`.
bool: lambda param_info: ParamExtras(arg_argparse_extra_kwargs=dict(type=parse_bool)),
bool: lambda _: ParamExtras(arg_argparse_extra_kwargs={"type": parse_bool}),
dict: lambda param_info: ParamExtras(
arg_argparse_extra_kwargs=dict(nargs="*", type=str),
arg_argparse_extra_kwargs={"nargs": "*", "type": str},
arg_postprocess=pairs_to_dict(
key_parse=param_info.contained_type[0] if param_info.contained_type else str,
val_parse=param_info.contained_type[1] if param_info.contained_type else str,
),
),
list: lambda param_info: ParamExtras(
arg_argparse_extra_kwargs=dict(nargs="*", type=param_info.contained_type or str)
arg_argparse_extra_kwargs={"nargs": "*", "type": param_info.contained_type or str},
),
}

def __call__(self, *args: Any, **kwargs: Any) -> TRet | None:
@abstractmethod
def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: # type: ignore
raise NotImplementedError

@classmethod
Expand All @@ -199,7 +206,10 @@ def _make_param_infos(
defaults = defaults or {}
return [
ParamInfo.from_arg_param(
arg_param, doc=params_docs.get(name), annotations_ns=annotations_ns, default=defaults.get(name)
arg_param,
doc=params_docs.get(name),
annotations_ns=annotations_ns,
default=defaults.get(name),
)
for name, arg_param in signature.parameters.items()
]
Expand Down Expand Up @@ -233,7 +243,7 @@ class AutoCLI(AutoCLIBase[TCallable]):
signature_override: inspect.Signature | None = None
_description_indent: str = " "
# mock default for `dataclass`, gets filled in `__post_init__` with an actual value.
__wrapped__: Callable[..., Any] = lambda: None
__wrapped__: Callable[..., Any] = lambda: None # noqa: E731

@classmethod
def _base_param_extras(cls, param_info: ParamInfo) -> ParamExtras:
Expand Down Expand Up @@ -274,15 +284,20 @@ def _full_params_extras(self, param_infos: Sequence[ParamInfo]) -> dict[str, Par
return result

def _make_full_parser_and_params_extras(
self, defaults: dict[str, Any]
self,
defaults: dict[str, Any],
) -> tuple[argparse.ArgumentParser, dict[str, ParamExtras]]:
docs = docstring_parser.parse(self.func.__doc__ or "")
params_docs = {param.arg_name: param.description for param in docs.params} if docs.params else {}
description = "\n\n".join(item for item in (docs.short_description, docs.long_description) if item)
description = _indent(description, self._description_indent)
parser = self._make_base_parser(description=description)
param_infos = self._make_param_infos(
self.func, params_docs, self.annotations_ns, defaults, self.signature_override
self.func,
params_docs,
self.annotations_ns,
defaults,
self.signature_override,
)
params_extras = self._full_params_extras(param_infos)
for param in param_infos:
Expand All @@ -300,7 +315,7 @@ def _make_base_parser(self, description: str) -> argparse.ArgumentParser:
formatter_class = type(
f"_Custom_{formatter_class.__name__}",
(formatter_class,),
dict(_default_width=self.help_width, _default_max_help_position=self.max_help_position),
{"_default_width": self.help_width, "_default_max_help_position": self.max_help_position},
)
return argparse.ArgumentParser(formatter_class=formatter_class, description=description) # type: ignore

Expand All @@ -324,65 +339,68 @@ def _add_argument(
param: ParamInfo,
arg_extra_kwargs: dict[str, Any] | None = None,
) -> None:
if param.required:
arg_name = cls._pos_param_to_arg_name(name)
else:
arg_name = cls._opt_param_to_arg_name(name)
arg_name = cls._pos_param_to_arg_name(name) if param.required else cls._opt_param_to_arg_name(name)

type_converter = cls.TYPE_CONVERTERS.get(param.value_type) or param.value_type
arg_kwargs = dict(type=type_converter, help=param.doc)
arg_kwargs = {"type": type_converter, "help": param.doc}
if not param.required:
arg_kwargs.update(dict(default=param.default, help=param.doc or " "))
arg_kwargs.update({"default": param.default, "help": param.doc or " "})

if param.extra_args: # `func(*args)`
# Implementation: `arg_kwargs.update(nargs="*")`
raise Exception("Not currently supported: `**args` in function")
raise CLIError("Not currently supported: `**args` in function")
if param.extra_kwargs: # `func(**kwargs)`
raise Exception("Not currently supported: `**kwargs` in function")
raise CLIError("Not currently supported: `**kwargs` in function")
arg_extra_kwargs = dict(arg_extra_kwargs or {})
arg_extra_args = arg_extra_kwargs.pop("__args__", None) or ()
arg_kwargs.update(arg_extra_kwargs)
parser.add_argument(arg_name, *arg_extra_args, **arg_kwargs) # type: ignore # very tricky typing.

def _make_overridden_defaults(
self, args: tuple[Any, ...], kwargs: dict[str, Any]
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[dict[str, Any], tuple[Any, ...]]:
var_args = cast(tuple[Any, ...], ())
if not args:
return kwargs or {}, var_args
signature = inspect.signature(self.func)

var_arg = next(
(param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_POSITIONAL), None
(param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_POSITIONAL),
None,
)
var_kwarg = next(
(param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_KEYWORD), None
(param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_KEYWORD),
None,
)
if var_arg is not None:
raise Exception("Not currently supported: `**args` in function")
raise CLIError("Not currently supported: `**args` in function")
if var_kwarg is not None:
raise Exception("Not currently supported: `**kwargs` in function")
raise CLIError("Not currently supported: `**kwargs` in function")

return signature.bind_partial(*args, **kwargs).arguments, var_args

def _postprocess_values(
self, parsed_kwargs: dict[str, Any], params_extras: dict[str, ParamExtras]
self,
parsed_kwargs: dict[str, Any],
params_extras: dict[str, ParamExtras],
) -> dict[str, Any]:
parsed_kwargs = parsed_kwargs.copy()
for name, param_extras in params_extras.items():
if param_extras.arg_postprocess and name in parsed_kwargs:
parsed_kwargs[name] = param_extras.arg_postprocess(parsed_kwargs[name])
return parsed_kwargs

def __call__(self, *args: Any, **kwargs: Any) -> TRet | None:
def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: # type: ignore
defaults, var_args = self._make_overridden_defaults(args, kwargs)

parser, params_extras = self._make_full_parser_and_params_extras(defaults=defaults)
params, unknown_args = parser.parse_known_args(self.argv)
if unknown_args and self.fail_on_unknown_args:
raise Exception("Unrecognized arguments", dict(unknown_args=unknown_args))
parsed_args = params._get_args() # generally, an empty list.
parsed_kwargs_base = params._get_kwargs()
raise SystemExit("Unrecognized arguments", {"unknown_args": unknown_args})
parsed_args = params._get_args() # generally, an empty list. # noqa: SLF001
parsed_kwargs_base = params._get_kwargs() # noqa: SLF001
arg_name_to_param_name = {
extras.name_norm or param_name: param_name for param_name, extras in params_extras.items()
}
Expand All @@ -391,7 +409,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> TRet | None:

full_args = [*var_args, *parsed_args]
full_kwargs = {**kwargs, **parsed_kwargs}
return self.func(*full_args, **full_kwargs)
return self.func(*full_args, **full_kwargs) # type: ignore


def get_caller_ns(extra_stack: int = 0) -> dict[str, Any] | None:
Expand All @@ -408,17 +426,16 @@ def get_caller_ns(extra_stack: int = 0) -> dict[str, Any] | None:


@overload
def autocli(func: TCallable, **config: Any) -> AutoCLI[TCallable]:
...
def autocli(func: TCallable, **config: Any) -> AutoCLI[TCallable]: ...


@overload
def autocli(func: None = None, **config: Any) -> Callable[[TCallable], AutoCLI[TCallable]]:
...
def autocli(func: None = None, **config: Any) -> Callable[[TCallable], AutoCLI[TCallable]]: ...


def autocli(
func: TCallable | None = None, **config: Any
func: TCallable | None = None,
**config: Any,
) -> AutoCLI[TCallable] | Callable[[TCallable], AutoCLI[TCallable]]:
actual_config = config.copy()

Expand Down
Loading

0 comments on commit 8d1a119

Please sign in to comment.