Skip to content

Commit 3cb7446

Browse files
committed
renamed train_metric to optimize_metric
improved handling of metrics and logs improved return result of fit added training_time as budget time updated benchmark summary plotting fixes
1 parent 33403a8 commit 3cb7446

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+557
-283
lines changed

autoPyTorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
sys.path.append(hpbandster)
44

55
from autoPyTorch.core.autonet_classes import AutoNetClassification, AutoNetMultilabel, AutoNetRegression
6+
from autoPyTorch.data_management.data_manager import DataManager
67
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
78
from autoPyTorch.core.ensemble import AutoNetEnsemble

autoPyTorch/components/ensembles/ensemble_selection.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77

88

99
class EnsembleSelection(AbstractEnsemble):
10-
def __init__(self, ensemble_size, metric, minimize,
10+
def __init__(self, ensemble_size, metric,
1111
sorted_initialization_n_best=0, only_consider_n_best=0,
1212
bagging=False, mode='fast'):
1313
self.ensemble_size = ensemble_size
14-
self.metric = metric
15-
self.minimize = 1 if minimize else -1
14+
self.metric = metric.get_loss_value
1615
self.sorted_initialization_n_best = sorted_initialization_n_best
1716
self.only_consider_n_best = only_consider_n_best
1817
self.bagging = bagging
@@ -56,7 +55,7 @@ def _fast(self, predictions, labels):
5655
ensemble.append(predictions[idx])
5756
order.append(idx)
5857
ensemble_ = np.array(ensemble).mean(axis=0)
59-
ensemble_performance = self.metric(ensemble_, labels) * self.minimize
58+
ensemble_performance = self.metric(ensemble_, labels)
6059
trajectory.append(ensemble_performance)
6160
ensemble_size -= self.sorted_initialization_n_best
6261

@@ -82,7 +81,7 @@ def _fast(self, predictions, labels):
8281
continue
8382
fant_ensemble_prediction[:,:] = weighted_ensemble_prediction + \
8483
(1. / float(s + 1)) * pred
85-
scores[j] = self.metric(fant_ensemble_prediction, labels) * self.minimize
84+
scores[j] = self.metric(fant_ensemble_prediction, labels)
8685
all_best = np.argwhere(scores == np.nanmin(scores)).flatten()
8786
best = np.random.choice(all_best)
8887
ensemble.append(predictions[best])
@@ -113,7 +112,7 @@ def _slow(self, predictions, labels):
113112
ensemble.append(predictions[idx])
114113
order.append(idx)
115114
ensemble_ = np.array(ensemble).mean(axis=0)
116-
ensemble_performance = self.metric(ensemble_, labels) * self.minimize
115+
ensemble_performance = self.metric(ensemble_, labels)
117116
trajectory.append(ensemble_performance)
118117
ensemble_size -= self.sorted_initialization_n_best
119118

@@ -129,7 +128,7 @@ def _slow(self, predictions, labels):
129128
continue
130129
ensemble.append(pred)
131130
ensemble_prediction = np.mean(np.array(ensemble), axis=0)
132-
scores[j] = self.metric(ensemble_prediction, labels) * self.minimize
131+
scores[j] = self.metric(ensemble_prediction, labels)
133132
ensemble.pop()
134133
best = np.nanargmin(scores)
135134
ensemble.append(predictions[best])
@@ -160,7 +159,7 @@ def _sorted_initialization(self, predictions, labels, n_best):
160159
perf = np.zeros([predictions.shape[0]])
161160

162161
for idx, prediction in enumerate(predictions):
163-
perf[idx] = self.metric(prediction, labels) * self.minimize
162+
perf[idx] = self.metric(prediction, labels)
164163

165164
indices = np.argsort(perf)[:n_best]
166165
return indices

autoPyTorch/components/metrics/balanced_accuracy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
from sklearn.metrics.classification import _check_targets, type_of_target
55

66

7-
def balanced_accuracy(y_pred, y_true):
8-
return _balanced_accuracy(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1)) * 100
7+
def balanced_accuracy(solution, prediction):
98

