diff --git a/pyproject.toml b/pyproject.toml index e65dfd740..89b2dee13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "scipy >=1.8,<2.0", "pandas >=1.3.0,<3.0.0", "scikit-learn >=1.2,<2.0", + "scikit-base <0.13.0", ] [project.optional-dependencies] @@ -102,7 +103,6 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", - "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index 1337f33c6..a321ab5d3 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -20,6 +20,7 @@ from numpy import iterable import pandas as pd import scipy.stats +from skbase.utils.dependencies._dependencies import _check_soft_dependencies import torch import torch.nn as nn from torch.nn.utils import rnn @@ -61,10 +62,7 @@ to_list, ) from pytorch_forecasting.utils._classproperty import classproperty -from pytorch_forecasting.utils._dependencies import ( - _check_matplotlib, - _get_installed_packages, -) +from pytorch_forecasting.utils._dependencies import _check_matplotlib # todo: compile models @@ -1355,7 +1353,7 @@ def configure_optimizers(self): Returns: Tuple[List]: first entry is list of optimizers and second is list of schedulers """ # noqa: E501 - ptopt_in_env = "pytorch_optimizer" in _get_installed_packages() + ptopt_in_env = _check_soft_dependencies("pytorch_optimizer", severity="none") # either set a schedule of lrs or find it dynamically if self.hparams.optimizer_params is None: optimizer_params = {} diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 6acb76203..e62ca18ac 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -13,12 +13,12 @@ from lightning.pytorch.tuner import Tuner import numpy as np import scipy._lib._util +from skbase.utils.dependencies._dependencies import _check_soft_dependencies from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import QuantileLoss -from pytorch_forecasting.utils._dependencies import _get_installed_packages optuna_logger = logging.getLogger("optuna") @@ -108,9 +108,7 @@ def optimize_hyperparameters( Returns: optuna.Study: optuna study results """ # noqa : E501 - pkgs = _get_installed_packages() - - if "optuna" not in pkgs or "statsmodels" not in pkgs: + if not _check_soft_dependencies(["optuna", "statsmodels"], severity="none"): raise ImportError( "optimize_hyperparameters requires optuna and statsmodels. " "Please install these packages with `pip install optuna statsmodels`. " diff --git a/pytorch_forecasting/utils/_dependencies/__init__.py b/pytorch_forecasting/utils/_dependencies/__init__.py index fbf75137b..c198be6c1 100644 --- a/pytorch_forecasting/utils/_dependencies/__init__.py +++ b/pytorch_forecasting/utils/_dependencies/__init__.py @@ -1,13 +1,9 @@ """Utilities for managing dependencies.""" -from pytorch_forecasting.utils._dependencies._dependencies import ( - _check_matplotlib, - _get_installed_packages, -) +from pytorch_forecasting.utils._dependencies._dependencies import _check_matplotlib from pytorch_forecasting.utils._dependencies._safe_import import _safe_import __all__ = [ - "_get_installed_packages", "_check_matplotlib", "_safe_import", ] diff --git a/pytorch_forecasting/utils/_dependencies/_dependencies.py b/pytorch_forecasting/utils/_dependencies/_dependencies.py index 0edf653fc..d4bca2f01 100644 --- a/pytorch_forecasting/utils/_dependencies/_dependencies.py +++ b/pytorch_forecasting/utils/_dependencies/_dependencies.py @@ -3,45 +3,9 @@ Copied from sktime/skbase. """ -from functools import lru_cache +from skbase.utils.dependencies._dependencies import _check_soft_dependencies - -@lru_cache -def _get_installed_packages_private(): - """Get a dictionary of installed packages and their versions. - - Same as _get_installed_packages, but internal to avoid mutating the lru_cache - by accident. - """ - from importlib.metadata import distributions, version - - dists = distributions() - package_names = { - dist.metadata["Name"] - for dist in dists - if dist.metadata and "Name" in dist.metadata - } - package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} - # developer note: - # we cannot just use distributions naively, - # because the same top level package name may appear *twice*, - # e.g., in a situation where a virtual env overrides a base env, - # such as in deployment environments like databricks. - # the "version" contract ensures we always get the version that corresponds - # to the importable distribution, i.e., the top one in the sys.path. - return package_versions - - -def _get_installed_packages(): - """Get a dictionary of installed packages and their versions. - - Returns - ------- - dict : dictionary of installed packages and their versions - keys are PEP 440 compatible package names, values are package versions - MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" - """ - return _get_installed_packages_private().copy() +__all__ = ["_check_soft_dependencies", "_check_matplotlib"] def _check_matplotlib(ref="This feature", raise_error=True): @@ -58,12 +22,11 @@ def _check_matplotlib(ref="This feature", raise_error=True): ------- bool : whether matplotlib is installed """ - pkgs = _get_installed_packages() - - if raise_error and "matplotlib" not in pkgs: + matplotlib_present = _check_soft_dependencies("matplotlib", severity="none") + if raise_error and not matplotlib_present: raise ImportError( f"{ref} requires matplotlib." " Please install matplotlib with `pip install matplotlib`." ) - return "matplotlib" in pkgs + return matplotlib_present diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f11313dd4..0c7d80a17 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -8,7 +8,7 @@ import importlib from unittest.mock import MagicMock -from pytorch_forecasting.utils._dependencies import _get_installed_packages +from skbase.utils.dependencies._dependencies import _get_installed_packages def _safe_import(import_path, pkg_name=None): diff --git a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py index e0e7e7ecb..bc98d7c27 100644 --- a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py @@ -1,16 +1,6 @@ __author__ = ["jgyasu", "fkiraly"] -from pytorch_forecasting.utils._dependencies import ( - _get_installed_packages, - _safe_import, -) - - -def test_import_present_module(): - """Test importing a dependency that is installed.""" - result = _safe_import("pandas") - assert result is not None - assert "pandas" in _get_installed_packages() +from pytorch_forecasting.utils._dependencies import _safe_import def test_import_missing_module(): diff --git a/pytorch_forecasting/utils/_maint/_show_versions.py b/pytorch_forecasting/utils/_maint/_show_versions.py index 39f4c61bc..b9b31d907 100644 --- a/pytorch_forecasting/utils/_maint/_show_versions.py +++ b/pytorch_forecasting/utils/_maint/_show_versions.py @@ -82,7 +82,7 @@ def _get_deps_info(deps=None, source="distributions"): deps = ["pytorch-forecasting"] if source == "distributions": - from pytorch_forecasting.utils._dependencies import _get_installed_packages + from skbase.utils.dependencies._dependencies import _get_installed_packages KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"} diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index c3379fbf1..2eb43249e 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -5,9 +5,9 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.models import NBeats -from pytorch_forecasting.utils._dependencies import _get_installed_packages def test_integration(dataloaders_fixed_window_without_covariates, tmp_path): @@ -90,7 +90,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_interpretation(model, dataloaders_fixed_window_without_covariates): diff --git a/tests/test_models/test_nhits.py b/tests/test_models/test_nhits.py index a79e7a93f..2260a47f1 100644 --- a/tests/test_models/test_nhits.py +++ b/tests/test_models/test_nhits.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss @@ -14,7 +15,6 @@ ImplicitQuantileNetworkDistributionLoss, ) from pytorch_forecasting.models import NHiTS -from pytorch_forecasting.utils._dependencies import _get_installed_packages def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): @@ -96,7 +96,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): "implicit-quantiles", ] -if "cpflows" in _get_installed_packages(): +if _check_soft_dependencies("cpflows", severity="none"): LOADERS += ["multivariate-quantiles"] @@ -158,7 +158,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_interpretation(model, dataloaders_with_covariates): diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 3e09b8b7e..747cd53b5 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies from test_models.conftest import make_dataloaders import torch @@ -29,7 +30,6 @@ from pytorch_forecasting.models.temporal_fusion_transformer.tuning import ( optimize_hyperparameters, ) -from pytorch_forecasting.utils._dependencies import _get_installed_packages def test_integration(multiple_dataloaders_with_covariates, tmp_path): @@ -71,7 +71,7 @@ def test_distribution_loss(data_with_covariates, tmp_path): @pytest.mark.skipif( - "cpflows" not in _get_installed_packages(), + not _check_soft_dependencies("cpflows", severity="none"), reason="Test skipped if required package cpflows not available", ) def test_mqf2_loss(data_with_covariates, tmp_path): @@ -331,7 +331,7 @@ def test_predict_dependency( @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_actual_vs_predicted_plot(model, dataloaders_with_covariates): @@ -424,8 +424,7 @@ def test_prediction_with_dataframe(model, data_with_covariates): SKIP_HYPEPARAM_TEST = ( sys.platform.startswith("win") # Test skipped on Windows OS due to issues with ddp, see #1632" - or "optuna" not in _get_installed_packages() - or "statsmodels" not in _get_installed_packages() + or not _check_soft_dependencies(["optuna", "statsmodels"], severity="none") # Test skipped if required package optuna or statsmodels not available ) diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py index 3b73ba380..4e7a6eddc 100644 --- a/tests/test_models/test_tide.py +++ b/tests/test_models/test_tide.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import SMAPE from pytorch_forecasting.models import TiDEModel from pytorch_forecasting.tests._conftest import make_dataloaders -from pytorch_forecasting.utils._dependencies import _get_installed_packages def _integration( @@ -192,7 +192,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_prediction_visualization(model, dataloaders_with_covariates):