Skip to content

Commit 2eea80f

Browse files
committed
[rebase] Rebase to the latest version and merge test_evaluator to train_evaluator
Since test_evaluator can be merged, I merged it. * [rebase] Rebase and merge the changes in non-test files without issues * [refactor] Merge test- and train-evaluator * [fix] Fix the import error due to the change xxx_evaluator --> evaluator * [test] Fix errors in tests * [fix] Fix the handling of test pred in no resampling * [refactor] Move save_y_opt=False for no resampling deepter for simplicity * [test] Increase the budget size for no resample tests
1 parent b32e8be commit 2eea80f

File tree

13 files changed

+298
-603
lines changed

13 files changed

+298
-603
lines changed

autoPyTorch/api/base_task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def _get_dataset_input_validator(
315315
Testing feature set
316316
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
317317
Testing target set
318-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
318+
resampling_strategy (Optional[ResamplingStrategies]):
319319
Strategy to split the training data. if None, uses
320320
HoldoutValTypes.holdout_validation.
321321
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -355,7 +355,7 @@ def get_dataset(
355355
Testing feature set
356356
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
357357
Testing target set
358-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
358+
resampling_strategy (Optional[ResamplingStrategies]):
359359
Strategy to split the training data. if None, uses
360360
HoldoutValTypes.holdout_validation.
361361
resampling_strategy_args (Optional[Dict[str, Any]]):
@@ -973,7 +973,7 @@ def _search(
973973
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
974974
tae_func (Optional[Callable]):
975975
TargetAlgorithm to be optimised. If None, `eval_function`
976-
available in autoPyTorch/evaluation/train_evaluator is used.
976+
available in autoPyTorch/evaluation/evaluator is used.
977977
Must be child class of AbstractEvaluator.
978978
all_supported_metrics (bool: default=True):
979979
If True, all metrics supporting current task will be calculated
@@ -1380,7 +1380,7 @@ def fit_pipeline(
13801380
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13811381
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13821382
dataset_name: Optional[str] = None,
1383-
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
1383+
resampling_strategy: Optional[ResamplingStrategies] = None,
13841384
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13851385
run_time_limit_secs: int = 60,
13861386
memory_limit: Optional[int] = None,
@@ -1415,7 +1415,7 @@ def fit_pipeline(
14151415
be provided to track the generalization performance of each stage.
14161416
dataset_name (Optional[str]):
14171417
Name of the dataset, if None, random value is used.
1418-
resampling_strategy (Optional[RESAMPLING_STRATEGIES]):
1418+
resampling_strategy (Optional[ResamplingStrategies]):
14191419
Strategy to split the training data. if None, uses
14201420
HoldoutValTypes.holdout_validation.
14211421
resampling_strategy_args (Optional[Dict[str, Any]]):

autoPyTorch/api/tabular_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def search(
330330
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
331331
tae_func (Optional[Callable]):
332332
TargetAlgorithm to be optimised. If None, `eval_function`
333-
available in autoPyTorch/evaluation/train_evaluator is used.
333+
available in autoPyTorch/evaluation/evaluator is used.
334334
Must be child class of AbstractEvaluator.
335335
all_supported_metrics (bool: default=True):
336336
If True, all metrics supporting current task will be calculated

autoPyTorch/api/tabular_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def search(
331331
`SMAC <https://automl.github.io/SMAC3/master/index.html>`_.
332332
tae_func (Optional[Callable]):
333333
TargetAlgorithm to be optimised. If None, `eval_function`
334-
available in autoPyTorch/evaluation/train_evaluator is used.
334+
available in autoPyTorch/evaluation/evaluator is used.
335335
Must be child class of AbstractEvaluator.
336336
all_supported_metrics (bool: default=True):
337337
If True, all metrics supporting current task will be calculated

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def is_stratified(self) -> bool:
9393
# TODO: replace it with another way
9494
ResamplingStrategies = Union[CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
9595

96+
97+
def check_resampling_strategy(resampling_strategy: Optional[ResamplingStrategies]) -> None:
98+
choices = (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes)
99+
if not isinstance(resampling_strategy, choices):
100+
rs_names = (rs.__mro__[0].__name__ for rs in choices)
101+
raise ValueError(f'resampling_strategy must be in {rs_names}, but got {resampling_strategy}')
102+
103+
96104
DEFAULT_RESAMPLING_PARAMETERS: Dict[
97105
ResamplingStrategies,
98106
Dict[str, Any]

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(self, backend: Backend,
207207
An evaluator is an object that:
208208
+ constructs a pipeline (i.e. a classification or regression estimator) for a given
209209
pipeline_config and run settings (budget, seed)
210-
+ Fits and trains this pipeline (TrainEvaluator) or tests a given
210+
+ Fits and trains this pipeline (Evaluator) or tests a given
211211
configuration (TestEvaluator)
212212

213213
The provided configuration determines the type of pipeline created. For more

autoPyTorch/evaluation/train_evaluator.py renamed to autoPyTorch/evaluation/evaluator.py

Lines changed: 38 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77

88
from smac.tae import StatusType
99

10-
from autoPyTorch.automl_common.common.utils.backend import Backend
11-
from autoPyTorch.constants import (
12-
CLASSIFICATION_TASKS,
13-
MULTICLASSMULTIOUTPUT,
10+
from autoPyTorch.datasets.resampling_strategy import (
11+
CrossValTypes,
12+
NoResamplingStrategyTypes,
13+
check_resampling_strategy
1414
)
15-
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
1615
from autoPyTorch.evaluation.abstract_evaluator import (
1716
AbstractEvaluator,
1817
EvaluationResults,
@@ -21,7 +20,8 @@
2120
from autoPyTorch.evaluation.abstract_evaluator import EvaluatorParams, FixedPipelineParams
2221
from autoPyTorch.utils.common import dict_repr, subsampler
2322

24-
__all__ = ['TrainEvaluator', 'eval_train_function']
23+
__all__ = ['Evaluator', 'eval_fn']
24+
2525

2626
class _CrossValidationResultsManager:
2727
def __init__(self, num_folds: int):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
8383
)
8484

8585

86-
class TrainEvaluator(AbstractEvaluator):
86+
class Evaluator(AbstractEvaluator):
8787
"""
8888
This class builds a pipeline using the provided configuration.
8989
A pipeline implementing the provided configuration is fitted
9090
using the datamanager object retrieved from disc, via the backend.
9191
After the pipeline is fitted, it is save to disc and the performance estimate
92-
is communicated to the main process via a Queue. It is only compatible
93-
with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data
94-
is split and the validation set is used for SMBO optimisation.
92+
is communicated to the main process via a Queue.
9593
9694
Args:
9795
queue (Queue):
@@ -101,43 +99,17 @@ class TrainEvaluator(AbstractEvaluator):
10199
Fixed parameters for a pipeline
102100
evaluator_params (EvaluatorParams):
103101
The parameters for an evaluator.
102+
103+
Attributes:
104+
train (bool):
105+
Whether the training data is split and the validation set is used for SMBO optimisation.
106+
cross_validation (bool):
107+
Whether we use cross validation or not.
104108
"""
105-
def __init__(self, backend: Backend, queue: Queue,
106-
metric: autoPyTorchMetric,
107-
budget: float,
108-
configuration: Union[int, str, Configuration],
109-
budget_type: str = None,
110-
pipeline_config: Optional[Dict[str, Any]] = None,
111-
seed: int = 1,
112-
output_y_hat_optimization: bool = True,
113-
num_run: Optional[int] = None,
114-
include: Optional[Dict[str, Any]] = None,
115-
exclude: Optional[Dict[str, Any]] = None,
116-
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
117-
init_params: Optional[Dict[str, Any]] = None,
118-
logger_port: Optional[int] = None,
119-
keep_models: Optional[bool] = None,
120-
all_supported_metrics: bool = True,
121-
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None) -> None:
122-
super().__init__(
123-
backend=backend,
124-
queue=queue,
125-
configuration=configuration,
126-
metric=metric,
127-
seed=seed,
128-
output_y_hat_optimization=output_y_hat_optimization,
129-
num_run=num_run,
130-
include=include,
131-
exclude=exclude,
132-
disable_file_output=disable_file_output,
133-
init_params=init_params,
134-
budget=budget,
135-
budget_type=budget_type,
136-
logger_port=logger_port,
137-
all_supported_metrics=all_supported_metrics,
138-
pipeline_config=pipeline_config,
139-
search_space_updates=search_space_updates
140-
)
109+
def __init__(self, queue: Queue, fixed_pipeline_params: FixedPipelineParams, evaluator_params: EvaluatorParams):
110+
resampling_strategy = fixed_pipeline_params.backend.load_datamanager().resampling_strategy
111+
self.train = not isinstance(resampling_strategy, NoResamplingStrategyTypes)
112+
self.cross_validation = isinstance(resampling_strategy, CrossValTypes)
141113

142114
if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)):
143115
raise ValueError(
@@ -175,7 +147,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
175147

176148
return EvaluationResults(
177149
pipeline=pipeline,
178-
opt_loss=self._loss(labels=self.y_train[opt_split], preds=opt_pred),
150+
opt_loss=self._loss(labels=self.y_train[opt_split] if self.train else self.y_test, preds=opt_pred),
179151
train_loss=self._loss(labels=self.y_train[train_split], preds=train_pred),
180152
opt_pred=opt_pred,
181153
valid_pred=valid_pred,
@@ -201,6 +173,7 @@ def _cross_validation(self) -> EvaluationResults:
201173
results = self._evaluate_on_split(split_id)
202174

203175
self.pipelines[split_id] = results.pipeline
176+
assert opt_split is not None # mypy redefinition
204177
cv_results.update(split_id, results, len(train_split), len(opt_split))
205178

206179
self.y_opt = np.concatenate([y_opt for y_opt in Y_opt if y_opt is not None])
@@ -212,15 +185,16 @@ def evaluate_loss(self) -> None:
212185
if self.splits is None:
213186
raise ValueError(f"cannot fit pipeline {self.__class__.__name__} with datamanager.splits None")
214187

215-
if self.num_folds == 1:
188+
if self.cross_validation:
189+
results = self._cross_validation()
190+
else:
216191
_, opt_split = self.splits[0]
217192
results = self._evaluate_on_split(split_id=0)
218-
self.y_opt, self.pipelines[0] = self.y_train[opt_split], results.pipeline
219-
else:
220-
results = self._cross_validation()
193+
self.pipelines[0] = results.pipeline
194+
self.y_opt = self.y_train[opt_split] if self.train else self.y_test
221195

222196
self.logger.debug(
223-
f"In train evaluator.evaluate_loss, num_run: {self.num_run}, loss:{results.opt_loss},"
197+
f"In evaluate_loss, num_run: {self.num_run}, loss:{results.opt_loss},"
224198
f" status: {results.status},\nadditional run info:\n{dict_repr(results.additional_run_info)}"
225199
)
226200
self.record_evaluation(results=results)
@@ -240,41 +214,23 @@ def _fit_and_evaluate_loss(
240214

241215
kwargs = {'pipeline': pipeline, 'unique_train_labels': self.unique_train_labels[split_id]}
242216
train_pred = self.predict(subsampler(self.X_train, train_indices), **kwargs)
243-
opt_pred = self.predict(subsampler(self.X_train, opt_indices), **kwargs)
244-
valid_pred = self.predict(self.X_valid, **kwargs)
245217
test_pred = self.predict(self.X_test, **kwargs)
218+
valid_pred = self.predict(self.X_valid, **kwargs)
219+
220+
# No resampling ===> evaluate on test dataset
221+
opt_pred = self.predict(subsampler(self.X_train, opt_indices), **kwargs) if self.train else test_pred
246222

247223
assert train_pred is not None and opt_pred is not None # mypy check
248224
return train_pred, opt_pred, valid_pred, test_pred
249225

250226

251-
# create closure for evaluating an algorithm
252-
def eval_train_function(
253-
backend: Backend,
254-
queue: Queue,
255-
metric: autoPyTorchMetric,
256-
budget: float,
257-
config: Optional[Configuration],
258-
seed: int,
259-
output_y_hat_optimization: bool,
260-
num_run: int,
261-
include: Optional[Dict[str, Any]],
262-
exclude: Optional[Dict[str, Any]],
263-
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
264-
pipeline_config: Optional[Dict[str, Any]] = None,
265-
budget_type: str = None,
266-
init_params: Optional[Dict[str, Any]] = None,
267-
logger_port: Optional[int] = None,
268-
all_supported_metrics: bool = True,
269-
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
270-
instance: str = None,
271-
) -> None:
227+
def eval_fn(queue: Queue, fixed_pipeline_params: FixedPipelineParams, evaluator_params: EvaluatorParams) -> None:
272228
"""
273229
This closure allows the communication between the TargetAlgorithmQuery and the
274-
pipeline trainer (TrainEvaluator).
230+
pipeline trainer (Evaluator).
275231
276232
Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally
277-
builds a TrainEvaluator. The TrainEvaluator builds a pipeline, stores the output files
233+
builds an Evaluator. The Evaluator builds a pipeline, stores the output files
278234
to disc via the backend, and puts the performance result of the run in the queue.
279235
280236
Args:
@@ -286,7 +242,11 @@ def eval_train_function(
286242
evaluator_params (EvaluatorParams):
287243
The parameters for an evaluator.
288244
"""
289-
evaluator = TrainEvaluator(
245+
resampling_strategy = fixed_pipeline_params.backend.load_datamanager().resampling_strategy
246+
check_resampling_strategy(resampling_strategy)
247+
248+
# NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator
249+
evaluator = Evaluator(
290250
queue=queue,
291251
evaluator_params=evaluator_params,
292252
fixed_pipeline_params=fixed_pipeline_params

0 commit comments

Comments
 (0)