Skip to content

Commit 95f7f25

Browse files
authored
Merge: Fix dtypes related to floating point precision (#254)
Since floating point precision can be controlled via env vars (#226) various problems have surfaced letting tests fail in single precision. This PR fixes those. They were mostly related to the way `values` and `comp_df` were created for parameters, `selection` was treated in `SubSelectionCondition` and a `lookup` in a different float precision being used in a simulation. The only remaining issues with test in single precision are numerical instabilities (out of scope)
2 parents 2555bda + 79433aa commit 95f7f25

24 files changed

+153
-60
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Fixed
1212
- Non-GP surrogates not working with `deepcopy` and the simulation module due to slotted
1313
base class
14+
- Datatype inconsistencies for various parameters' `values` and `comp_df` and
15+
`SubSelectionCondition`'s `selection` related to floating point precision
1416

1517
## [0.9.0] - 2024-05-21
1618
### Added

baybe/constraints/base.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
from collections.abc import Sequence
77
from typing import TYPE_CHECKING, Any, ClassVar
88

9+
import numpy as np
910
import pandas as pd
1011
from attr import define, field
1112
from attr.validators import min_len
1213

13-
from baybe.constraints.conditions import Condition
1414
from baybe.parameters import NumericalContinuousParameter
1515
from baybe.serialization import (
1616
SerialMixin,
1717
converter,
1818
get_base_structure_hook,
1919
unstructure_base,
2020
)
21+
from baybe.utils.numerical import DTypeFloatNumpy
2122

2223
if TYPE_CHECKING:
2324
from torch import Tensor
@@ -173,16 +174,13 @@ def to_botorch(
173174
if p in param_names
174175
]
175176

176-
# TODO: Cast rhs to correct precision once BoTorch also supports single point.
177177
return (
178178
torch.tensor(param_indices),
179179
torch.tensor(self.coefficients, dtype=DTypeFloatTorch),
180-
self.rhs,
180+
np.asarray(self.rhs, dtype=DTypeFloatNumpy).item(),
181181
)
182182

183183

184184
# Register (un-)structure hooks
185-
converter.register_unstructure_hook(Condition, unstructure_base)
186-
converter.register_structure_hook(Condition, get_base_structure_hook(Condition))
187185
converter.register_unstructure_hook(Constraint, unstructure_base)
188186
converter.register_structure_hook(Constraint, get_base_structure_hook(Constraint))

baybe/constraints/conditions.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,26 @@
22

33
import operator as ops
44
from abc import ABC, abstractmethod
5+
from functools import partial
56
from typing import Any, Callable, Optional, Union
67

78
import numpy as np
89
import pandas as pd
910
from attr import define, field
1011
from attr.validators import in_
12+
from attrs.validators import min_len
13+
from cattrs.gen import override
1114
from funcy import rpartial
1215
from numpy.typing import ArrayLike
1316

14-
from baybe.serialization import SerialMixin
17+
from baybe.parameters.validation import validate_unique_values
18+
from baybe.serialization import (
19+
SerialMixin,
20+
converter,
21+
get_base_structure_hook,
22+
unstructure_base,
23+
)
24+
from baybe.utils.numerical import DTypeFloatNumpy
1525

1626

1727
def _is_not_close(x: ArrayLike, y: ArrayLike, rtol: float, atol: float) -> np.ndarray:
@@ -135,9 +145,38 @@ class SubSelectionCondition(Condition):
135145
"""Class for defining valid parameter entries."""
136146

137147
# object variables
138-
selection: list[Any] = field()
139-
"""The list of items which are considered valid."""
148+
_selection: tuple = field(
149+
converter=tuple,
150+
# FIXME[typing]: https://github.com/python-attrs/attrs/issues/1197
151+
validator=[
152+
min_len(1),
153+
validate_unique_values, # type: ignore
154+
],
155+
)
156+
"""The internal list of items which are considered valid."""
157+
158+
@property
159+
def selection(self) -> tuple: # noqa: D102
160+
"""The list of items which are considered valid."""
161+
return tuple(
162+
DTypeFloatNumpy(itm) if isinstance(itm, (float, int, bool)) else itm
163+
for itm in self._selection
164+
)
140165

141166
def evaluate(self, data: pd.Series) -> pd.Series: # noqa: D102
142167
# See base class.
143168
return data.isin(self.selection)
169+
170+
171+
# Register (un-)structure hooks
172+
_overrides = {
173+
"_selection": override(rename="selection"),
174+
}
175+
# FIXME[typing]: https://github.com/python/mypy/issues/4717
176+
converter.register_structure_hook(
177+
Condition,
178+
get_base_structure_hook(Condition, overrides=_overrides), # type: ignore
179+
)
180+
converter.register_unstructure_hook(
181+
Condition, partial(unstructure_base, overrides=_overrides)
182+
)

baybe/objectives/desirability.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from baybe.targets.numerical import NumericalTarget
1818
from baybe.utils.basic import to_tuple
1919
from baybe.utils.numerical import geom_mean
20+
from baybe.utils.validation import finite_float
2021

2122

2223
def _is_all_numerical_targets(
@@ -73,7 +74,7 @@ class DesirabilityObjective(Objective):
7374

7475
weights: tuple[float, ...] = field(
7576
converter=lambda w: cattrs.structure(w, tuple[float, ...]),
76-
validator=deep_iterable(member_validator=gt(0.0)),
77+
validator=deep_iterable(member_validator=[finite_float, gt(0.0)]),
7778
)
7879
"""The weights to balance the different targets.
7980
By default, all targets are considered equally important."""

baybe/parameters/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,15 @@ class ContinuousParameter(Parameter):
130130

131131

132132
# Register (un-)structure hooks
133-
overrides = {
133+
_overrides = {
134134
"_values": override(rename="values"),
135135
"decorrelate": override(struct_hook=lambda x, _: x),
136136
}
137137
# FIXME[typing]: https://github.com/python/mypy/issues/4717
138138
converter.register_structure_hook(
139139
Parameter,
140-
get_base_structure_hook(Parameter, overrides=overrides), # type: ignore
140+
get_base_structure_hook(Parameter, overrides=_overrides), # type: ignore
141141
)
142142
converter.register_unstructure_hook(
143-
Parameter, partial(unstructure_base, overrides=overrides)
143+
Parameter, partial(unstructure_base, overrides=_overrides)
144144
)

baybe/parameters/categorical.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from baybe.parameters.base import DiscreteParameter
1212
from baybe.parameters.enum import CategoricalEncoding
1313
from baybe.parameters.validation import validate_unique_values
14+
from baybe.utils.numerical import DTypeFloatNumpy
1415

1516

1617
@define(frozen=True, slots=False)
@@ -47,9 +48,13 @@ def comp_df(self) -> pd.DataFrame: # noqa: D102
4748
# See base class.
4849
if self.encoding is CategoricalEncoding.OHE:
4950
cols = [f"{self.name}_{val}" for val in self.values]
50-
comp_df = pd.DataFrame(np.eye(len(self.values), dtype=int), columns=cols)
51+
comp_df = pd.DataFrame(
52+
np.eye(len(self.values), dtype=DTypeFloatNumpy), columns=cols
53+
)
5154
elif self.encoding is CategoricalEncoding.INT:
52-
comp_df = pd.DataFrame(range(len(self.values)), columns=[self.name])
55+
comp_df = pd.DataFrame(
56+
range(len(self.values)), dtype=DTypeFloatNumpy, columns=[self.name]
57+
)
5358
comp_df.index = pd.Index(self.values)
5459

5560
return comp_df

baybe/parameters/custom.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from baybe.parameters.validation import validate_decorrelation
1414
from baybe.utils.boolean import eq_dataframe
1515
from baybe.utils.dataframe import df_uncorrelated_features
16+
from baybe.utils.numerical import DTypeFloatNumpy
1617

1718

1819
@define(frozen=True, slots=False)
@@ -100,7 +101,9 @@ def comp_df(self) -> pd.DataFrame: # noqa: D102
100101
# The encoding is directly provided by the user
101102
# We prepend the parameter name to the columns names to avoid potential
102103
# conflicts with other parameters
103-
comp_df = self.data.rename(columns=lambda x: f"{self.name}_{x}")
104+
comp_df = self.data.rename(columns=lambda x: f"{self.name}_{x}").astype(
105+
DTypeFloatNumpy
106+
)
104107

105108
# Get a decorrelated subset of the provided features
106109
if self.decorrelate:

baybe/parameters/numerical.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def _validate_tolerance( # noqa: DOC101, DOC103
6262
if tolerance == 0.0:
6363
return
6464

65-
min_dist = np.diff(self.values).min()
66-
if min_dist == (eps := np.nextafter(0, 1, dtype=DTypeFloatNumpy)):
65+
min_dist = np.diff(self._values).min()
66+
if min_dist == (eps := np.nextafter(0, 1)):
6767
raise NumericalUnderflowError(
6868
f"The distance between any two parameter values must be at least "
6969
f"twice the size of the used floating point resolution of {eps}."
@@ -79,12 +79,14 @@ def _validate_tolerance( # noqa: DOC101, DOC103
7979
@property
8080
def values(self) -> tuple: # noqa: D102
8181
# See base class.
82-
return self._values
82+
return tuple(DTypeFloatNumpy(itm) for itm in self._values)
8383

8484
@cached_property
8585
def comp_df(self) -> pd.DataFrame: # noqa: D102
8686
# See base class.
87-
comp_df = pd.DataFrame({self.name: self.values}, index=self.values)
87+
comp_df = pd.DataFrame(
88+
{self.name: self.values}, index=self.values, dtype=DTypeFloatNumpy
89+
)
8890
return comp_df
8991

9092
def is_in_range(self, item: float) -> bool: # noqa: D102

baybe/parameters/substance.py

-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def _validate_substance_data( # noqa: DOC101, DOC103
108108
@property
109109
def values(self) -> tuple:
110110
"""Returns the labels of the given set of molecules."""
111-
# Since the order of dictionary keys is important here, this will only work
112-
# for Python 3.7 or higher
113111
return tuple(self.data.keys())
114112

115113
@cached_property

baybe/simulation/core.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from baybe.simulation.lookup import _look_up_target_values
1717
from baybe.targets.enum import TargetMode
1818
from baybe.utils.dataframe import add_parameter_noise
19-
from baybe.utils.numerical import closer_element, closest_element
19+
from baybe.utils.numerical import DTypeFloatNumpy, closer_element, closest_element
2020
from baybe.utils.random import temporary_seed
2121

2222

@@ -112,6 +112,12 @@ def simulate_experiment(
112112
"Impute mode 'ignore' is only available for dataframe lookups."
113113
)
114114

115+
# Enforce correct float precision in lookup dataframes
116+
if isinstance(lookup, pd.DataFrame):
117+
lookup = lookup.copy()
118+
float_cols = lookup.select_dtypes(include=["float"]).columns
119+
lookup[float_cols] = lookup[float_cols].astype(DTypeFloatNumpy)
120+
115121
# Clone the campaign to avoid mutating the original object
116122
# TODO: Reconsider if deepcopies are required once [16605] is resolved
117123
campaign = deepcopy(campaign)

baybe/simulation/lookup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _look_up_target_values(
5959
# column ordering, which is not robust. Instead, the callable should return
6060
# a dataframe with properly labeled columns.
6161

62-
# Since the return of a lookup function is a a tuple, the following code stores
62+
# Since the return of a lookup function is a tuple, the following code stores
6363
# tuples of floats in a single column with label 0:
6464
measured_targets = queries.apply(lambda x: lookup(*x.values), axis=1).to_frame()
6565
# We transform this column to a DataFrame in which there is an individual
@@ -79,7 +79,7 @@ def _look_up_target_values(
7979
queries[target.name] = measured_targets.iloc[:, k_target]
8080

8181
# Get results via dataframe lookup (works only for exact matches)
82-
# IMPROVE: Although its not too important for a simulation, this
82+
# IMPROVE: Although it's not too important for a simulation, this
8383
# could also be implemented for approximate matches
8484
elif isinstance(lookup, pd.DataFrame):
8585
all_match_vals = []

baybe/utils/chemistry.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,10 @@ def _smiles_to_mordred_features(smiles: str) -> np.ndarray:
8888
"""
8989
try:
9090
return np.asarray(
91-
_mordred_calculator(Chem.MolFromSmiles(smiles)).fill_missing(),
92-
dtype=DTypeFloatNumpy,
91+
_mordred_calculator(Chem.MolFromSmiles(smiles)).fill_missing()
9392
)
9493
except Exception:
95-
return np.full(
96-
len(_mordred_calculator.descriptors), np.NaN, dtype=DTypeFloatNumpy
97-
)
94+
return np.full(len(_mordred_calculator.descriptors), np.NaN)
9895

9996

10097
def smiles_to_mordred_features(
@@ -117,7 +114,7 @@ def smiles_to_mordred_features(
117114
features = [_smiles_to_mordred_features(smiles) for smiles in smiles_list]
118115
descriptor_names = list(_mordred_calculator.descriptors)
119116
columns = [prefix + "MORDRED_" + str(name) for name in descriptor_names]
120-
dataframe = pd.DataFrame(data=features, columns=columns)
117+
dataframe = pd.DataFrame(data=features, columns=columns, dtype=DTypeFloatNumpy)
121118

122119
if dropna:
123120
dataframe = dataframe.dropna(axis=1)
@@ -169,7 +166,7 @@ def smiles_to_rdkit_features(
169166
res = []
170167
for mol in mols:
171168
desc = {
172-
prefix + "RDKIT_" + dname: func(mol)
169+
prefix + "RDKIT_" + dname: DTypeFloatNumpy(func(mol))
173170
for dname, func in Chem.Descriptors.descList
174171
}
175172
res.append(desc)

baybe/utils/memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ def bytes_to_human_readable(num: float, /) -> tuple[float, str]:
1414
if abs(num) < 1024.0:
1515
return num, unit
1616
num /= 1024.0
17-
return num, "YB"
17+
return round(num, 2), "YB"

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def fixture_parameters(
231231
CategoricalParameter(
232232
name="Categorical_2",
233233
values=("bad", "OK", "good"),
234-
encoding="OHE",
234+
encoding="INT",
235235
),
236236
CategoricalParameter(
237237
name="Switch_1",

tests/hypothesis_strategies/acquisition.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@
1717
qUpperConfidenceBound,
1818
)
1919

20+
from ..hypothesis_strategies.basic import finite_floats
21+
2022
# These acqfs are ordered roughly according to increasing complexity
2123
acquisition_functions = st.one_of(
2224
st.builds(ExpectedImprovement),
2325
st.builds(ProbabilityOfImprovement),
24-
st.builds(
25-
UpperConfidenceBound, beta=st.floats(min_value=0.0, allow_infinity=False)
26-
),
26+
st.builds(UpperConfidenceBound, beta=finite_floats(min_value=0.0)),
2727
st.builds(PosteriorMean),
2828
st.builds(LogExpectedImprovement),
2929
st.builds(qExpectedImprovement),
3030
st.builds(qProbabilityOfImprovement),
31-
st.builds(
32-
qUpperConfidenceBound, beta=st.floats(min_value=0.0, allow_infinity=False)
33-
),
31+
st.builds(qUpperConfidenceBound, beta=finite_floats(min_value=0.0)),
3432
st.builds(qSimpleRegret),
3533
st.builds(qLogExpectedImprovement),
3634
st.builds(qNoisyExpectedImprovement),

tests/hypothesis_strategies/basic.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
from functools import partial
44

55
import hypothesis.strategies as st
6+
import numpy as np
67

7-
finite_floats = partial(st.floats, allow_infinity=False, allow_nan=False)
8+
from baybe.utils.numerical import DTypeFloatNumpy
9+
10+
finite_floats = partial(
11+
st.floats,
12+
allow_infinity=False,
13+
allow_nan=False,
14+
width=32 if DTypeFloatNumpy == np.float32 else 64,
15+
)
816
"""A strategy producing finite (i.e., non-nan and non-infinite) floats."""

0 commit comments

Comments
 (0)