Skip to content

Commit d49ed68

Browse files
committed
[FIX] Tests after rebase of reg_cocktails (#359)
* update requirements * update requirements * resolve remaining conflicts and fix flake and mypy * Fix remaining tests and examples * fix failing checks * fix flake
1 parent ff6e8c4 commit d49ed68

40 files changed

+329
-1057
lines changed

autoPyTorch/api/base_task.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import pandas as pd
2929

30-
from smac.runhistory.runhistory import DataOrigin, RunHistory, RunInfo, RunValue
30+
from smac.runhistory.runhistory import DataOrigin, RunHistory
3131
from smac.stats.stats import Stats
3232
from smac.tae import StatusType
3333

@@ -593,11 +593,16 @@ def _load_models(self) -> bool:
593593
raise ValueError("Resampling strategy is needed to determine what models to load")
594594
self.ensemble_ = self._backend.load_ensemble(self.seed)
595595

596-
if isinstance(self._disable_file_output, List):
597-
disabled_file_outputs = self._disable_file_output
596+
# TODO: remove this code after `fit_pipeline` is rebased.
597+
if hasattr(self, '_disable_file_output'):
598+
if isinstance(self._disable_file_output, List):
599+
disabled_file_outputs = self._disable_file_output
600+
disable_file_output = False
601+
elif isinstance(self._disable_file_output, bool):
602+
disable_file_output = self._disable_file_output
603+
disabled_file_outputs = []
604+
else:
598605
disable_file_output = False
599-
elif isinstance(self._disable_file_output, bool):
600-
disable_file_output = self._disable_file_output
601606
disabled_file_outputs = []
602607

603608
# If no ensemble is loaded, try to get the best performing model
@@ -901,18 +906,15 @@ def run_traditional_ml(
901906
learning algorithm runs over the time limit.
902907
"""
903908
assert self._logger is not None # for mypy compliancy
904-
if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS:
905-
self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...")
906-
else:
907-
traditional_task_name = 'runTraditional'
908-
self._stopwatch.start_task(traditional_task_name)
909-
elapsed_time = self._stopwatch.wall_elapsed(current_task_name)
910-
time_for_traditional = int(runtime_limit - elapsed_time)
911-
self._do_traditional_prediction(
912-
func_eval_time_limit_secs=func_eval_time_limit_secs,
913-
time_left=time_for_traditional,
914-
)
915-
self._stopwatch.stop_task(traditional_task_name)
909+
traditional_task_name = 'runTraditional'
910+
self._stopwatch.start_task(traditional_task_name)
911+
elapsed_time = self._stopwatch.wall_elapsed(current_task_name)
912+
time_for_traditional = int(runtime_limit - elapsed_time)
913+
self._do_traditional_prediction(
914+
func_eval_time_limit_secs=func_eval_time_limit_secs,
915+
time_left=time_for_traditional,
916+
)
917+
self._stopwatch.stop_task(traditional_task_name)
916918

917919
def _search(
918920
self,
@@ -1282,22 +1284,7 @@ def _search(
12821284
self._logger.info("Starting Shutdown")
12831285

12841286
if proc_ensemble is not None:
1285-
self._results_manager.ensemble_performance_history = list(proc_ensemble.history)
1286-
1287-
if len(proc_ensemble.futures) > 0:
1288-
# Also add ensemble runs that did not finish within smac time
1289-
# and add them into the ensemble history
1290-
self._logger.info("Ensemble script still running, waiting for it to finish.")
1291-
result = proc_ensemble.futures.pop().result()
1292-
if result:
1293-
ensemble_history, _, _, _ = result
1294-
self._results_manager.ensemble_performance_history.extend(ensemble_history)
1295-
self._logger.info("Ensemble script finished, continue shutdown.")
1296-
1297-
# save the ensemble performance history file
1298-
if len(self.ensemble_performance_history) > 0:
1299-
pd.DataFrame(self.ensemble_performance_history).to_json(
1300-
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
1287+
self._collect_results_ensemble(proc_ensemble)
13011288

13021289
if load_models:
13031290
self._logger.info("Loading models...")
@@ -1557,7 +1544,7 @@ def fit_pipeline(
15571544
exclude=self.exclude_components,
15581545
search_space_updates=self.search_space_updates)
15591546
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
1560-
self._backend.replace_datamanager(dataset)
1547+
self._backend.save_datamanager(dataset)
15611548

15621549
if self._logger is None:
15631550
self._logger = self._get_logger(dataset.dataset_name)
@@ -1747,7 +1734,7 @@ def fit_ensemble(
17471734
ensemble_fit_task_name = 'EnsembleFit'
17481735
self._stopwatch.start_task(ensemble_fit_task_name)
17491736
if enable_traditional_pipeline:
1750-
if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_for_task:
1737+
if func_eval_time_limit_secs > time_for_task:
17511738
self._logger.warning(
17521739
'Time limit for a single run is higher than total time '
17531740
'limit. Capping the limit for a single run to the total '
@@ -1788,12 +1775,8 @@ def fit_ensemble(
17881775
)
17891776

17901777
manager.build_ensemble(self._dask_client)
1791-
future = manager.futures.pop()
1792-
result = future.result()
1793-
if result is None:
1794-
raise ValueError("Errors occurred while building the ensemble - please"
1795-
" check the log file and command line output for error messages.")
1796-
self.ensemble_performance_history, _, _, _ = result
1778+
if manager is not None:
1779+
self._collect_results_ensemble(manager)
17971780

17981781
if load_models:
17991782
self._load_models()
@@ -1871,6 +1854,31 @@ def _init_ensemble_builder(
18711854

18721855
return proc_ensemble
18731856

1857+
def _collect_results_ensemble(
1858+
self,
1859+
manager: EnsembleBuilderManager
1860+
) -> None:
1861+
1862+
if self._logger is None:
1863+
raise ValueError("logger should be initialized to fit ensemble")
1864+
1865+
self._results_manager.ensemble_performance_history = list(manager.history)
1866+
1867+
if len(manager.futures) > 0:
1868+
# Also add ensemble runs that did not finish within smac time
1869+
# and add them into the ensemble history
1870+
self._logger.info("Ensemble script still running, waiting for it to finish.")
1871+
result = manager.futures.pop().result()
1872+
if result:
1873+
ensemble_history, _, _, _ = result
1874+
self._results_manager.ensemble_performance_history.extend(ensemble_history)
1875+
self._logger.info("Ensemble script finished, continue shutdown.")
1876+
1877+
# save the ensemble performance history file
1878+
if len(self.ensemble_performance_history) > 0:
1879+
pd.DataFrame(self.ensemble_performance_history).to_json(
1880+
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
1881+
18741882
def predict(
18751883
self,
18761884
X_test: np.ndarray,

autoPyTorch/api/tabular_classification.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,18 @@ def search(
371371
self
372372
373373
"""
374+
if dataset_name is None:
375+
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
376+
377+
# we have to create a logger for at this point for the validator
378+
self._logger = self._get_logger(dataset_name)
379+
380+
# Create a validator object to make sure that the data provided by
381+
# the user matches the autopytorch requirements
382+
self.InputValidator = TabularInputValidator(
383+
is_classification=True,
384+
logger_port=self._logger_port,
385+
)
374386

375387
self.dataset, self.InputValidator = self._get_dataset_input_validator(
376388
X_train=X_train,
@@ -389,9 +401,9 @@ def search(
389401
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
390402
)
391403

392-
393404
if self.dataset is None:
394405
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
406+
395407
return self._search(
396408
dataset=self.dataset,
397409
optimize_metric=optimize_metric,
@@ -431,23 +443,23 @@ def predict(
431443
raise ValueError("predict() is only supported after calling search. Kindly call first "
432444
"the estimator search() method.")
433445

434-
X_test = self.input_validator.feature_validator.transform(X_test)
446+
X_test = self.InputValidator.feature_validator.transform(X_test)
435447
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
436448
n_jobs=n_jobs)
437449

438-
if self.input_validator.target_validator.is_single_column_target():
450+
if self.InputValidator.target_validator.is_single_column_target():
439451
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
440452
else:
441453
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
442454

443455
# Allow to predict in the original domain -- that is, the user is not interested
444456
# in our encoded values
445-
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
457+
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
446458

447459
def predict_proba(self,
448460
X_test: Union[np.ndarray, pd.DataFrame, List],
449461
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
450-
if self.input_validator is None or not self.input_validator._is_fitted:
462+
if self.InputValidator is None or not self.InputValidator._is_fitted:
451463
raise ValueError("predict() is only supported after calling search. Kindly call first "
452464
"the estimator search() method.")
453465
X_test = self.InputValidator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def __init__(
8484
delete_output_folder_after_terminate: bool = True,
8585
include_components: Optional[Dict] = None,
8686
exclude_components: Optional[Dict] = None,
87-
resampling_strategy:Union[CrossValTypes,
88-
HoldoutValTypes,
89-
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
87+
resampling_strategy: Union[CrossValTypes,
88+
HoldoutValTypes,
89+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
9090
resampling_strategy_args: Optional[Dict[str, Any]] = None,
9191
backend: Optional[Backend] = None,
9292
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
@@ -386,9 +386,9 @@ def search(
386386
'(CrossValTypes, HoldoutValTypes), but got {}'.format(self.resampling_strategy)
387387
)
388388

389-
390389
if self.dataset is None:
391390
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
391+
392392
return self._search(
393393
dataset=self.dataset,
394394
optimize_metric=optimize_metric,
@@ -414,14 +414,14 @@ def predict(
414414
batch_size: Optional[int] = None,
415415
n_jobs: int = 1
416416
) -> np.ndarray:
417-
if self.input_validator is None or not self.input_validator._is_fitted:
417+
if self.InputValidator is None or not self.InputValidator._is_fitted:
418418
raise ValueError("predict() is only supported after calling search. Kindly call first "
419419
"the estimator search() method.")
420420

421-
X_test = self.input_validator.feature_validator.transform(X_test)
421+
X_test = self.InputValidator.feature_validator.transform(X_test)
422422
predicted_values = super().predict(X_test, batch_size=batch_size,
423423
n_jobs=n_jobs)
424424

425425
# Allow to predict in the original domain -- that is, the user is not interested
426426
# in our encoded values
427-
return self.input_validator.target_validator.inverse_transform(predicted_values)
427+
return self.InputValidator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/base_target_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def fit(
9898
np.shape(y_test)
9999
))
100100
if isinstance(y_train, pd.DataFrame):
101-
y_train = cast(pd.DataFrame, y_train)
102101
y_test = cast(pd.DataFrame, y_test)
103102
if y_train.columns.tolist() != y_test.columns.tolist():
104103
raise ValueError(

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import Dict, List, Optional, Tuple, Union, cast
2+
from typing import Dict, List, Optional, Tuple, Type, Union, cast
33

44
import numpy as np
55

@@ -263,7 +263,7 @@ def transform(
263263
X = self.numpy_to_pandas(X)
264264

265265
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
266-
X = cast(pd.DataFrame, X)
266+
X = cast(Type[pd.DataFrame], X)
267267

268268
# Check the data here so we catch problems on new test data
269269
self._check_data(X)
@@ -391,9 +391,6 @@ def _get_columns_info(
391391
Type of each column numerical/categorical
392392
"""
393393

394-
if len(self.transformed_columns) > 0 and self.feat_type is not None:
395-
return self.transformed_columns, self.feat_type
396-
397394
# Register if a column needs encoding
398395
numerical_columns = []
399396
categorical_columns = []

autoPyTorch/data/tabular_target_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union, cast
1+
from typing import List, Optional, cast
22

33
import numpy as np
44

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ class NoResamplingStrategyTypes(IntEnum):
9292
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes]
9393

9494

95-
DEFAULT_RESAMPLING_PARAMETERS: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] = {
95+
DEFAULT_RESAMPLING_PARAMETERS: Dict[Union[CrossValTypes,
96+
HoldoutValTypes,
97+
NoResamplingStrategyTypes],
98+
Dict[str, Any]] = {
9699
HoldoutValTypes.holdout_validation: {
97100
'val_share': 0.33,
98101
},
@@ -117,7 +120,7 @@ class NoResamplingStrategyTypes(IntEnum):
117120
NoResamplingStrategyTypes.shuffle_no_resampling: {
118121
'shuffle': True
119122
}
120-
} # type: Dict[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes], Dict[str, Any]]
123+
}
121124

122125

123126
class HoldOutFuncs():

autoPyTorch/evaluation/fit_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from smac.tae import StatusType
1212

13+
from autoPyTorch.automl_common.common.utils.backend import Backend
1314
from autoPyTorch.datasets.resampling_strategy import NoResamplingStrategyTypes
1415
from autoPyTorch.evaluation.abstract_evaluator import (
1516
AbstractEvaluator,
1617
fit_and_suppress_warnings
1718
)
1819
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
19-
from autoPyTorch.utils.backend import Backend
2020
from autoPyTorch.utils.common import subsampler
2121
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2222

autoPyTorch/evaluation/tae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import autoPyTorch.evaluation.fit_evaluator
2626
import autoPyTorch.evaluation.train_evaluator
27+
from autoPyTorch.automl_common.common.utils.backend import Backend
2728
from autoPyTorch.datasets.resampling_strategy import (
2829
CrossValTypes,
2930
HoldoutValTypes,

autoPyTorch/optimizer/smbo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CrossValTypes,
2424
DEFAULT_RESAMPLING_PARAMETERS,
2525
HoldoutValTypes,
26+
NoResamplingStrategyTypes
2627
)
2728
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
2829
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
@@ -98,11 +99,13 @@ def __init__(self,
9899
pipeline_config: Dict[str, Any],
99100
start_num_run: int = 1,
100101
seed: int = 1,
101-
resampling_strategy: Union[HoldoutValTypes, CrossValTypes] = HoldoutValTypes.holdout_validation,
102+
resampling_strategy: Union[HoldoutValTypes,
103+
CrossValTypes,
104+
NoResamplingStrategyTypes] = HoldoutValTypes.holdout_validation,
102105
resampling_strategy_args: Optional[Dict[str, Any]] = None,
103106
include: Optional[Dict[str, Any]] = None,
104107
exclude: Optional[Dict[str, Any]] = None,
105-
disable_file_output: List = [],
108+
disable_file_output: Union[bool, List[str]] = False,
106109
smac_scenario_args: Optional[Dict[str, Any]] = None,
107110
get_smac_object_callback: Optional[Callable] = None,
108111
all_supported_metrics: bool = True,
@@ -245,6 +248,10 @@ def __init__(self,
245248
if portfolio_selection is not None:
246249
self.initial_configurations = read_return_initial_configurations(config_space=config_space,
247250
portfolio_selection=portfolio_selection)
251+
if len(self.initial_configurations) == 0:
252+
self.initial_configurations = None
253+
self.logger.warning("None of the portfolio configurations are compatible"
254+
" with the current search space. Skipping initial configuration...")
248255

249256
def reset_data_manager(self) -> None:
250257
if self.datamanager is not None:

0 commit comments

Comments
 (0)