@@ -70,28 +70,31 @@ def get_current_autonet_config(self):
70
70
return self .pipeline .get_pipeline_config (** self .base_config )
71
71
72
72
def get_hyperparameter_search_space (self , X_train = None , Y_train = None , X_valid = None , Y_valid = None , ** autonet_config ):
73
- """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset!
73
+ """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration. !
74
74
75
75
Keyword Arguments:
76
76
X_train {array} -- Training data.
77
77
Y_train {array} -- Targets of training data.
78
78
X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
79
79
Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
80
+ autonet_config{dict} -- if not given and fit already called, config of last fit will be used
80
81
81
82
Returns:
82
83
ConfigurationSpace -- The configuration space that should be optimized.
83
84
"""
84
85
85
86
dataset_info = None
87
+ pipeline_config = dict (self .base_config , ** autonet_config ) if autonet_config else \
88
+ self .get_current_autonet_config ()
86
89
if X_train is not None and Y_train is not None :
87
90
dataset_info_node = self .pipeline [CreateDatasetInfo .get_name ()]
88
- dataset_info = dataset_info_node .fit (pipeline_config = dict ( self . base_config , ** autonet_config ) ,
91
+ dataset_info = dataset_info_node .fit (pipeline_config = pipeline_config ,
89
92
X_train = X_train ,
90
93
Y_train = Y_train ,
91
94
X_valid = X_valid ,
92
95
Y_valid = Y_valid )["dataset_info" ]
93
96
94
- return self .pipeline .get_hyperparameter_search_space (dataset_info = dataset_info , ** self . get_current_autonet_config () )
97
+ return self .pipeline .get_hyperparameter_search_space (dataset_info = dataset_info , ** pipeline_config )
95
98
96
99
@classmethod
97
100
def get_default_pipeline (cls ):
0 commit comments