Skip to content

Commit d4717fb

Browse files
committed
test fix in progress
1 parent d6bb8c8 commit d4717fb

File tree

6 files changed

+19
-15
lines changed

6 files changed

+19
-15
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ class TabularFeatureValidator(BaseFeatureValidator):
7777
transformer.
7878
7979
Attributes:
80-
categories (List[List[str]]):
81-
List for which an element at each index is a
82-
list containing the categories for the respective
83-
categorical column.
80+
num_categories_per_col (List[int]):
81+
List for which an element at each index is the number
82+
of categories for the respective categorical column.
8483
transformed_columns (List[str])
8584
List of columns that were transformed.
8685
column_transformer (Optional[BaseEstimator])

autoPyTorch/datasets/time_series_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def __init__(self,
559559
self.num_features: int = self.validator.feature_validator.num_features # type: ignore[assignment]
560560
self.num_targets: int = self.validator.target_validator.out_dimensionality # type: ignore[assignment]
561561

562-
self.categories = self.validator.feature_validator.categories
562+
self.num_categories_per_col = self.validator.feature_validator.num_categories_per_col
563563

564564
self.feature_shapes = self.validator.feature_shapes
565565
self.feature_names = tuple(self.validator.feature_names)
@@ -1072,7 +1072,7 @@ def get_required_dataset_info(self) -> Dict[str, Any]:
10721072
'categorical_features': self.categorical_features,
10731073
'numerical_columns': self.numerical_columns,
10741074
'categorical_columns': self.categorical_columns,
1075-
'categories': self.categories,
1075+
'num_categories_per_col': self.num_categories_per_col,
10761076
})
10771077
return info
10781078

autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class ColumnSplitter(autoPyTorchTabularPreprocessingComponent):
1818
"""
19-
Removes features that have the same value in the training data.
19+
Splits categorical columns into embed or encode columns based on a hyperparameter.
2020
"""
2121
def __init__(
2222
self,

autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/OneHotEncoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ def __init__(self,
1919
def fit(self, X: Dict[str, Any], y: Any = None) -> TimeSeriesBaseEncoder:
2020
OneHotEncoder.fit(self, X, y)
2121
categorical_columns = X['dataset_properties']['categorical_columns']
22-
n_features_cat = X['dataset_properties']['categories']
22+
num_categories_per_col = X['dataset_properties']['num_categories_per_col']
2323
feature_names = X['dataset_properties']['feature_names']
2424
feature_shapes = X['dataset_properties']['feature_shapes']
2525

26-
if len(n_features_cat) == 0:
27-
n_features_cat = self.preprocessor['categorical'].categories # type: ignore
26+
if len(num_categories_per_col) == 0:
27+
num_categories_per_col = [len(cat) for cat in self.preprocessor['categorical'].categories] # type: ignore
2828
for i, cat_column in enumerate(categorical_columns):
29-
feature_shapes[feature_names[cat_column]] = len(n_features_cat[i])
29+
feature_shapes[feature_names[cat_column]] = num_categories_per_col[i]
3030
self.feature_shapes = feature_shapes
3131
return self
3232

autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/time_series_base_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def __init__(self) -> None:
1515
super(TimeSeriesBaseEncoder, self).__init__()
1616
self.add_fit_requirements([
1717
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
18-
FitRequirement('categories', (List,), user_defined=True, dataset_property=True),
18+
FitRequirement('num_categories_per_col', (List,), user_defined=True, dataset_property=True),
1919
FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True),
2020
FitRequirement('feature_shapes', (Dict, ), user_defined=True, dataset_property=True),
2121
])
22-
self.feature_shapes: Union[Dict[str, int]] = {}
22+
self.feature_shapes: Dict[str, int] = {}
2323

2424
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
2525
"""

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,13 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
447447
raise RuntimeError("Budget exhausted without finishing an epoch.")
448448

449449
if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated:
450-
use_double = 'float64' in X['preprocessed_dtype']
450+
# By default, we assume the data is double. Only if the data was preprocessed,
451+
# we check the dtype and use it accordingly
452+
preprocessed_dtype = X.get('preprocessed_dtype', None)
453+
if preprocessed_dtype is None:
454+
use_double = True
455+
else:
456+
use_double = 'float64' in preprocessed_dtype
451457

452458
# update batch norm statistics
453459
swa_model = self.choice.swa_model.double() if use_double else self.choice.swa_model
@@ -458,7 +464,6 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
458464
# we update only the last network which pertains to the stochastic weight averaging model
459465
snapshot_model = self.choice.model_snapshots[-1].double() if use_double else self.choice.model_snapshots[-1]
460466
swa_utils.update_bn(X['train_data_loader'], snapshot_model)
461-
update_model_state_dict_from_swa(X['network_snapshots'][-1], self.choice.swa_model.state_dict())
462467

463468
# wrap up -- add score if not evaluating every epoch
464469
if not self.eval_valid_each_epoch(X):

0 commit comments

Comments
 (0)