Skip to content

Commit d58dd9d

Browse files
committed
TODO: fix errors after rebase
1 parent d4717fb commit d58dd9d

File tree

5 files changed

+17
-6
lines changed

5 files changed

+17
-6
lines changed

autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1212
from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent
13-
from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms, preprocess
13+
from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms, get_preprocessed_dtype, preprocess
1414
from autoPyTorch.utils.common import FitRequirement
1515

1616

@@ -39,11 +39,13 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
3939

4040
X['X_train'] = preprocess(dataset=X_train, transforms=transforms)
4141

42+
preprocessed_dtype = get_preprocessed_dtype(X['X_train'])
43+
4244
# We need to also save the preprocess transforms for inference
4345
X.update({
4446
'preprocess_transforms': transforms,
4547
'shape_after_preprocessing': X['X_train'].shape[1:],
46-
'preprocessed_dtype': X['X_train'].dtype.name
48+
'preprocessed_dtype': preprocessed_dtype
4749
})
4850
return X
4951

autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import \
1111
EarlyPreprocessing
1212
from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import (
13-
get_preprocess_transforms, time_series_preprocess)
13+
get_preprocess_transforms, get_preprocessed_dtype, time_series_preprocess)
1414
from autoPyTorch.utils.common import FitRequirement
1515

1616

@@ -62,11 +62,12 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
6262
new_feature_names += list(set(feature_names) - set(new_feature_names))
6363
X['dataset_properties']['feature_names'] = tuple(new_feature_names)
6464

65+
preprocessed_dtype = get_preprocessed_dtype(X['X_train'])
6566
# We need to also save the preprocess transforms for inference
6667
X.update({
6768
'preprocess_transforms': transforms,
6869
'shape_after_preprocessing': X['X_train'].shape[1:],
69-
'preprocessed_dtype': X['X_train'].dtype.name
70+
'preprocessed_dtype': preprocessed_dtype
7071
})
7172
return X
7273

autoPyTorch/pipeline/components/setup/early_preprocessor/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
autoPyTorchPreprocessingComponent as aPTPre,
1414
autoPyTorchTargetPreprocessingComponent as aPTTPre
1515
)
16+
from .....utils.common import ispandas
1617

1718

1819
def get_preprocess_transforms(X: Dict[str, Any],
@@ -71,3 +72,10 @@ def time_series_preprocess(dataset: pd.DataFrame, transforms: torchvision.transf
7172
sub_dataset = composite_transforms(sub_dataset)
7273
dataset.iloc[:, indices] = sub_dataset
7374
return dataset
75+
76+
77+
def get_preprocessed_dtype(X_train: Union[np.ndarray, pd.DataFrame]):
78+
if ispandas(X_train):
79+
return X_train.dtypes[X_train.columns].name
80+
else:
81+
return X_train.dtype.name

autoPyTorch/pipeline/components/training/trainer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic
453453
if preprocessed_dtype is None:
454454
use_double = True
455455
else:
456-
use_double = 'float64' in preprocessed_dtype
456+
use_double = 'float64' in preprocessed_dtype or 'int64' in preprocessed_dtype
457457

458458
# update batch norm statistics
459459
swa_model = self.choice.swa_model.double() if use_double else self.choice.swa_model

test/test_pipeline/test_time_series_forecasting_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class TestTimeSeriesForecastingPipeline:
4646
"multi_variant_only_num"], indirect=True)
4747
def test_fit_predict(self, fit_dictionary_forecasting, forecasting_budgets):
4848
dataset_properties = fit_dictionary_forecasting['dataset_properties']
49-
if not dataset_properties['uni_variant'] and len(dataset_properties['categories']) > 0:
49+
if not dataset_properties['uni_variant'] and len(dataset_properties['num_categories_per_col']) > 0:
5050
include = {'network_embedding': ['LearnedEntityEmbedding']}
5151
else:
5252
include = None

0 commit comments

Comments
 (0)