10-
11-
def _balanced_accuracy(solution, prediction):
129
y_type, solution, prediction = _check_targets(solution, prediction)
1310

1411
if y_type not in ["binary", "multiclass", 'multilabel-indicator']:
@@ -65,4 +62,4 @@ def _balanced_accuracy(solution, prediction):
6562
else:
6663
raise ValueError(y_type)
6764

68-
return np.mean(bac) # average over all classes
65+
return np.mean(bac) # average over all classes

autoPyTorch/components/metrics/pac_score.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
from sklearn.metrics.classification import _check_targets, type_of_target
55

66

7-
def pac_metric(y_pred, y_true):
8-
return _pac_score(y_true, y_pred) * 100
9-
10-
11-
def _pac_score(solution, prediction):
7+
def pac_metric(solution, prediction):
128
"""
139
Probabilistic Accuracy based on log_loss metric.
1410
We assume the solution is in {0, 1} and prediction in [0, 1].

autoPyTorch/components/metrics/standard_metrics.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,20 @@
22
import numpy as np
33

44
# classification metrics
5-
def accuracy(y_pred, y_true):
6-
return np.mean((undo_ohe(y_true) == undo_ohe(y_pred))) * 100
75

8-
def auc_metric(y_pred, y_true):
9-
return (2 * metrics.roc_auc_score(y_true, y_pred) - 1) * 100
6+
7+
def accuracy(y_true, y_pred):
8+
return np.mean(y_true == y_pred)
9+
10+
11+
def auc_metric(y_true, y_pred):
12+
return (2 * metrics.roc_auc_score(y_true, y_pred) - 1)
1013

1114

1215
# multilabel metric
13-
def multilabel_accuracy(y_pred, y_true):
14-
return np.mean(y_true == (y_pred > 0.5)) * 100
16+
def multilabel_accuracy(y_true, y_pred):
17+
return np.mean(y_true == (y_pred > 0.5))
1518

1619
# regression metric
17-
def mean_distance(y_pred, y_true):
20+
def mean_distance(y_true, y_pred):
1821
return np.mean(np.abs(y_true - y_pred))
19-
20-
def undo_ohe(y):
21-
if len(y.shape) == 1:
22-
return(y)
23-
return np.argmax(y, axis=1)

autoPyTorch/components/training/budget_types.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def on_batch_end(self, **kwargs):
2121

2222
# OVERRIDE
2323
def on_epoch_end(self, trainer, **kwargs):
24-
elapsed = time.time() - self.start_time
24+
elapsed = time.time() - trainer.fit_start_time
2525
trainer.model.budget_trained = elapsed
26-
trainer.logger.debug("Budget used: " + str(elapsed) + "/" + str(self.end_time - self.start_time))
26+
trainer.logger.debug("Budget used: " + str(elapsed) + "/" + str(trainer.budget - self.compensate))
2727

2828
if time.time() >= self.end_time:
2929
trainer.logger.debug("Budget exhausted!")
@@ -47,4 +47,33 @@ def on_epoch_end(self, trainer, epoch, **kwargs):
4747
if epoch >= self.target:
4848
trainer.logger.debug("Budget exhausted!")
4949
return True
50-
return False
50+
return False
51+
52+
class BudgetTypeTrainingTime(BaseTrainingTechnique):
53+
default_min_budget = 120
54+
default_max_budget = 6000
55+
56+
# OVERRIDE
57+
def set_up(self, trainer, pipeline_config, **kwargs):
58+
super(BudgetTypeTrainingTime, self).set_up(trainer, pipeline_config)
59+
self.end_time = trainer.budget + time.time()
60+
self.start_time = time.time()
61+
62+
if self.start_time >= self.end_time:
63+
raise Exception("Budget exhausted before training started")
64+
65+
# OVERRIDE
66+
def on_batch_end(self, **kwargs):
67+
return time.time() >= self.end_time
68+
69+
# OVERRIDE
70+
def on_epoch_end(self, trainer, **kwargs):
71+
elapsed = time.time() - self.start_time
72+
trainer.model.budget_trained = elapsed
73+
trainer.logger.debug("Budget used: " + str(elapsed) +
74+
"/" + str(self.end_time - self.start_time))
75+
76+
if time.time() >= self.end_time:
77+
trainer.logger.debug("Budget exhausted!")
78+
return True
79+
return False

autoPyTorch/components/training/early_stopping.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class EarlyStopping(BaseTrainingTechnique):
1111
def set_up(self, trainer, pipeline_config, **kwargs):
1212
super(EarlyStopping, self).set_up(trainer, pipeline_config)
1313
self.reset_parameters = pipeline_config["early_stopping_reset_parameters"]
14-
self.minimize = pipeline_config["minimize"]
1514
self.patience = pipeline_config["early_stopping_patience"]
15+
self.loss_transform = trainer.metrics[0].loss_transform
1616

1717
# does not work with e.g. cosine anealing with warm restarts
1818
if hasattr(trainer, "lr_scheduler") and not trainer.lr_scheduler.allows_early_stopping:
@@ -21,8 +21,6 @@ def set_up(self, trainer, pipeline_config, **kwargs):
2121
# initialize current best performance to +/- infinity
2222
if trainer.model.current_best_epoch_performance is None:
2323
trainer.model.current_best_epoch_performance = float("inf")
24-
if not self.minimize:
25-
trainer.model.current_best_epoch_performance = -float("inf")
2624

2725
trainer.logger.debug("Using Early stopping with patience: " + str(self.patience))
2826
trainer.logger.debug("Reset Parameters to parameters with best validation performance: " + str(self.reset_parameters))
@@ -35,11 +33,10 @@ def on_epoch_end(self, trainer, log, **kwargs):
3533
return False
3634
if self.reset_parameters and (not hasattr(trainer, "lr_scheduler") or not trainer.lr_scheduler.snapshot_before_restart):
3735
log["best_parameters"] = False
38-
current_performance = log["val_" + trainer.metrics[0]]
36+
current_performance = self.loss_transform(log["val_" + trainer.metrics[0]])
3937

4038
# new best performance
41-
if ((self.minimize and current_performance < trainer.model.current_best_epoch_performance) or
42-
(not self.minimize and current_performance > trainer.model.current_best_epoch_performance)):
39+
if current_performance < trainer.model.current_best_epoch_performance:
4340
trainer.model.num_epochs_no_progress = 0
4441
trainer.model.current_best_epoch_performance = current_performance
4542
trainer.logger.debug("New best performance!")

autoPyTorch/components/training/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def on_epoch_start(self, log, epoch):
6666
def on_epoch_end(self, log, epoch):
6767
return any([t.on_epoch_end(trainer=self, log=log, epoch=epoch) for t in self.training_techniques])
6868

69-
def final_eval(self, opt_metric_name, logs, train_loader, valid_loader, minimize, best_over_epochs, refit):
69+
def final_eval(self, opt_metric_name, logs, train_loader, valid_loader, best_over_epochs, refit):
7070
# select log
7171
if best_over_epochs:
72-
final_log = (min if minimize else max)(logs, key=lambda log: log[opt_metric_name])
72+
final_log = min(logs, key=lambda log: self.metrics[0].loss_transform(log[opt_metric_name]))
7373
else:
7474
final_log = None
7575
for t in self.training_techniques:
@@ -87,10 +87,10 @@ def final_eval(self, opt_metric_name, logs, train_loader, valid_loader, minimize
8787

8888
for i, metric in enumerate(self.metrics):
8989
if valid_metric_results:
90-
final_log['val_' + metric.__name__] = valid_metric_results[i]
90+
final_log['val_' + metric.name] = valid_metric_results[i]
9191
if self.eval_additional_logs_on_snapshot and not refit:
9292
for additional_log in self.log_functions:
93-
final_log[additional_log.__name__] = additional_log(self.model, None)
93+
final_log[additional_log.name] = additional_log(self.model, None)
9494
return final_log
9595

9696
def train(self, epoch, train_loader):

autoPyTorch/core/api.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import numpy as np
7+
import scipy.sparse
78
import torch
89
import torch.nn as nn
910
import copy
@@ -21,6 +22,7 @@
2122
from autoPyTorch.utils.config.config_file_parser import ConfigFileParser
2223

2324
class AutoNet():
25+
"""Find an optimal neural network given a ML-task using BOHB"""
2426
preset_folder_name = None
2527

2628
def __init__(self, config_preset="medium_cs", pipeline=None, **autonet_config):
@@ -34,6 +36,7 @@ def __init__(self, config_preset="medium_cs", pipeline=None, **autonet_config):
3436
self.base_config = autonet_config
3537
self.autonet_config = None
3638
self.fit_result = None
39+
self.dataset_info = None
3740

3841
if config_preset is not None:
3942
parser = self.get_autonet_config_file_parser()
@@ -70,10 +73,11 @@ def get_current_autonet_config(self):
7073
return self.pipeline.get_pipeline_config(**self.base_config)
7174

7275
def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=None, Y_valid=None, **autonet_config):
73-
"""Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration.!
76+
"""Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration!
77+
You can either pass the dataset and the configuration or use dataset and configuration of last fit call.
7478
7579
Keyword Arguments:
76-
X_train {array} -- Training data.
80+
X_train {array} -- Training data. ConfigSpace depends on Training data.
7781
Y_train {array} -- Targets of training data.
7882
X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
7983
Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
@@ -82,8 +86,8 @@ def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=No
8286
Returns:
8387
ConfigurationSpace -- The configuration space that should be optimized.
8488
"""
85-
86-
dataset_info = None
89+
X_train, Y_train, X_valid, Y_valid = self.check_data_array_types(X_train, Y_train, X_valid, Y_valid)
90+
dataset_info = self.dataset_info
8791
pipeline_config = dict(self.base_config, **autonet_config) if autonet_config else \
8892
self.get_current_autonet_config()
8993
if X_train is not None and Y_train is not None:
@@ -129,21 +133,22 @@ def fit(self, X_train, Y_train, X_valid=None, Y_valid=None, refit=True, **autone
129133
130134
Returns:
131135
optimized_hyperparameter_config -- The best found hyperparameter config.
132-
final_metric_score -- The final score of the specified train metric.
133136
**autonet_config -- Configure AutoNet for your needs. You can also configure AutoNet in the constructor(). Call print_help() for more info.
134137
"""
138+
X_train, Y_train, X_valid, Y_valid = self.check_data_array_types(X_train, Y_train, X_valid, Y_valid)
135139
self.autonet_config = self.pipeline.get_pipeline_config(**dict(self.base_config, **autonet_config))
136140

137141
self.fit_result = self.pipeline.fit_pipeline(pipeline_config=self.autonet_config,
138142
X_train=X_train, Y_train=Y_train, X_valid=X_valid, Y_valid=Y_valid)
143+
self.dataset_info = self.pipeline[CreateDatasetInfo.get_name()].fit_output["dataset_info"]
139144
self.pipeline.clean()
140145

141146
if not self.fit_result["optimized_hyperparameter_config"]:
142147
raise RuntimeError("No models fit during training, please retry with a larger max_runtime.")
143148

144149
if (refit):
145150
self.refit(X_train, Y_train, X_valid, Y_valid)
146-
return self.fit_result["optimized_hyperparameter_config"], self.fit_result['final_metric_score']
151+
return self.fit_result
147152

148153
def refit(self, X_train, Y_train, X_valid=None, Y_valid=None, hyperparameter_config=None, autonet_config=None, budget=None, rescore=False):
149154
"""Refit AutoNet to given hyperparameters. This will skip hyperparameter search.
@@ -163,6 +168,7 @@ def refit(self, X_train, Y_train, X_valid=None, Y_valid=None, hyperparameter_con
163168
Raises:
164169
ValueError -- No hyperparameter config available
165170
"""
171+
X_train, Y_train, X_valid, Y_valid = self.check_data_array_types(X_train, Y_train, X_valid, Y_valid)
166172
if (autonet_config is None):
167173
autonet_config = self.autonet_config
168174
if (autonet_config is None):
@@ -182,9 +188,8 @@ def refit(self, X_train, Y_train, X_valid=None, Y_valid=None, hyperparameter_con
182188
'budget': budget,
183189
'rescore': rescore}
184190

185-
result = self.pipeline.fit_pipeline(pipeline_config=autonet_config, refit=refit_data,
186-
X_train=X_train, Y_train=Y_train, X_valid=X_valid, Y_valid=Y_valid)
187-
return result["final_metric_score"]
191+
return self.pipeline.fit_pipeline(pipeline_config=autonet_config, refit=refit_data,
192+
X_train=X_train, Y_train=Y_train, X_valid=X_valid, Y_valid=Y_valid)
188193

189194
def predict(self, X, return_probabilities=False):
190195
"""Predict the targets for a data matrix X.
@@ -200,6 +205,7 @@ def predict(self, X, return_probabilities=False):
200205
"""
201206

202207
# run predict pipeline
208+
X, = self.check_data_array_types(X)
203209
autonet_config = self.autonet_config or self.base_config
204210
Y_pred = self.pipeline.predict_pipeline(pipeline_config=autonet_config, X=X)['Y']
205211

@@ -208,8 +214,8 @@ def predict(self, X, return_probabilities=False):
208214
result = OHE.reverse_transform_y(Y_pred, OHE.fit_output['y_one_hot_encoder'])
209215
return result if not return_probabilities else (result, Y_pred)
210216

211-
def score(self, X_test, Y_test):
212-
"""Calculate the sore on test data using the specified train_metric
217+
def score(self, X_test, Y_test, return_loss_value=False):
218+
"""Calculate the sore on test data using the specified optimize_metric
213219
214220
Arguments:
215221
X_test {array} -- The test data matrix.
@@ -220,6 +226,7 @@ def score(self, X_test, Y_test):
220226
"""
221227

222228
# run predict pipeline
229+
X_test, Y_test = self.check_data_array_types(X_test, Y_test)
223230
autonet_config = self.autonet_config or self.base_config
224231
self.pipeline.predict_pipeline(pipeline_config=autonet_config, X=X_test)
225232
Y_pred = self.pipeline[OptimizationAlgorithm.get_name()].predict_output['Y']
@@ -228,5 +235,19 @@ def score(self, X_test, Y_test):
228235
OHE = self.pipeline[OneHotEncoding.get_name()]
229236
Y_test = OHE.transform_y(Y_test, OHE.fit_output['y_one_hot_encoder'])
230237

231-
metric = self.pipeline[MetricSelector.get_name()].fit_output['train_metric']
238+
metric = self.pipeline[MetricSelector.get_name()].fit_output['optimize_metric']
239+
if return_loss_value:
240+
return metric.get_loss_value(Y_pred, Y_test)
232241
return metric(Y_pred, Y_test)
242+
243+
def check_data_array_types(self, *arrays):
244+
result = []
245+
for array in arrays:
246+
if array is None or scipy.sparse.issparse(array):
247+
result.append(array)
248+
continue
249+
250+
result.append(np.asanyarray(array))
251+
if not result[-1].shape:
252+
raise RuntimeError("Given data-array is of unexpected type %s. Please pass numpy arrays instead." % type(array))
253+
return result

0 commit comments

Comments
 (0)