Skip to content

Commit fb86e23

Browse files
committed
[FIX] Enable preprocessing in reg_cocktails (#369)
* enable preprocessing and remove is_small_preprocess * address comments from shuhei and fix precommit checks * fix tests * fix precommit checks * add suggestions from shuhei for astype use * address speed issue when using object_dtype_mapping * make code more readable * improve documentation for base network embedding
1 parent 596b21d commit fb86e23

34 files changed

+269
-771
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21-
CrossValTypes,
2221
ResamplingStrategies,
2322
)
2423
from autoPyTorch.datasets.tabular_dataset import TabularDataset
@@ -433,8 +432,13 @@ def search(
433432
y_test=y_test,
434433
resampling_strategy=self.resampling_strategy,
435434
resampling_strategy_args=self.resampling_strategy_args,
435+
<<<<<<< HEAD
436436
dataset_name=dataset_name,
437437
dataset_compression=self._dataset_compression)
438+
=======
439+
dataset_name=dataset_name
440+
)
441+
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
438442

439443
return self._search(
440444
dataset=self.dataset,

autoPyTorch/api/tabular_regression.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1919
from autoPyTorch.datasets.resampling_strategy import (
2020
HoldoutValTypes,
21-
CrossValTypes,
2221
ResamplingStrategies,
2322
)
2423
from autoPyTorch.datasets.tabular_dataset import TabularDataset
@@ -433,8 +432,13 @@ def search(
433432
y_test=y_test,
434433
resampling_strategy=self.resampling_strategy,
435434
resampling_strategy_args=self.resampling_strategy_args,
435+
<<<<<<< HEAD
436436
dataset_name=dataset_name,
437437
dataset_compression=self._dataset_compression)
438+
=======
439+
dataset_name=dataset_name
440+
)
441+
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
438442

439443
return self._search(
440444
dataset=self.dataset,

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.exceptions import NotFittedError
1717
from sklearn.impute import SimpleImputer
1818
from sklearn.pipeline import make_pipeline
19-
from sklearn.preprocessing import OneHotEncoder, StandardScaler
19+
from sklearn.preprocessing import OrdinalEncoder
2020

2121
from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SupportedFeatTypes
2222
from autoPyTorch.utils.common import ispandas
@@ -25,7 +25,6 @@
2525

2626
def _create_column_transformer(
2727
preprocessors: Dict[str, List[BaseEstimator]],
28-
numerical_columns: List[str],
2928
categorical_columns: List[str],
3029
) -> ColumnTransformer:
3130
"""
@@ -36,49 +35,36 @@ def _create_column_transformer(
3635
Args:
3736
preprocessors (Dict[str, List[BaseEstimator]]):
3837
Dictionary containing list of numerical and categorical preprocessors.
39-
numerical_columns (List[str]):
40-
List of names of numerical columns
4138
categorical_columns (List[str]):
4239
List of names of categorical columns
4340
4441
Returns:
4542
ColumnTransformer
4643
"""
4744

48-
numerical_pipeline = 'drop'
49-
categorical_pipeline = 'drop'
50-
if len(numerical_columns) > 0:
51-
numerical_pipeline = make_pipeline(*preprocessors['numerical'])
52-
if len(categorical_columns) > 0:
53-
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
45+
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
5446

5547
return ColumnTransformer([
56-
('categorical_pipeline', categorical_pipeline, categorical_columns),
57-
('numerical_pipeline', numerical_pipeline, numerical_columns)],
58-
remainder='drop'
48+
('categorical_pipeline', categorical_pipeline, categorical_columns)],
49+
remainder='passthrough'
5950
)
6051

6152

6253
def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
6354
"""
6455
This function creates a Dictionary containing a list
6556
of numerical and categorical preprocessors
66-
6757
Returns:
6858
Dict[str, List[BaseEstimator]]
6959
"""
7060
preprocessors: Dict[str, List[BaseEstimator]] = dict()
7161

7262
# Categorical Preprocessors
73-
onehot_encoder = OneHotEncoder(categories='auto', sparse=False, handle_unknown='ignore')
63+
ordinal_encoder = OrdinalEncoder(handle_unknown='use_encoded_value',
64+
unknown_value=-1)
7465
categorical_imputer = SimpleImputer(strategy='constant', copy=False)
7566

76-
# Numerical Preprocessors
77-
numerical_imputer = SimpleImputer(strategy='median', copy=False)
78-
standard_scaler = StandardScaler(with_mean=True, with_std=True, copy=False)
79-
80-
preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
81-
preprocessors['numerical'] = [numerical_imputer, standard_scaler]
67+
preprocessors['categorical'] = [categorical_imputer, ordinal_encoder]
8268

8369
return preprocessors
8470

