diff --git a/mizani/breaks.py b/mizani/breaks.py index cb41151..2446a39 100644 --- a/mizani/breaks.py +++ b/mizani/breaks.py @@ -14,9 +14,11 @@ from __future__ import annotations import sys +from dataclasses import KW_ONLY, dataclass, field from datetime import datetime, timedelta from itertools import product from typing import TYPE_CHECKING +from warnings import warn import numpy as np import pandas as pd @@ -29,7 +31,7 @@ from .utils import NANOSECONDS, SECONDS, log, min_max if TYPE_CHECKING: - from typing import Callable, Literal, Sequence + from typing import Literal, Sequence from mizani.typing import ( DatetimeBreaksUnits, @@ -53,6 +55,7 @@ ] +@dataclass class breaks_log: """ Integer breaks on log transformed scales @@ -76,9 +79,8 @@ class breaks_log: array([0.1, 0.3, 1. , 3. ]) """ - def __init__(self, n: int = 5, base: float = 10): - self.n = n - self.base = base + n: int = 5 + base: float = 10 def __call__(self, limits: tuple[float, float]) -> NDArrayFloat: """ @@ -124,6 +126,7 @@ def __call__(self, limits: tuple[float, float]) -> NDArrayFloat: return _breaks_log_sub(n=n, base=base)(limits) +@dataclass class _breaks_log_sub: """ Breaks for log transformed scales @@ -144,9 +147,8 @@ class _breaks_log_sub: algorithm in the r-scales package. """ - def __init__(self, n: int = 5, base: float = 10): - self.n = n - self.base = base + n: int = 5 + base: float = 10 def __call__(self, limits: tuple[float, float]) -> NDArrayFloat: base = self.base @@ -204,6 +206,7 @@ def delta(x): return breaks_extended(n=n)(limits) +@dataclass class minor_breaks: """ Compute minor breaks @@ -234,8 +237,7 @@ class minor_breaks: array([1.25, 1.5 , 1.75]) """ - def __init__(self, n: int = 1): - self.n = n + n: int = 1 def __call__( self, @@ -293,6 +295,7 @@ def __call__( return minor +@dataclass class minor_breaks_trans: """ Compute minor breaks for transformed scales @@ -335,9 +338,8 @@ class minor_breaks_trans: array([2.8, 4.6, 6.4, 8.2]) """ - def __init__(self, trans: Trans, n: int = 1): - self.trans = trans - self.n = n + trans: Trans + n: int = 1 def __call__( self, @@ -399,6 +401,7 @@ def _extend_breaks(self, major: FloatArrayLike) -> FloatArrayLike: return major +@dataclass class breaks_date: """ Regularly spaced dates @@ -426,27 +429,35 @@ class breaks_date: Breaks at 4 year intervals - >>> breaks = breaks_date('4 year') + >>> breaks = breaks_date(width='4 year') >>> [d.year for d in breaks(limits)] [2010, 2014, 2018, 2022, 2026] """ - n: int - width: int | None = None - units: DatetimeBreaksUnits | None = None - - def __init__(self, n: int = 5, width: str | None = None): - if isinstance(n, str): - width = n - - self.n = n + n: int = 5 + _: KW_ONLY + width: str | None = None + + _width: int | None = field(init=False, default=None) + _units: DatetimeBreaksUnits | None = field(init=False, default=None) + + def __post_init__(self): + # For backwards compatibility + if isinstance(self.n, str) and self.width is None: + warn( + "Passing the width as the parameter has been deprecated " + "and will not work in a future version. " + 'Use breaks_date(width="4 years")', + FutureWarning, + ) + self.width = self.n - if width: + if self.width: # Parse the width specification # e.g. '10 months' => (10, month) - _w, units = width.strip().lower().split() - self.width = int(_w) - self.units = units.rstrip("s") # type: ignore + _w, units = self.width.strip().lower().split() + self._width = int(_w) + self._units = units.rstrip("s") # type: ignore def __call__( self, limits: tuple[datetime, datetime] @@ -472,14 +483,15 @@ def __call__( ): limits = limits[0].astype(object), limits[1].astype(object) - if self.units and self.width: + if self._units and self._width: return calculate_date_breaks_byunits( - limits, self.units, self.width + limits, self._units, self._width ) else: return calculate_date_breaks_auto(limits, self.n) +@dataclass class breaks_timedelta: """ Timedelta breaks @@ -502,10 +514,11 @@ class breaks_timedelta: [0.0, 5.0, 10.0, 15.0, 20.0, 25.0] """ - _calculate_breaks: Callable[[tuple[float, float]], NDArrayFloat] + n: int = 5 + Q: Sequence[float] = (1, 2, 5, 10) - def __init__(self, n: int = 5, Q: Sequence[float] = (1, 2, 5, 10)): - self._calculate_breaks = breaks_extended(n=n, Q=Q) + def __post_init__(self): + self._calculate_breaks = breaks_extended(n=self.n, Q=self.Q) def __call__( self, limits: tuple[Timedelta, Timedelta] @@ -534,6 +547,7 @@ def __call__( # This could be cleaned up, state overload? +@dataclass class timedelta_helper: """ Helper for computing timedelta breaks @@ -561,22 +575,14 @@ class timedelta_helper: """ x: TimedeltaArrayLike - units: DurationUnit - limits: tuple[float, float] - package: Literal["pandas", "cpython"] - factor: float + units: DurationUnit | None = None - def __init__( - self, - x: TimedeltaArrayLike, - units: DurationUnit | None = None, - ): - self.x = x - self.package = self.determine_package(x[0]) - _limits = min(x), max(x) - self.limits = self.value(_limits[0]), self.value(_limits[1]) - self.units = units or self.best_units(_limits) - self.factor = self.get_scaling_factor(self.units) + def __post_init__(self): + l, h = min(self.x), max(self.x) + self.package = self.determine_package(self.x[0]) + self.limits = self.value(l), self.value(h) + self._units: DurationUnit = self.units or self.best_units((l, h)) + self.factor = self.get_scaling_factor(self._units) @classmethod def determine_package(cls, td: Timedelta) -> Literal["pandas", "cpython"]: @@ -594,7 +600,7 @@ def format_info( cls, x: TimedeltaArrayLike, units: DurationUnit | None = None ) -> tuple[NDArrayFloat, DurationUnit]: helper = cls(x, units) - return helper.timedelta_to_numeric(x), helper.units + return helper.timedelta_to_numeric(x), helper._units def best_units(self, x: TimedeltaArrayLike) -> DurationUnit: """ @@ -691,11 +697,12 @@ def to_numeric(self, td: Timedelta) -> float: determined with the object is initialised. """ if isinstance(td, pd.Timedelta): - return td.value / NANOSECONDS[self.units] + return td.value / NANOSECONDS[self._units] else: - return td.total_seconds() / SECONDS[self.units] + return td.total_seconds() / SECONDS[self._units] +@dataclass class breaks_extended: """ An extension of Wilkinson's tick position algorithm @@ -732,19 +739,14 @@ class breaks_extended: implementation is almost entirely based. """ - def __init__( - self, - n: int = 5, - Q: Sequence[float] = (1, 5, 2, 2.5, 4, 3), - only_inside: bool = False, - w: Sequence[float] = (0.25, 0.2, 0.5, 0.05), - ): - self.Q = Q - self.only_inside = only_inside - self.w = w - self.n = n + n: int = 5 + Q: Sequence[float] = (1, 5, 2, 2.5, 4, 3) + only_inside: bool = False + w: Sequence[float] = (0.25, 0.2, 0.5, 0.05) + + def __post_init__(self): # Used for lookups during the computations - self.Q_index = {q: i for i, q in enumerate(Q)} + self.Q_index = {q: i for i, q in enumerate(self.Q)} def coverage( self, dmin: float, dmax: float, lmin: float, lmax: float @@ -909,12 +911,6 @@ class breaks_symlog: """ Breaks for the Symmetric Logarithm Transform - Parameters - ---------- - n : int - Desired number of breaks - base : int - Base of logarithm Examples -------- diff --git a/tests/test_breaks.py b/tests/test_breaks.py index bef9d5d..a6d3080 100644 --- a/tests/test_breaks.py +++ b/tests/test_breaks.py @@ -182,12 +182,16 @@ class square_trans(trans): def test_breaks_date(): # cpython limits = (datetime(2010, 1, 1), datetime(2026, 1, 1)) - breaks = breaks_date("5 Years") + breaks = breaks_date(width="5 Years") assert [d.year for d in breaks(limits)] == [2010, 2015, 2020, 2025, 2030] - breaks = breaks_date("10 Years")(limits) + breaks = breaks_date(width="10 Years")(limits) assert [d.year for d in breaks] == [2010, 2020, 2030] + with pytest.warns(FutureWarning): + breaks = breaks_date("10 Years")(limits) + assert [d.year for d in breaks] == [2010, 2020, 2030] + # numpy datetime64 limits = (np.datetime64("1973"), np.datetime64("1997")) breaks = breaks_date(width="10 Years")(limits) @@ -195,7 +199,7 @@ def test_breaks_date(): # NaT limits = np.datetime64("NaT"), datetime(2017, 1, 1) - breaks = breaks_date("10 Years")(limits) + breaks = breaks_date(width="10 Years")(limits) assert len(breaks) == 0 # automatic monthly breaks