7
7
8
8
from smac .tae import StatusType
9
9
10
- from autoPyTorch .automl_common . common . utils . backend import Backend
11
- from autoPyTorch . constants import (
12
- CLASSIFICATION_TASKS ,
13
- MULTICLASSMULTIOUTPUT ,
10
+ from autoPyTorch .datasets . resampling_strategy import (
11
+ CrossValTypes ,
12
+ NoResamplingStrategyTypes ,
13
+ check_resampling_strategy
14
14
)
15
- from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
16
15
from autoPyTorch .evaluation .abstract_evaluator import (
17
16
AbstractEvaluator ,
18
17
EvaluationResults ,
21
20
from autoPyTorch .evaluation .abstract_evaluator import EvaluatorParams , FixedPipelineParams
22
21
from autoPyTorch .utils .common import dict_repr , subsampler
23
22
24
- __all__ = ['TrainEvaluator' , 'eval_train_function' ]
23
+ __all__ = ['Evaluator' , 'eval_fn' ]
24
+
25
25
26
26
class _CrossValidationResultsManager :
27
27
def __init__ (self , num_folds : int ):
@@ -83,15 +83,13 @@ def get_result_dict(self) -> Dict[str, Any]:
83
83
)
84
84
85
85
86
- class TrainEvaluator (AbstractEvaluator ):
86
+ class Evaluator (AbstractEvaluator ):
87
87
"""
88
88
This class builds a pipeline using the provided configuration.
89
89
A pipeline implementing the provided configuration is fitted
90
90
using the datamanager object retrieved from disc, via the backend.
91
91
After the pipeline is fitted, it is save to disc and the performance estimate
92
- is communicated to the main process via a Queue. It is only compatible
93
- with `CrossValTypes`, `HoldoutValTypes`, i.e, when the training data
94
- is split and the validation set is used for SMBO optimisation.
92
+ is communicated to the main process via a Queue.
95
93
96
94
Args:
97
95
queue (Queue):
@@ -101,43 +99,17 @@ class TrainEvaluator(AbstractEvaluator):
101
99
Fixed parameters for a pipeline
102
100
evaluator_params (EvaluatorParams):
103
101
The parameters for an evaluator.
102
+
103
+ Attributes:
104
+ train (bool):
105
+ Whether the training data is split and the validation set is used for SMBO optimisation.
106
+ cross_validation (bool):
107
+ Whether we use cross validation or not.
104
108
"""
105
- def __init__ (self , backend : Backend , queue : Queue ,
106
- metric : autoPyTorchMetric ,
107
- budget : float ,
108
- configuration : Union [int , str , Configuration ],
109
- budget_type : str = None ,
110
- pipeline_config : Optional [Dict [str , Any ]] = None ,
111
- seed : int = 1 ,
112
- output_y_hat_optimization : bool = True ,
113
- num_run : Optional [int ] = None ,
114
- include : Optional [Dict [str , Any ]] = None ,
115
- exclude : Optional [Dict [str , Any ]] = None ,
116
- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
117
- init_params : Optional [Dict [str , Any ]] = None ,
118
- logger_port : Optional [int ] = None ,
119
- keep_models : Optional [bool ] = None ,
120
- all_supported_metrics : bool = True ,
121
- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ) -> None :
122
- super ().__init__ (
123
- backend = backend ,
124
- queue = queue ,
125
- configuration = configuration ,
126
- metric = metric ,
127
- seed = seed ,
128
- output_y_hat_optimization = output_y_hat_optimization ,
129
- num_run = num_run ,
130
- include = include ,
131
- exclude = exclude ,
132
- disable_file_output = disable_file_output ,
133
- init_params = init_params ,
134
- budget = budget ,
135
- budget_type = budget_type ,
136
- logger_port = logger_port ,
137
- all_supported_metrics = all_supported_metrics ,
138
- pipeline_config = pipeline_config ,
139
- search_space_updates = search_space_updates
140
- )
109
+ def __init__ (self , queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ):
110
+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
111
+ self .train = not isinstance (resampling_strategy , NoResamplingStrategyTypes )
112
+ self .cross_validation = isinstance (resampling_strategy , CrossValTypes )
141
113
142
114
if not isinstance (self .resampling_strategy , (CrossValTypes , HoldoutValTypes )):
143
115
raise ValueError (
@@ -175,7 +147,7 @@ def _evaluate_on_split(self, split_id: int) -> EvaluationResults:
175
147
176
148
return EvaluationResults (
177
149
pipeline = pipeline ,
178
- opt_loss = self ._loss (labels = self .y_train [opt_split ], preds = opt_pred ),
150
+ opt_loss = self ._loss (labels = self .y_train [opt_split ] if self . train else self . y_test , preds = opt_pred ),
179
151
train_loss = self ._loss (labels = self .y_train [train_split ], preds = train_pred ),
180
152
opt_pred = opt_pred ,
181
153
valid_pred = valid_pred ,
@@ -201,6 +173,7 @@ def _cross_validation(self) -> EvaluationResults:
201
173
results = self ._evaluate_on_split (split_id )
202
174
203
175
self .pipelines [split_id ] = results .pipeline
176
+ assert opt_split is not None # mypy redefinition
204
177
cv_results .update (split_id , results , len (train_split ), len (opt_split ))
205
178
206
179
self .y_opt = np .concatenate ([y_opt for y_opt in Y_opt if y_opt is not None ])
@@ -212,15 +185,16 @@ def evaluate_loss(self) -> None:
212
185
if self .splits is None :
213
186
raise ValueError (f"cannot fit pipeline { self .__class__ .__name__ } with datamanager.splits None" )
214
187
215
- if self .num_folds == 1 :
188
+ if self .cross_validation :
189
+ results = self ._cross_validation ()
190
+ else :
216
191
_ , opt_split = self .splits [0 ]
217
192
results = self ._evaluate_on_split (split_id = 0 )
218
- self .y_opt , self .pipelines [0 ] = self .y_train [opt_split ], results .pipeline
219
- else :
220
- results = self ._cross_validation ()
193
+ self .pipelines [0 ] = results .pipeline
194
+ self .y_opt = self .y_train [opt_split ] if self .train else self .y_test
221
195
222
196
self .logger .debug (
223
- f"In train evaluator. evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
197
+ f"In evaluate_loss, num_run: { self .num_run } , loss:{ results .opt_loss } ,"
224
198
f" status: { results .status } ,\n additional run info:\n { dict_repr (results .additional_run_info )} "
225
199
)
226
200
self .record_evaluation (results = results )
@@ -240,41 +214,23 @@ def _fit_and_evaluate_loss(
240
214
241
215
kwargs = {'pipeline' : pipeline , 'unique_train_labels' : self .unique_train_labels [split_id ]}
242
216
train_pred = self .predict (subsampler (self .X_train , train_indices ), ** kwargs )
243
- opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs )
244
- valid_pred = self .predict (self .X_valid , ** kwargs )
245
217
test_pred = self .predict (self .X_test , ** kwargs )
218
+ valid_pred = self .predict (self .X_valid , ** kwargs )
219
+
220
+ # No resampling ===> evaluate on test dataset
221
+ opt_pred = self .predict (subsampler (self .X_train , opt_indices ), ** kwargs ) if self .train else test_pred
246
222
247
223
assert train_pred is not None and opt_pred is not None # mypy check
248
224
return train_pred , opt_pred , valid_pred , test_pred
249
225
250
226
251
- # create closure for evaluating an algorithm
252
- def eval_train_function (
253
- backend : Backend ,
254
- queue : Queue ,
255
- metric : autoPyTorchMetric ,
256
- budget : float ,
257
- config : Optional [Configuration ],
258
- seed : int ,
259
- output_y_hat_optimization : bool ,
260
- num_run : int ,
261
- include : Optional [Dict [str , Any ]],
262
- exclude : Optional [Dict [str , Any ]],
263
- disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
264
- pipeline_config : Optional [Dict [str , Any ]] = None ,
265
- budget_type : str = None ,
266
- init_params : Optional [Dict [str , Any ]] = None ,
267
- logger_port : Optional [int ] = None ,
268
- all_supported_metrics : bool = True ,
269
- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
270
- instance : str = None ,
271
- ) -> None :
227
+ def eval_fn (queue : Queue , fixed_pipeline_params : FixedPipelineParams , evaluator_params : EvaluatorParams ) -> None :
272
228
"""
273
229
This closure allows the communication between the TargetAlgorithmQuery and the
274
- pipeline trainer (TrainEvaluator ).
230
+ pipeline trainer (Evaluator ).
275
231
276
232
Fundamentally, smac calls the TargetAlgorithmQuery.run() method, which internally
277
- builds a TrainEvaluator . The TrainEvaluator builds a pipeline, stores the output files
233
+ builds an Evaluator . The Evaluator builds a pipeline, stores the output files
278
234
to disc via the backend, and puts the performance result of the run in the queue.
279
235
280
236
Args:
@@ -286,7 +242,11 @@ def eval_train_function(
286
242
evaluator_params (EvaluatorParams):
287
243
The parameters for an evaluator.
288
244
"""
289
- evaluator = TrainEvaluator (
245
+ resampling_strategy = fixed_pipeline_params .backend .load_datamanager ().resampling_strategy
246
+ check_resampling_strategy (resampling_strategy )
247
+
248
+ # NoResamplingStrategyTypes ==> test evaluator, otherwise ==> train evaluator
249
+ evaluator = Evaluator (
290
250
queue = queue ,
291
251
evaluator_params = evaluator_params ,
292
252
fixed_pipeline_params = fixed_pipeline_params
0 commit comments