@@ -170,31 +156,47 @@ def _fit(
170156
if ispandas(X) and not issparse(X):
171157
X = cast(pd.DataFrame, X)
172158

173-
self.all_nan_columns = set([column for column in X.columns if X[column].isna().all()])
159+
all_nan_columns = X.columns[X.isna().all()]
160+
for col in all_nan_columns:
161+
X[col] = pd.to_numeric(X[col])
162+
163+
# Handle objects if possible
164+
exist_object_columns = has_object_columns(X.dtypes.values)
165+
if exist_object_columns:
166+
X = self.infer_objects(X)
174167

175-
categorical_columns, numerical_columns, feat_type = self._get_columns_info(X)
168+
self.dtypes = [dt.name for dt in X.dtypes] # Also note this change in self.dtypes
169+
self.all_nan_columns = set(all_nan_columns)
176170

177-
self.enc_columns = categorical_columns
171+
self.enc_columns, self.feat_type = self._get_columns_info(X)
178172

179-
preprocessors = get_tabular_preprocessors()
180-
self.column_transformer = _create_column_transformer(
181-
preprocessors=preprocessors,
182-
numerical_columns=numerical_columns,
183-
categorical_columns=categorical_columns,
184-
)
173+
if len(self.enc_columns) > 0:
185174

186-
# Mypy redefinition
187-
assert self.column_transformer is not None
188-
self.column_transformer.fit(X)
175+
preprocessors = get_tabular_preprocessors()
176+
self.column_transformer = _create_column_transformer(
177+
preprocessors=preprocessors,
178+
categorical_columns=self.enc_columns,
179+
)
189180

190-
# The column transformer reorders the feature types
191-
# therefore, we need to change the order of columns as well
192-
# This means categorical columns are shifted to the left
181+
# Mypy redefinition
182+
assert self.column_transformer is not None
183+
self.column_transformer.fit(X)
193184

194-
self.feat_type = sorted(
195-
feat_type,
196-
key=functools.cmp_to_key(self._comparator)
197-
)
185+
# The column transformer moves categorical columns before all numerical columns
186+
# therefore, we need to sort categorical columns so that it complies this change
187+
188+
self.feat_type = sorted(
189+
self.feat_type,
190+
key=functools.cmp_to_key(self._comparator)
191+
)
192+
193+
encoded_categories = self.column_transformer.\
194+
named_transformers_['categorical_pipeline'].\
195+
named_steps['ordinalencoder'].categories_
196+
self.categories = [
197+
list(range(len(cat)))
198+
for cat in encoded_categories
199+
]
198200

199201
# differently to categorical_columns and numerical_columns,
200202
# this saves the index of the column.
@@ -274,6 +276,23 @@ def transform(
274276
if ispandas(X) and not issparse(X):
275277
X = cast(pd.DataFrame, X)
276278

279+
if self.all_nan_columns is None:
280+
raise ValueError('_fit must be called before calling transform')
281+
282+
for col in list(self.all_nan_columns):
283+
X[col] = np.nan
284+
X[col] = pd.to_numeric(X[col])
285+
286+
if len(self.categorical_columns) > 0:
287+
# when some categorical columns are not all nan in the training set
288+
# but they are all nan in the testing or validation set
289+
# we change those columns to `object` dtype
290+
# to ensure that these columns are changed to appropriate dtype
291+
# in self.infer_objects
292+
all_nan_cat_cols = set(X[self.enc_columns].columns[X[self.enc_columns].isna().all()])
293+
dtype_dict = {col: 'object' for col in self.enc_columns if col in all_nan_cat_cols}
294+
X = X.astype(dtype_dict)
295+
277296
# Check the data here so we catch problems on new test data
278297
self._check_data(X)
279298

@@ -282,11 +301,6 @@ def transform(
282301
# We need to convert the column in test data to
283302
# object otherwise the test column is interpreted as float
284303
if self.column_transformer is not None:
285-
if len(self.categorical_columns) > 0:
286-
categorical_columns = self.column_transformer.transformers_[0][-1]
287-
for column in categorical_columns:
288-
if X[column].isna().all():
289-
X[column] = X[column].astype('object')
290304
X = self.column_transformer.transform(X)
291305

292306
# Sparse related transformations
@@ -371,7 +385,6 @@ def _check_data(
371385
self.column_order = column_order
372386

373387
dtypes = [dtype.name for dtype in X.dtypes]
374-
375388
diff_cols = X.columns[[s_dtype != dtype for s_dtype, dtype in zip(self.dtypes, dtypes)]]
376389
if len(self.dtypes) == 0:
377390
self.dtypes = dtypes
@@ -383,7 +396,7 @@ def _check_data(
383396
def _get_columns_info(
384397
self,
385398
X: pd.DataFrame,
386-
) -> Tuple[List[str], List[str], List[str]]:
399+
) -> Tuple[List[str], List[str]]:
387400
"""
388401
Return the columns to be encoded from a pandas dataframe
389402
@@ -402,15 +415,12 @@ def _get_columns_info(
402415
"""
403416

404417
# Register if a column needs encoding
405-
numerical_columns = []
406418
categorical_columns = []
407419
# Also, register the feature types for the estimator
408420
feat_type = []
409421

410422
# Make sure each column is a valid type
411423
for i, column in enumerate(X.columns):
412-
if self.all_nan_columns is not None and column in self.all_nan_columns:
413-
continue
414424
column_dtype = self.dtypes[i]
415425
err_msg = "Valid types are `numerical`, `categorical` or `boolean`, " \
416426
"but input column {} has an invalid type `{}`.".format(column, column_dtype)
@@ -421,7 +431,6 @@ def _get_columns_info(
421431
# TypeError: data type not understood in certain pandas types
422432
elif is_numeric_dtype(column_dtype):
423433
feat_type.append('numerical')
424-
numerical_columns.append(column)
425434
elif column_dtype == 'object':
426435
# TODO verify how would this happen when we always convert the object dtypes to category
427436
raise TypeError(
@@ -447,7 +456,7 @@ def _get_columns_info(
447456
"before feeding it to AutoPyTorch.".format(err_msg)
448457
)
449458

450-
return categorical_columns, numerical_columns, feat_type
459+
return categorical_columns, feat_type
451460

452461
def list_to_pandas(
453462
self,
@@ -517,22 +526,26 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
517526
pd.DataFrame
518527
"""
519528
if hasattr(self, 'object_dtype_mapping'):
520-
# Mypy does not process the has attr. This dict is defined below
521-
for key, dtype in self.object_dtype_mapping.items(): # type: ignore[has-type]
522-
# honor the training data types
523-
try:
524-
X[key] = X[key].astype(dtype.name)
525-
except Exception as e:
526-
# Try inference if possible
527-
self.logger.warning(f'Casting the column {key} to {dtype} caused the exception {e}')
528-
pass
529+
# honor the training data types
530+
try:
531+
# Mypy does not process the has attr.
532+
X = X.astype(self.object_dtype_mapping) # type: ignore[has-type]
533+
except Exception as e:
534+
# Try inference if possible
535+
self.logger.warning(f'Casting the columns to training dtypes ' # type: ignore[has-type]
536+
f'{self.object_dtype_mapping} caused the exception {e}')
537+
pass
529538
else:
530-
# Calling for the first time to infer the categories
531-
X = X.infer_objects()
532-
for column, data_type in zip(X.columns, X.dtypes):
533-
if not is_numeric_dtype(data_type):
534-
X[column] = X[column].astype('category')
535-
539+
if len(self.dtypes) != 0:
540+
# when train data has no object dtype, but test does
541+
# we prioritise the datatype given in training data
542+
dtype_dict = {col: dtype for col, dtype in zip(X.columns, self.dtypes)}
543+
X = X.astype(dtype_dict)
544+
else:
545+
# Calling for the first time to infer the categories
546+
X = X.infer_objects()
547+
dtype_dict = {col: 'category' for col, dtype in zip(X.columns, X.dtypes) if not is_numeric_dtype(dtype)}
548+
X = X.astype(dtype_dict)
536549
# only numerical attributes and categories
537550
self.object_dtype_mapping = {column: data_type for column, data_type in zip(X.columns, X.dtypes)}
538551

autoPyTorch/datasets/base_dataset.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def __init__(
155155
self.holdout_validators: Dict[str, HoldOutFunc] = {}
156156
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
157157
self.random_state = np.random.RandomState(seed=seed)
158-
self.no_resampling_validators: Dict[str, NoResamplingFunc] = {}
159158
self.shuffle = shuffle
160159
self.resampling_strategy = resampling_strategy
161160
self.resampling_strategy_args = resampling_strategy_args
@@ -165,10 +164,6 @@ def __init__(
165164
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
166165
self.output_shape, self.output_type = _get_output_properties(self.train_tensors)
167166

168-
# TODO: Look for a criteria to define small enough to preprocess
169-
# False for the regularization cocktails initially
170-
self.is_small_preprocess = False
171-
172167
# Make sure cross validation splits are created once
173168
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
174169
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,6 @@ def __call__(self, random_state: np.random.RandomState, val_share: float,
3939
...
4040

4141

42-
class NoResamplingFunc(Protocol):
43-
def __call__(self,
44-
random_state: np.random.RandomState,
45-
indices: np.ndarray) -> np.ndarray:
46-
...
47-
48-
4942
class CrossValTypes(IntEnum):
5043
"""The type of cross validation
5144

0 commit comments

Comments
 (0)