|
2 | 2 |
|
3 | 3 | import operator as ops
|
4 | 4 | from abc import ABC, abstractmethod
|
| 5 | +from functools import partial |
5 | 6 | from typing import Any, Callable, Optional, Union
|
6 | 7 |
|
7 | 8 | import numpy as np
|
8 | 9 | import pandas as pd
|
9 | 10 | from attr import define, field
|
10 | 11 | from attr.validators import in_
|
| 12 | +from attrs.validators import min_len |
| 13 | +from cattrs.gen import override |
11 | 14 | from funcy import rpartial
|
12 | 15 | from numpy.typing import ArrayLike
|
13 | 16 |
|
14 |
| -from baybe.serialization import SerialMixin |
| 17 | +from baybe.parameters.validation import validate_unique_values |
| 18 | +from baybe.serialization import ( |
| 19 | + SerialMixin, |
| 20 | + converter, |
| 21 | + get_base_structure_hook, |
| 22 | + unstructure_base, |
| 23 | +) |
| 24 | +from baybe.utils.numerical import DTypeFloatNumpy |
15 | 25 |
|
16 | 26 |
|
17 | 27 | def _is_not_close(x: ArrayLike, y: ArrayLike, rtol: float, atol: float) -> np.ndarray:
|
@@ -135,9 +145,38 @@ class SubSelectionCondition(Condition):
|
135 | 145 | """Class for defining valid parameter entries."""
|
136 | 146 |
|
137 | 147 | # object variables
|
138 |
| - selection: list[Any] = field() |
139 |
| - """The list of items which are considered valid.""" |
| 148 | + _selection: tuple = field( |
| 149 | + converter=tuple, |
| 150 | + # FIXME[typing]: https://github.com/python-attrs/attrs/issues/1197 |
| 151 | + validator=[ |
| 152 | + min_len(1), |
| 153 | + validate_unique_values, # type: ignore |
| 154 | + ], |
| 155 | + ) |
| 156 | + """The internal list of items which are considered valid.""" |
| 157 | + |
| 158 | + @property |
| 159 | + def selection(self) -> tuple: # noqa: D102 |
| 160 | + """The list of items which are considered valid.""" |
| 161 | + return tuple( |
| 162 | + DTypeFloatNumpy(itm) if isinstance(itm, (float, int, bool)) else itm |
| 163 | + for itm in self._selection |
| 164 | + ) |
140 | 165 |
|
141 | 166 | def evaluate(self, data: pd.Series) -> pd.Series: # noqa: D102
|
142 | 167 | # See base class.
|
143 | 168 | return data.isin(self.selection)
|
| 169 | + |
| 170 | + |
| 171 | +# Register (un-)structure hooks |
| 172 | +_overrides = { |
| 173 | + "_selection": override(rename="selection"), |
| 174 | +} |
| 175 | +# FIXME[typing]: https://github.com/python/mypy/issues/4717 |
| 176 | +converter.register_structure_hook( |
| 177 | + Condition, |
| 178 | + get_base_structure_hook(Condition, overrides=_overrides), # type: ignore |
| 179 | +) |
| 180 | +converter.register_unstructure_hook( |
| 181 | + Condition, partial(unstructure_base, overrides=_overrides) |
| 182 | +) |
0 commit comments