From bb14330559ff902edc5d84c950600e895eb2c0b5 Mon Sep 17 00:00:00 2001 From: Victoria A <52001888+adjeiv@users.noreply.github.com> Date: Mon, 6 Nov 2023 20:11:26 +0000 Subject: [PATCH 1/4] Report CV scores from within OptunaSearchCV --- optuna/integration/sklearn.py | 2 ++ tests/integration_tests/test_sklearn.py | 26 ++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/optuna/integration/sklearn.py b/optuna/integration/sklearn.py index f5a4b7a329b..d42e0909d03 100644 --- a/optuna/integration/sklearn.py +++ b/optuna/integration/sklearn.py @@ -24,6 +24,7 @@ from optuna._imports import try_import from optuna.distributions import _convert_old_distribution_to_new_distribution from optuna.study import StudyDirection +from optuna.terminator import report_cross_validation_scores from optuna.trial import FrozenTrial from optuna.trial import Trial @@ -242,6 +243,7 @@ def __call__(self, trial: Trial) -> float: } self._store_scores(trial, scores) + report_cross_validation_scores(trial, scores["test_score"]) return trial.user_attrs["mean_test_score"] diff --git a/tests/integration_tests/test_sklearn.py b/tests/integration_tests/test_sklearn.py index def1a222cd9..85b8a3e017b 100644 --- a/tests/integration_tests/test_sklearn.py +++ b/tests/integration_tests/test_sklearn.py @@ -1,6 +1,7 @@ from __future__ import annotations from unittest.mock import MagicMock +from unittest.mock import patch import warnings import numpy as np @@ -20,6 +21,7 @@ from optuna import integration from optuna.samplers import BruteForceSampler from optuna.study import create_study +from optuna.terminator.erroreval import _CROSS_VALIDATION_SCORES_KEY pytestmark = pytest.mark.integration @@ -378,7 +380,10 @@ def test_optuna_search_convert_deprecated_distribution() -> None: assert optuna_search.param_distributions == expected_param_dist -def test_callbacks() -> None: +@patch("optuna.integration.sklearn.report_cross_validation_scores") +def test_callbacks(mock: MagicMock) -> None: + mock.return_value = None + callbacks = [] for _ in range(2): @@ -409,3 +414,22 @@ def test_callbacks() -> None: for trial in optuna_search.trials_: callback.assert_any_call(optuna_search.study_, trial) assert callback.call_count == n_trials + + +@pytest.mark.filterwarnings("ignore::UserWarning") +@patch("optuna.integration.sklearn.cross_validate") +def test_terminator_cv_score_reporting(mock: MagicMock) -> None: + scores = { + "fit_time": np.array([2.01, 1.78, 3.22]), + "score_time": np.array([0.33, 0.35, 0.48]), + "test_score": np.array([0.04, 2.0, 1.5]), + } + mock.return_value = scores + + X, _ = make_blobs(n_samples=10) + est = PCA() + optuna_search = integration.OptunaSearchCV(est, {}, cv=3, error_score="raise", random_state=0) + optuna_search.fit(X) + + for trial in optuna_search.study_.trials: + assert (trial.system_attrs[_CROSS_VALIDATION_SCORES_KEY] == scores["test_score"]).all() From 867fa4a4626cd9294daced267a2e9e276fc1c858 Mon Sep 17 00:00:00 2001 From: Victoria A <52001888+adjeiv@users.noreply.github.com> Date: Mon, 6 Nov 2023 20:35:49 +0000 Subject: [PATCH 2/4] Fix typing --- optuna/integration/sklearn.py | 5 ++++- tests/integration_tests/test_sklearn.py | 5 +---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optuna/integration/sklearn.py b/optuna/integration/sklearn.py index d42e0909d03..eeaa83a7986 100644 --- a/optuna/integration/sklearn.py +++ b/optuna/integration/sklearn.py @@ -243,7 +243,10 @@ def __call__(self, trial: Trial) -> float: } self._store_scores(trial, scores) - report_cross_validation_scores(trial, scores["test_score"]) + test_scores = ( + scores if isinstance(scores["test_score"], list) else scores["test_score"].tolist() + ) + report_cross_validation_scores(trial, test_scores) return trial.user_attrs["mean_test_score"] diff --git a/tests/integration_tests/test_sklearn.py b/tests/integration_tests/test_sklearn.py index 85b8a3e017b..80322528e82 100644 --- a/tests/integration_tests/test_sklearn.py +++ b/tests/integration_tests/test_sklearn.py @@ -380,10 +380,7 @@ def test_optuna_search_convert_deprecated_distribution() -> None: assert optuna_search.param_distributions == expected_param_dist -@patch("optuna.integration.sklearn.report_cross_validation_scores") -def test_callbacks(mock: MagicMock) -> None: - mock.return_value = None - +def test_callbacks() -> None: callbacks = [] for _ in range(2): From cc9f661719eda13282dbd8f6f0f6572414845fa0 Mon Sep 17 00:00:00 2001 From: Victoria A <52001888+adjeiv@users.noreply.github.com> Date: Mon, 6 Nov 2023 20:44:32 +0000 Subject: [PATCH 3/4] Use correct var --- optuna/integration/sklearn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optuna/integration/sklearn.py b/optuna/integration/sklearn.py index eeaa83a7986..83a0d17db3b 100644 --- a/optuna/integration/sklearn.py +++ b/optuna/integration/sklearn.py @@ -243,10 +243,10 @@ def __call__(self, trial: Trial) -> float: } self._store_scores(trial, scores) - test_scores = ( - scores if isinstance(scores["test_score"], list) else scores["test_score"].tolist() - ) - report_cross_validation_scores(trial, test_scores) + + test_scores = scores["test_score"] + scores_list = test_scores if isinstance(test_scores, list) else test_scores.tolist() + report_cross_validation_scores(trial, scores_list) return trial.user_attrs["mean_test_score"] From b4dd96021cbaca27d9549438d04f85bf9d93ecb3 Mon Sep 17 00:00:00 2001 From: Victoria A <52001888+adjeiv@users.noreply.github.com> Date: Mon, 6 Nov 2023 21:16:08 +0000 Subject: [PATCH 4/4] Change test_score values --- tests/integration_tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/test_sklearn.py b/tests/integration_tests/test_sklearn.py index 80322528e82..d52bf898033 100644 --- a/tests/integration_tests/test_sklearn.py +++ b/tests/integration_tests/test_sklearn.py @@ -419,7 +419,7 @@ def test_terminator_cv_score_reporting(mock: MagicMock) -> None: scores = { "fit_time": np.array([2.01, 1.78, 3.22]), "score_time": np.array([0.33, 0.35, 0.48]), - "test_score": np.array([0.04, 2.0, 1.5]), + "test_score": np.array([0.04, 0.80, 0.70]), } mock.return_value = scores