diff --git a/pyproject.toml b/pyproject.toml index 1a383cd..782e15b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = ["coloredlogs", "docstring-parser", "python-dateutil", "xxhash"] [project.optional-dependencies] test = ["flaky", "pytest", "pytest-asyncio", "pytest-coverage", "types-python-dateutil", "types-xxhash"] -dev = ["black", "replete[test]", "ruff"] +dev = ["replete[test]", "ruff"] [tool.flit.sdist] include = ["README.md"] @@ -27,10 +27,6 @@ minversion = "6.0" addopts = "--doctest-modules --no-success-flaky-report" asyncio_mode = "auto" -[tool.black] -line-length = 120 -skip-magic-trailing-comma = true - [tool.yamlfix] line_length = 120 section_whitelines = 1 @@ -47,15 +43,19 @@ preview = true select = ["A", "ARG", "B", "BLE", "C4", "COM", "E", "ERA", "F", "FBT", "FIX", "FLY", "FURB", "I", "IC", "INP", "ISC", "LOG", "N", "NPY", "PERF", "PIE", "PT", "PTH", "Q", "R", "RET", "RSE", "S", "SIM", "SLF", "T20", "TCH", "TD", "TID", "TRY", "UP", "W"] fixable = ["ALL"] ignore = ["A003", "E203", "FIX002", "FURB113", "N817", "PTH123", "RET503", "S113", "TD002", "TD003", "TRY003", "UP007", "UP035"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "**/__init__.py" = [ - "F401", # Allow unused imports in module files + "F401", # Allow unused imports in module files +] +"tests/test_cli.py" = [ + "W291", # Need to ignore it for correct formatting ] "tests/**/*.py" = [ - "E501", # Test strings can be long - "S101", # Asserts in tests are fine - "T201", # Prints are useful for debugging - "TCH001", - "TCH002", - "TCH003", # Tests don't need to be super performant, prefer simpler code + "E501", # Test strings can be long + "FBT001", # We don't expect to call the tests + "S101", # Asserts in tests are fine + "T201", # Prints are useful for debugging + "TCH001", + "TCH002", + "TCH003", # Tests don't need to be super performant, prefer simpler code ] diff --git a/replete/__init__.py b/replete/__init__.py index ae4e32a..4396bfe 100644 --- a/replete/__init__.py +++ b/replete/__init__.py @@ -1,12 +1,13 @@ -""" Assorted utilities with minimal dependencies """ +"""Assorted utilities with minimal dependencies""" + from __future__ import annotations -from .aio import LazyWrapAsync, achunked, alist +from .aio import LazyWrapAsync, achunked from .cli import autocli from .consistent_hash import consistent_hash, picklehash from .datetime import date_range, datetime_range, round_dt from .enum import ComparableEnum -from .logging import assert_with_logging, setup_logging, warn_with_traceback +from .logging import WarnWithTraceback, assert_with_logging, setup_logging from .register import Register from .timing import RateLimiter, Timer from .utils import chunks, deep_update, ensure_unique_keys, grouped, split_list diff --git a/replete/aio.py b/replete/aio.py index b13e23f..89a9e6b 100644 --- a/replete/aio.py +++ b/replete/aio.py @@ -11,14 +11,6 @@ TLazyWrapValue = TypeVar("TLazyWrapValue") -async def alist(async_iter: AsyncIterable[T]) -> list[T]: - """Simple gatherer of an async iterable into a list""" - result = [] - async for item in async_iter: - result.append(item) - return result - - async def achunked(aiterable: AsyncIterable[T], size: int) -> AsyncIterable[list[T]]: """Async iterable chunker""" chunk: list[T] = [] diff --git a/replete/cli.py b/replete/cli.py index 4e85cb1..3a9d4f9 100644 --- a/replete/cli.py +++ b/replete/cli.py @@ -6,17 +6,19 @@ import functools import inspect import typing +from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, TypeVar, Union, cast, overload # noqa: TC002 +from typing import Any, ClassVar, Generic, NamedTuple, Sequence, TypeVar, Union, cast, overload import docstring_parser -if TYPE_CHECKING: - from collections.abc import Sequence - T = TypeVar("T") +class CLIError(Exception): + """Exception for all things CLI""" + +T = TypeVar("T") TRet = TypeVar("TRet") TCallable = TypeVar("TCallable", bound=Union[Callable[..., TRet], Callable[..., None]]) # type: ignore NoneType = type(None) @@ -33,21 +35,25 @@ def pairs_to_dict( key_parse: Callable[[Any], Any] = str, val_parse: Callable[[Any], Any] = str, + *, allow_empty: bool = True, - name: str = "", ) -> Callable[[Any], Any]: + """ + Convert list of keys and values into dict + """ + def _pairs_to_dict_configured(items: Sequence[T] | None) -> dict[T, T]: if not items: if allow_empty: return {} - raise ValueError(f"{name!r}: Must not be empty") + raise ValueError("Got empty list of items") if isinstance(items, dict): return items if len(items) % 2 != 0: - raise ValueError("Must have key-value pairs (even number)", dict(items=items)) + raise ValueError("Must have key-value pairs (even number)", {"items": items}) keys = [key_parse(key) for key in items[::2]] values = [val_parse(val) for val in items[1::2]] - return dict(zip(keys, values)) + return dict(zip(keys, values, strict=True)) return _pairs_to_dict_configured @@ -65,8 +71,8 @@ def _resolve_type_annotation(type_annotation: Any, annotations_ns: dict[str, Any return None, None if isinstance(type_annotation, str): try: - type_annotation = eval(type_annotation, COMMON_ANNOTATIONS_NS | (annotations_ns or {})) - except Exception: + type_annotation = eval(type_annotation, COMMON_ANNOTATIONS_NS | (annotations_ns or {})) # noqa: S307 + except Exception: # noqa: BLE001 return None, None type_origin = getattr(type_annotation, "__origin__", None) @@ -87,11 +93,11 @@ def _resolve_type_annotation(type_annotation: Any, annotations_ns: dict[str, Any if isinstance(type_args, tuple) and len(type_args) == 2: return type_origin, type_args return type_origin, None - elif type_origin in (list, tuple): + if type_origin in (list, tuple): if isinstance(type_args, tuple) and len(type_args) == 1: return type_origin, type_args[0] return type_origin, None - elif type_name in ("Sequence", "Iterable"): + if type_name in ("Sequence", "Iterable"): if isinstance(type_args, tuple) and len(type_args) == 1: return list, type_args[0] return list, None @@ -153,7 +159,7 @@ def replace(self, **kwargs: Any) -> ParamExtras: @dataclass -class AutoCLIBase(Generic[TCallable]): +class AutoCLIBase(Generic[TCallable], ABC): """Base interface for an automatic function-to-CLI processing""" func: TCallable @@ -169,20 +175,21 @@ class AutoCLIBase(Generic[TCallable]): } TYPE_AUTO_PARAMS: ClassVar[dict[Any, Callable[[ParamInfo], ParamExtras]]] = { # TODO: actually make the boolean arguments into `--flag/--no-flag`. - bool: lambda param_info: ParamExtras(arg_argparse_extra_kwargs=dict(type=parse_bool)), + bool: lambda _: ParamExtras(arg_argparse_extra_kwargs={"type": parse_bool}), dict: lambda param_info: ParamExtras( - arg_argparse_extra_kwargs=dict(nargs="*", type=str), + arg_argparse_extra_kwargs={"nargs": "*", "type": str}, arg_postprocess=pairs_to_dict( key_parse=param_info.contained_type[0] if param_info.contained_type else str, val_parse=param_info.contained_type[1] if param_info.contained_type else str, ), ), list: lambda param_info: ParamExtras( - arg_argparse_extra_kwargs=dict(nargs="*", type=param_info.contained_type or str) + arg_argparse_extra_kwargs={"nargs": "*", "type": param_info.contained_type or str}, ), } - def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: + @abstractmethod + def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: # type: ignore raise NotImplementedError @classmethod @@ -199,7 +206,10 @@ def _make_param_infos( defaults = defaults or {} return [ ParamInfo.from_arg_param( - arg_param, doc=params_docs.get(name), annotations_ns=annotations_ns, default=defaults.get(name) + arg_param, + doc=params_docs.get(name), + annotations_ns=annotations_ns, + default=defaults.get(name), ) for name, arg_param in signature.parameters.items() ] @@ -233,7 +243,7 @@ class AutoCLI(AutoCLIBase[TCallable]): signature_override: inspect.Signature | None = None _description_indent: str = " " # mock default for `dataclass`, gets filled in `__post_init__` with an actual value. - __wrapped__: Callable[..., Any] = lambda: None + __wrapped__: Callable[..., Any] = lambda: None # noqa: E731 @classmethod def _base_param_extras(cls, param_info: ParamInfo) -> ParamExtras: @@ -274,7 +284,8 @@ def _full_params_extras(self, param_infos: Sequence[ParamInfo]) -> dict[str, Par return result def _make_full_parser_and_params_extras( - self, defaults: dict[str, Any] + self, + defaults: dict[str, Any], ) -> tuple[argparse.ArgumentParser, dict[str, ParamExtras]]: docs = docstring_parser.parse(self.func.__doc__ or "") params_docs = {param.arg_name: param.description for param in docs.params} if docs.params else {} @@ -282,7 +293,11 @@ def _make_full_parser_and_params_extras( description = _indent(description, self._description_indent) parser = self._make_base_parser(description=description) param_infos = self._make_param_infos( - self.func, params_docs, self.annotations_ns, defaults, self.signature_override + self.func, + params_docs, + self.annotations_ns, + defaults, + self.signature_override, ) params_extras = self._full_params_extras(param_infos) for param in param_infos: @@ -300,7 +315,7 @@ def _make_base_parser(self, description: str) -> argparse.ArgumentParser: formatter_class = type( f"_Custom_{formatter_class.__name__}", (formatter_class,), - dict(_default_width=self.help_width, _default_max_help_position=self.max_help_position), + {"_default_width": self.help_width, "_default_max_help_position": self.max_help_position}, ) return argparse.ArgumentParser(formatter_class=formatter_class, description=description) # type: ignore @@ -324,28 +339,27 @@ def _add_argument( param: ParamInfo, arg_extra_kwargs: dict[str, Any] | None = None, ) -> None: - if param.required: - arg_name = cls._pos_param_to_arg_name(name) - else: - arg_name = cls._opt_param_to_arg_name(name) + arg_name = cls._pos_param_to_arg_name(name) if param.required else cls._opt_param_to_arg_name(name) type_converter = cls.TYPE_CONVERTERS.get(param.value_type) or param.value_type - arg_kwargs = dict(type=type_converter, help=param.doc) + arg_kwargs = {"type": type_converter, "help": param.doc} if not param.required: - arg_kwargs.update(dict(default=param.default, help=param.doc or " ")) + arg_kwargs.update({"default": param.default, "help": param.doc or " "}) if param.extra_args: # `func(*args)` # Implementation: `arg_kwargs.update(nargs="*")` - raise Exception("Not currently supported: `**args` in function") + raise CLIError("Not currently supported: `**args` in function") if param.extra_kwargs: # `func(**kwargs)` - raise Exception("Not currently supported: `**kwargs` in function") + raise CLIError("Not currently supported: `**kwargs` in function") arg_extra_kwargs = dict(arg_extra_kwargs or {}) arg_extra_args = arg_extra_kwargs.pop("__args__", None) or () arg_kwargs.update(arg_extra_kwargs) parser.add_argument(arg_name, *arg_extra_args, **arg_kwargs) # type: ignore # very tricky typing. def _make_overridden_defaults( - self, args: tuple[Any, ...], kwargs: dict[str, Any] + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], ) -> tuple[dict[str, Any], tuple[Any, ...]]: var_args = cast(tuple[Any, ...], ()) if not args: @@ -353,20 +367,24 @@ def _make_overridden_defaults( signature = inspect.signature(self.func) var_arg = next( - (param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_POSITIONAL), None + (param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_POSITIONAL), + None, ) var_kwarg = next( - (param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_KEYWORD), None + (param for param in signature.parameters.values() if param.kind is inspect.Parameter.VAR_KEYWORD), + None, ) if var_arg is not None: - raise Exception("Not currently supported: `**args` in function") + raise CLIError("Not currently supported: `**args` in function") if var_kwarg is not None: - raise Exception("Not currently supported: `**kwargs` in function") + raise CLIError("Not currently supported: `**kwargs` in function") return signature.bind_partial(*args, **kwargs).arguments, var_args def _postprocess_values( - self, parsed_kwargs: dict[str, Any], params_extras: dict[str, ParamExtras] + self, + parsed_kwargs: dict[str, Any], + params_extras: dict[str, ParamExtras], ) -> dict[str, Any]: parsed_kwargs = parsed_kwargs.copy() for name, param_extras in params_extras.items(): @@ -374,15 +392,15 @@ def _postprocess_values( parsed_kwargs[name] = param_extras.arg_postprocess(parsed_kwargs[name]) return parsed_kwargs - def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: + def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: # type: ignore defaults, var_args = self._make_overridden_defaults(args, kwargs) parser, params_extras = self._make_full_parser_and_params_extras(defaults=defaults) params, unknown_args = parser.parse_known_args(self.argv) if unknown_args and self.fail_on_unknown_args: - raise Exception("Unrecognized arguments", dict(unknown_args=unknown_args)) - parsed_args = params._get_args() # generally, an empty list. - parsed_kwargs_base = params._get_kwargs() + raise SystemExit("Unrecognized arguments", {"unknown_args": unknown_args}) + parsed_args = params._get_args() # generally, an empty list. # noqa: SLF001 + parsed_kwargs_base = params._get_kwargs() # noqa: SLF001 arg_name_to_param_name = { extras.name_norm or param_name: param_name for param_name, extras in params_extras.items() } @@ -391,7 +409,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> TRet | None: full_args = [*var_args, *parsed_args] full_kwargs = {**kwargs, **parsed_kwargs} - return self.func(*full_args, **full_kwargs) + return self.func(*full_args, **full_kwargs) # type: ignore def get_caller_ns(extra_stack: int = 0) -> dict[str, Any] | None: @@ -408,17 +426,16 @@ def get_caller_ns(extra_stack: int = 0) -> dict[str, Any] | None: @overload -def autocli(func: TCallable, **config: Any) -> AutoCLI[TCallable]: - ... +def autocli(func: TCallable, **config: Any) -> AutoCLI[TCallable]: ... @overload -def autocli(func: None = None, **config: Any) -> Callable[[TCallable], AutoCLI[TCallable]]: - ... +def autocli(func: None = None, **config: Any) -> Callable[[TCallable], AutoCLI[TCallable]]: ... def autocli( - func: TCallable | None = None, **config: Any + func: TCallable | None = None, + **config: Any, ) -> AutoCLI[TCallable] | Callable[[TCallable], AutoCLI[TCallable]]: actual_config = config.copy() diff --git a/replete/consistent_hash.py b/replete/consistent_hash.py index 3e8096f..5f2a7d3 100644 --- a/replete/consistent_hash.py +++ b/replete/consistent_hash.py @@ -3,7 +3,7 @@ import collections.abc import contextlib import datetime -import pickle +import pickle # noqa S403, not loading anything here, only dumping from typing import TYPE_CHECKING import xxhash @@ -28,7 +28,7 @@ # And pickling is very slow for `tzinfo` (~25 microseconds). # `.timestamp()` is around ~5 microseconds with timezone. datetime.datetime, - ) + ), ) @@ -36,6 +36,7 @@ def consistent_hash_raw_update( hasher: xxhash.xxh3_64, params: Sequence[Any] = (), primitive_types: frozenset[type] = PRIMITIVE_TYPES, + *, type_name_dependence: bool = False, try_pickle: bool = True, ) -> None: @@ -51,7 +52,9 @@ def consistent_hash_raw_update( hasher.update(b"\x02") hasher.update(repr(param).encode()) elif (chashmeth := getattr(param, "_consistent_hash", None)) is not None and getattr( - chashmeth, "__self__", None + chashmeth, + "__self__", + None, ) is not None: rec_int = chashmeth() hasher.update(b"\x03") @@ -88,7 +91,7 @@ def _normalize(value: Any) -> Any: value_type = type(value) if value_type in PRIMITIVE_TYPES: return value - chashmeth = getattr(value, "_consistent_hash", None) + chashmeth = getattr(value, "consistent_hash", None) if chashmeth is not None and getattr(chashmeth, "__self__", None) is not None: return chashmeth() # Note: makes the result type-independent. if value_type is list or value_type is tuple or isinstance(value, (list, tuple)): diff --git a/replete/datetime.py b/replete/datetime.py index 348c8fd..15f8c35 100644 --- a/replete/datetime.py +++ b/replete/datetime.py @@ -3,7 +3,7 @@ import datetime as dt from typing import Iterable -from dateutil.relativedelta import relativedelta +from dateutil.relativedelta import relativedelta # noqa: TCH002, required for doctests def date_range(start: dt.date, stop: dt.date, step_days: int = 1) -> Iterable[dt.date]: @@ -41,7 +41,8 @@ def datetime_range(start: dt.datetime, stop: dt.datetime | None, step: dt.timede >>> _dts_s(datetime_range(dt1, dt2, dt.timedelta(seconds=11.11111111111)))[-1] '2022-02-03T23:59:59.998272' """ - assert step, "must be non-zero" + if not step: + raise ValueError(f"Step must be positive, step = {step}") forward = step > dt.timedelta() current = start while True: @@ -52,7 +53,9 @@ def datetime_range(start: dt.datetime, stop: dt.datetime | None, step: dt.timede def round_dt( - datetime: dt.datetime, delta: dt.timedelta | relativedelta, start_time: dt.time = dt.time.min + datetime: dt.datetime, + delta: dt.timedelta | relativedelta, + start_time: dt.time = dt.time.min, ) -> dt.datetime: """ Round time-from-midnight to the specified timedelta. @@ -100,8 +103,7 @@ def round_dt( microsecond=start_time.microsecond, ) return ts_start + dt.timedelta(seconds=(datetime - ts_start).total_seconds() // seconds * seconds) - else: - raise ValueError(f"Timedelta should be one day or less, got: {delta}") + raise ValueError(f"Timedelta should be one day or less, got: {delta}") if any([delta.microseconds, delta.seconds, delta.minutes, delta.hours]): raise ValueError("relativedelta more precise than day is not supported, use timedelta") @@ -117,9 +119,11 @@ def round_dt( ) if delta.years == 0: - assert delta.months in [1, 2, 3, 4, 6], "months should be 1, 2, 3, 4 or 6" + if delta.months not in [1, 2, 3, 4, 6]: + raise ValueError(f"delta.months should be 1, 2, 3, 4 or 6. Got {delta=}") return datetime.replace(month=1 + (datetime.month - 1) // delta.months * delta.months) - assert delta.months == 0, "months should be 0 if years are used" + if delta.months != 0: + raise ValueError(f"delta.months should be 0 if years are used. Got {delta=}") datetime = datetime.replace(month=1) return datetime.replace(year=datetime.year // delta.years * delta.years) diff --git a/replete/funcutils.py b/replete/funcutils.py index 69341b1..b2c91f1 100644 --- a/replete/funcutils.py +++ b/replete/funcutils.py @@ -2,7 +2,7 @@ import logging import operator -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Type, TypeVar, cast, overload +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, TypeVar, cast, overload if TYPE_CHECKING: TRightDefault = TypeVar("TRightDefault") @@ -22,8 +22,7 @@ def join_ffill( right_lst: Iterable[TRight], condition: Callable[[TLeft, TRight], bool] = cast(Callable[[TLeft, TRight], bool], operator.ge), default: None = None, -) -> Iterable[tuple[TLeft, TRight | None]]: - ... +) -> Iterable[tuple[TLeft, TRight | None]]: ... @overload @@ -32,8 +31,7 @@ def join_ffill( right_lst: Iterable[TRight], condition: Callable[[TLeft, TRight], bool], default: TRightDefault, -) -> Iterable[tuple[TLeft, TRight | TRightDefault]]: - ... +) -> Iterable[tuple[TLeft, TRight | TRightDefault]]: ... def join_ffill( @@ -85,8 +83,7 @@ def join_backfill( right_lst: Iterable[TRight], condition: Callable[[TLeft, TRight], bool] = cast(Callable[[TLeft, TRight], bool], operator.le), default: None = None, -) -> Iterable[tuple[TLeft, TRight | None]]: - ... +) -> Iterable[tuple[TLeft, TRight | None]]: ... @overload @@ -95,8 +92,7 @@ def join_backfill( right_lst: Iterable[TRight], condition: Callable[[TLeft, TRight], bool], default: TRightDefault, -) -> Iterable[tuple[TLeft, TRight | TRightDefault]]: - ... +) -> Iterable[tuple[TLeft, TRight | TRightDefault]]: ... def join_backfill( @@ -139,12 +135,12 @@ def join_backfill( yield left_item, default if right_item is right_done_marker else cast(TRight, right_item) -def yield_or_skip(iter_: Iterable, func: Callable, skip_on_errors: Iterable[Type[Exception]]) -> Iterator: +def yield_or_skip(iter_: Iterable, func: Callable, skip_on_errors: Iterable[type[Exception]]) -> Iterator: skip_on_errors = tuple(skip_on_errors) for item in iter_: try: yield func(item) - except Exception as e: + except Exception as e: # noqa: PERF203 intentional if isinstance(e, skip_on_errors): LOGGER.debug(f"Skipping {item} due to {e}") continue diff --git a/replete/logging.py b/replete/logging.py index efd04d1..d9062b7 100644 --- a/replete/logging.py +++ b/replete/logging.py @@ -4,11 +4,13 @@ import sys import traceback import warnings -from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from coloredlogs import DEFAULT_FIELD_STYLES, DEFAULT_LEVEL_STYLES, ColoredFormatter +if TYPE_CHECKING: + from pathlib import Path + ORIGINAL_SHOWWARNINGS = warnings.showwarning @@ -18,7 +20,7 @@ def _warn_with_traceback(message, category, filename, lineno, file=None, line=No ORIGINAL_SHOWWARNINGS(message, category, filename, lineno, file, line) -class warn_with_traceback: +class WarnWithTraceback: def __enter__(self) -> None: self._orig_show = warnings.showwarning warnings.showwarning = _warn_with_traceback @@ -27,13 +29,13 @@ def __exit__(self, *_): warnings.showwarning = ORIGINAL_SHOWWARNINGS -def assert_with_logging(status: bool, message: str): +def assert_with_logging(status: bool, message: str): # noqa: FBT001 if not status: logging.critical(message) - assert status, message + assert status, message # noqa: S101 -class change_logging_level: +class ChangeLoggingLevel: def __init__(self, level: int): self._level = level self._original_level: int = None @@ -57,7 +59,7 @@ def change_level(record: logging.LogRecord): logger.addFilter(change_level) -def get_file_handler(log_file: Path, logging_level: int | str = logging.DEBUG, append=True, use_year=False): +def get_file_handler(log_file: Path, logging_level: int | str = logging.DEBUG, *, append=True, use_year=False): file_handler = logging.FileHandler(log_file, mode="a" if append else "w") formatter = get_logging_formatter(use_year=use_year) file_handler.setFormatter(formatter) @@ -79,7 +81,11 @@ def get_file_handler(log_file: Path, logging_level: int | str = logging.DEBUG, a def get_logging_formatter( - color=False, level_styles_update: StylesType = None, field_styles_update: StylesType = None, use_year=False + *, + color=False, + use_year=False, + level_styles_update: StylesType = None, + field_styles_update: StylesType = None, ): fmt = CUSTOM_FORMAT maybe_year = "%Y-" if use_year else "" @@ -115,9 +121,10 @@ def filter(self, record: logging.LogRecord) -> bool: # noqa: A003 def setup_logging( log_file: Path = None, print_level: int | str = logging.INFO, + *, append=False, - level_styles_update: dict[str, dict[str, Any]] = {}, - field_styles_update: dict[str, dict[str, Any]] = {}, + level_styles_update: dict[str, dict[str, Any]] = None, + field_styles_update: dict[str, dict[str, Any]] = None, disable_colors=False, log_uncaught_exceptions=True, use_year=False, @@ -135,11 +142,16 @@ def setup_logging( :param use_year: Add year to logs :param whitelist: Only log these modules """ + level_styles_update = level_styles_update or {} + field_styles_update = field_styles_update or {} if isinstance(print_level, str): print_level = int(print_level) if print_level.isdigit() else getattr(logging, print_level) formatter = get_logging_formatter(use_year=use_year) colored_formatter = get_logging_formatter( - color=True, level_styles_update=level_styles_update, field_styles_update=field_styles_update, use_year=use_year + color=True, + level_styles_update=level_styles_update, + field_styles_update=field_styles_update, + use_year=use_year, ) handlers = [] diff --git a/replete/py.typed b/replete/py.typed deleted file mode 100644 index e69de29..0000000 diff --git a/replete/register.py b/replete/register.py index 9d80328..cf90d95 100644 --- a/replete/register.py +++ b/replete/register.py @@ -1,9 +1,7 @@ from __future__ import annotations import logging -import re -from collections.abc import Iterable -from typing import Any, ClassVar +from typing import Any, ClassVar, Iterable LOGGER = logging.getLogger(__name__) @@ -25,15 +23,17 @@ class Register: def _set_register_base_class(cls, obj: type[Register], suffix: str = None) -> None: if hasattr(cls, "_register_base_class"): raise RegisterError( - f"Can't set {obj} as base class, base class is already set to {cls._register_base_class}" + f"Can't set {obj} as base class, base class is already set to {cls._register_base_class}", ) - obj._register_base_class = obj - obj._register_data = {} + obj._register_base_class = obj # noqa: SLF001 + obj._register_data = {} # noqa: SLF001 if suffix is None: - suffix_match = re.match(r".*([A-Z][^A-Z]+)$", obj.__name__) # 'BaseClass2' -> 'Class2' - assert suffix_match - suffix = suffix_match.group(1) - obj._register_suffix = suffix + suffix = obj.__name__ + if suffix.startswith("Base") and len(suffix) != 4: + suffix = suffix[4:] + if not suffix: + raise ValueError(f"Can't find suffix from name = {obj.__name__}, please provide manually") + obj._register_suffix = suffix # noqa: SLF001 # type: ignore @classmethod def register_class(cls, obj: type, name: str = None) -> None: @@ -51,18 +51,19 @@ def register_class(cls, obj: type, name: str = None) -> None: if old_cls is not None: # Using `__dict__` instead of `hasattr` because `hasattr` includes superclasses. if "__attrs_attrs__" in cls.__dict__ and "__attrs_attrs__" not in old_cls.__dict__: - old_cls._name_in_register = None + old_cls._name_in_register = None # noqa: SLF001 # type: ignore else: raise KeyError(f"{name} is already registered! (Most likely subclass name clash)") cls._register_data[name] = obj - obj._name_in_register = name + obj._name_in_register = name # noqa: SLF001 # Have to do this, since registered class might not be a subclass of Register. # "Cannot assign to a method", "expression has type "classmethod[Any]", variable has type "Callable[[], str]"" obj.get_name_in_register = classmethod(Register.get_name_in_register.__func__) # type: ignore def __init_subclass__( cls, + *, base: bool = False, base_class: type = None, abstract: bool = False, diff --git a/replete/timing.py b/replete/timing.py index b14b942..3fb7264 100644 --- a/replete/timing.py +++ b/replete/timing.py @@ -6,7 +6,7 @@ class Timer: - def __init__(self, base_time: float = None, process_only=False): + def __init__(self, base_time: float = None, *, process_only=False): self._base_time = base_time or self.__class__.get_sample_base_time() self._time_func = getattr(time, "process_time" if process_only else "perf_counter") @@ -30,7 +30,7 @@ def base_time_ratio(self) -> float: return self.time / self.base_time @classmethod - def get_sample_base_time(cls, length=24, process_only=False) -> float: + def get_sample_base_time(cls, length=24, *, process_only=False) -> float: def dumb_fibonacci(n: int) -> int: if n < 2: return n diff --git a/replete/utils.py b/replete/utils.py index cac4cce..8e8487d 100644 --- a/replete/utils.py +++ b/replete/utils.py @@ -5,13 +5,13 @@ import logging import weakref from concurrent import futures -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Iterator, Mapping, Sequence, TypeVar, cast - -from replete.abc import Comparable +from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Iterator, Mapping, Sequence, TypeVar LOGGER = logging.getLogger(__name__) if TYPE_CHECKING: + from replete.abc import Comparable + TKey = TypeVar("TKey", bound=Hashable) TVal = TypeVar("TVal") # For `sort`-like `key=...` argument: @@ -56,7 +56,8 @@ def iterchunks(iterable: Iterable[TVal], size: int) -> Iterator[Sequence[TVal]]: >>> list(iterchunks(range(10), 3)) [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)] """ - assert size > 0 + if size < 1: + raise ValueError(f"Invalid chunk size = {size}") iterator = iter(iterable) while True: chunk = tuple(itertools.islice(iterator, size)) @@ -121,10 +122,7 @@ def deep_update(target: dict, updates: Mapping) -> dict: target = target.copy() for key, value in updates.items(): old_value = target.get(key) - if isinstance(old_value, dict): - new_value = deep_update(old_value, value) - else: - new_value = value + new_value = deep_update(old_value, value) if isinstance(old_value, dict) else value target[key] = new_value return target @@ -134,6 +132,7 @@ def futures_processing( args_list: list[tuple] = None, kwargs_list: list[dict] = None, max_workers: int = None, + *, in_order=False, only_log_exceptions=False, ) -> Iterator: @@ -141,10 +140,11 @@ def futures_processing( if args_list is None and kwargs_list is None: raise ValueError("Must provide either args_list or kwargs_list") if args_list is None: - args_list = [()] * len(kwargs_list) + args_list = [()] * len(kwargs_list) # type: ignore if kwargs_list is None: kwargs_list = [{}] * len(args_list) - assert len(args_list) == len(kwargs_list), "args_list and kwargs_list must be the same length" + if len(args_list) != len(kwargs_list): + raise ValueError("args_list and kwargs_list must be the same length") def func_with_idx(idx, *args, **kwargs) -> tuple[int, Any]: try: @@ -163,7 +163,7 @@ def func_with_idx(idx, *args, **kwargs) -> tuple[int, Any]: current_result_idx = 0 for future in futures.as_completed(results): if future.exception(): - raise future.exception() + raise future.exception() # noqa: RSE102 False, positive # type: ignore idx, func_result = future.result() if in_order: cache_results[idx] = func_result @@ -174,10 +174,13 @@ def func_with_idx(idx, *args, **kwargs) -> tuple[int, Any]: yield func_result -def weak_lru_cache(maxsize=128, typed=False): - """LRU Cache decorator that keeps a weak reference to 'self'""" +def weak_lru_cache(maxsize: Callable | int | None = 128, *, typed=False): + """ + LRU Cache decorator that keeps a weak reference to 'self' + Should be used instead of functools.lru_cache on methods + """ - def helper(maxsize: int, typed: bool, user_function: Callable) -> Callable: + def helper(maxsize: int, user_function: Callable, *, typed: bool) -> Callable: @functools.lru_cache(maxsize, typed) def _func(_self, *args, **kwargs): return user_function(_self(), *args, **kwargs) @@ -191,7 +194,8 @@ def inner(self, *args, **kwargs): if callable(maxsize) and isinstance(typed, bool): # The user_function was passed in directly via the maxsize argument user_function, maxsize = maxsize, 128 - return helper(maxsize, typed, user_function) - assert type(maxsize) == int or maxsize is None, "Expected maxsize to be an integer or None" + return helper(maxsize, user_function, typed=typed) + if not (type(maxsize) is int or maxsize is None): + raise ValueError("Expected maxsize to be an integer or None") - return functools.partial(helper, maxsize, typed) + return functools.partial(helper, maxsize, typed=typed) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4428e01..fe96982 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -20,7 +20,7 @@ @autocli(fail_on_unknown_args=False, help_width=100, max_help_position=50) def example_cli( output_dir: Path, - no_help: str, + no_help: str, # noqa: ARG001 res_name: str = "test", date_: Optional[dt.date] = None, items: Optional[dict[str, int]] = None, @@ -54,9 +54,7 @@ def example_cli( usage: {MAIN_NAME} [-h] [--res-name RES_NAME] [--date DATE] [--items [ITEMS ...]] [--no-type-longname NO_TYPE_LONGNAME] output_dir no_help -""".strip( - "\n" -) +""".strip("\n") # fmt: off OPTIONAL_HELP_TEXT = """ --res-name RES_NAME Name for result (default: test) @@ -64,7 +62,7 @@ def example_cli( --items [ITEMS ...] --no-type-longname NO_TYPE_LONGNAME """.strip( - "\n" + "\n", ) # fmt: on diff --git a/tests/test_consistent_hash.py b/tests/test_consistent_hash.py index 6f4785d..159e08e 100644 --- a/tests/test_consistent_hash.py +++ b/tests/test_consistent_hash.py @@ -40,7 +40,7 @@ def __repr__(self) -> str: ({"foo": 1, "bar": 2}, {"foo": 1, "quux": 2}, False), (ConsistentHashObj(123), ConsistentHashObj(123), True), (ConsistentHashObj(123), ConsistentHashObj(1234), False), - (dict(cls=ConsistentHashObj), dict(cls=ConsistentHashObj), True), + ({"cls": ConsistentHashObj}, {"cls": ConsistentHashObj}, True), ] @@ -61,7 +61,7 @@ def consistent_hash_ref(*args: Any, **kwargs: Any) -> int: hashes.append(consistent_hash(*param)) else: hashes.append(repr(param)) - hasher = hashlib.md5() + hasher = hashlib.md5() # noqa: S324 for hash_piece in hashes: hasher.update(str(hash_piece).encode()) return int(hasher.hexdigest(), 16) @@ -71,8 +71,8 @@ def consistent_hash_ref2_raw(args: Sequence[Any] = (), kwargs: dict[str, Any] | params = [*args, *sorted(kwargs.items())] if kwargs else args hasher = xxhash.xxh128() for param in params: - if hasattr(param, "_consistent_hash") and hasattr(param._consistent_hash, "__self__"): - rec_int = param._consistent_hash() + if hasattr(param, "consistent_hash") and hasattr(param.consistent_hash, "__self__"): + rec_int = param.consistent_hash() hasher.update(rec_int.to_bytes(16, "little")) elif isinstance(param, Mapping): rec = consistent_hash_ref2_raw((), {str(key): value for key, value in param.items()}) diff --git a/tests/test_logging.py b/tests/test_logging.py index 492304b..941ba80 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -6,19 +6,19 @@ import pytest -from replete import assert_with_logging, setup_logging, warn_with_traceback -from replete.logging import change_logging_level, offset_logger_level +from replete import WarnWithTraceback, assert_with_logging, setup_logging +from replete.logging import ChangeLoggingLevel, offset_logger_level def test_warnings_traceback(capsys): - with pytest.warns(UserWarning), warn_with_traceback(): - warnings.warn("Test") + with pytest.warns(UserWarning), WarnWithTraceback(): + warnings.warn("Test") # noqa: B028 captured = capsys.readouterr() assert captured.out == "" assert len(captured.err) > 4 with pytest.warns(UserWarning): - warnings.warn("Test") + warnings.warn("Test") # noqa: B028 captured = capsys.readouterr() assert captured.out == "" assert captured.err == "" @@ -29,7 +29,7 @@ def test_change_logging_level(caplog): logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) logger.info("before") - with change_logging_level(logging.WARNING): + with ChangeLoggingLevel(logging.WARNING): logger.info("inside") logger.info("after") assert "before" in caplog.text @@ -38,9 +38,9 @@ def test_change_logging_level(caplog): def test_assert_with_logging(caplog): - assert_with_logging(True, "bar") + assert_with_logging(True, "bar") # noqa: FBT003 with pytest.raises(AssertionError): - assert_with_logging(False, "foo") + assert_with_logging(False, "foo") # noqa: FBT003 assert caplog.record_tuples == [("root", logging.CRITICAL, "foo")] diff --git a/tests/test_register.py b/tests/test_register.py index a7eab47..d9899f6 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -1,8 +1,6 @@ # type: ignore from __future__ import annotations -from typing import Type - import pytest from replete import Register @@ -16,35 +14,35 @@ class BaseClass(Register, base=True, abstract=True): return BaseClass -def test_registered_names(base_class: Type[Register]) -> None: +def test_registered_names(base_class: type[Register]) -> None: class SubClass(base_class): pass assert set(base_class.get_registered_names()) == {"Sub"} -def test_get_subclass(base_class: Type[Register]) -> None: +def test_get_subclass(base_class: type[Register]) -> None: class SubClass(base_class): pass assert base_class.get_subclass("Sub") == SubClass -def test_get_all_subclasses(base_class: Type[Register]) -> None: +def test_get_all_subclasses(base_class: type[Register]) -> None: class SubClass(base_class): pass assert set(base_class.get_all_subclases()) == {SubClass} -def test_get_name_in_register(base_class: Type[Register]) -> None: +def test_get_name_in_register(base_class: type[Register]) -> None: class SubClass(base_class): pass assert SubClass.get_name_in_register() == "Sub" -def test_check_double_register_error(base_class: Type[Register]) -> None: +def test_check_double_register_error(base_class: type[Register]) -> None: class SubClass(base_class): pass diff --git a/tests/test_timing.py b/tests/test_timing.py index d0c9dd3..6d19677 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -48,7 +48,7 @@ def test_base_time_ratio(): def test_rate_limiter_weight(): rate_limiter = RateLimiter(20, period_seconds=0.2) - weights = [random.randint(3, 7) for _ in range(20)] + weights = [random.randint(3, 7) for _ in range(20)] # noqa: S311, false positive with Timer() as t: for weight in weights: rate_limiter.check_rate(weight) diff --git a/tests/test_utils.py b/tests/test_utils.py index ffe3c06..a361db7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -33,7 +33,7 @@ def futures_processing_test_vars(): wait_time = 0.1 def func(str_: str, *, num: int) -> str: - time.sleep(wait_time + random.random() * 0.1) + time.sleep(wait_time + random.random() * 0.1) # noqa: S311 false positive return str_ * num args = [(c,) for c in list("abcdefg")] @@ -125,5 +125,5 @@ def func() -> None: func() gc.collect() # collect garbage - # Since f went out of scope after func() finished, it should be garbage collected + # Since foo went out of scope after func() finished, it should be garbage collected assert len([obj for obj in gc.get_objects() if isinstance(obj, WeakCacheTester)]) == 0