Skip to content

Commit 33403a8

Browse files
committed
small fix when calling get_hyperparameter_search_space()
1 parent b69a271 commit 33403a8

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

autoPyTorch/core/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,28 +70,31 @@ def get_current_autonet_config(self):
7070
return self.pipeline.get_pipeline_config(**self.base_config)
7171

7272
def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=None, Y_valid=None, **autonet_config):
73-
"""Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset!
73+
"""Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration.!
7474
7575
Keyword Arguments:
7676
X_train {array} -- Training data.
7777
Y_train {array} -- Targets of training data.
7878
X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
7979
Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
80+
autonet_config{dict} -- if not given and fit already called, config of last fit will be used
8081
8182
Returns:
8283
ConfigurationSpace -- The configuration space that should be optimized.
8384
"""
8485

8586
dataset_info = None
87+
pipeline_config = dict(self.base_config, **autonet_config) if autonet_config else \
88+
self.get_current_autonet_config()
8689
if X_train is not None and Y_train is not None:
8790
dataset_info_node = self.pipeline[CreateDatasetInfo.get_name()]
88-
dataset_info = dataset_info_node.fit(pipeline_config=dict(self.base_config, **autonet_config),
91+
dataset_info = dataset_info_node.fit(pipeline_config=pipeline_config,
8992
X_train=X_train,
9093
Y_train=Y_train,
9194
X_valid=X_valid,
9295
Y_valid=Y_valid)["dataset_info"]
9396

94-
return self.pipeline.get_hyperparameter_search_space(dataset_info=dataset_info, **self.get_current_autonet_config())
97+
return self.pipeline.get_hyperparameter_search_space(dataset_info=dataset_info, **pipeline_config)
9598

9699
@classmethod
97100
def get_default_pipeline(cls):

0 commit comments

Comments
 (0)