diff --git a/CHANGELOG.md b/CHANGELOG.md index d976ad4b9..8efe6b6da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix typo in `ar_model.py` that causes `AttributeError` during evaluation [\#204](https://github.com/mllam/neural-lam/pull/204) @ritinikhil +- Fix `get_integer_time` to avoid floating-point precision issues and correctly handle zero timedelta [#494](https://github.com/mllam/neural-lam/pull/494) @Saptami191 + - Changed the hardcoded True to a conditional check "persistent_workers=self.num_workers > 0" [\#235](https://github.com/mllam/neural-lam/pull/235) @santhil-cyber - Avoid eager download of the MEPS example dataset during pytest collection by lazily initializing it in `tests/conftest.py`, allowing tests to run without triggering a dataset download at import time. [#391](https://github.com/mllam/neural-lam/pull/391) @Saptami191 diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index 7dbcc7ef8..bcf4088e5 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -48,7 +48,7 @@ def get_original_indices(self): return self.original_indices def get_original_window_indices(self, step_length): - step_int, _ = get_integer_time(step_length.total_seconds()) + step_int, _ = get_integer_time(step_length) return [ i // step_int for i in range(len(self.original_indices) * step_int) ] diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 742ef9823..d510809e3 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -601,22 +601,31 @@ def get_integer_time(tdelta) -> tuple[int, str]: >>> get_integer_time(timedelta(milliseconds=1000)) (1, 'seconds') >>> get_integer_time(timedelta(days=0.001)) - (1, 'unknown') + (86400, 'milliseconds') + >>> get_integer_time(timedelta(0)) + (0, 'seconds') """ - total_seconds = tdelta.total_seconds() + total_microseconds = ( + tdelta.days * 86400_000000 + + tdelta.seconds * 1_000000 + + tdelta.microseconds + ) + + if total_microseconds == 0: + return 0, "seconds" units = { - "weeks": 604800, - "days": 86400, - "hours": 3600, - "minutes": 60, - "seconds": 1, - "milliseconds": 0.001, - "microseconds": 0.000001, + "weeks": 604800_000000, + "days": 86400_000000, + "hours": 3600_000000, + "minutes": 60_000000, + "seconds": 1_000000, + "milliseconds": 1_000, + "microseconds": 1, } - for unit, unit_in_seconds in units.items(): - if total_seconds % unit_in_seconds == 0: - return int(total_seconds / unit_in_seconds), unit + for unit, unit_us in units.items(): + if total_microseconds % unit_us == 0: + return total_microseconds // unit_us, unit return 1, "unknown" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..c37b6a881 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,29 @@ +# Standard library +from datetime import timedelta + +# First-party +from neural_lam.utils import get_integer_time + + +def test_days(): + assert get_integer_time(timedelta(days=14)) == (2, "weeks") + + +def test_hours(): + assert get_integer_time(timedelta(hours=5)) == (5, "hours") + + +def test_zero(): + assert get_integer_time(timedelta(0)) == (0, "seconds") + + +def test_milliseconds(): + assert get_integer_time(timedelta(milliseconds=1000)) == (1, "seconds") + + +def test_negative(): + assert get_integer_time(timedelta(days=-7)) == (-1, "weeks") + + +def test_float_days(): + assert get_integer_time(timedelta(days=0.001)) == (86400, "milliseconds")