Skip to content

Commit 45a7043

Browse files
committed
fixes after rebase
1 parent fb86e23 commit 45a7043

File tree

4 files changed

+14
-24
lines changed

4 files changed

+14
-24
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,8 @@ def search(
432432
y_test=y_test,
433433
resampling_strategy=self.resampling_strategy,
434434
resampling_strategy_args=self.resampling_strategy_args,
435-
<<<<<<< HEAD
436435
dataset_name=dataset_name,
437436
dataset_compression=self._dataset_compression)
438-
=======
439-
dataset_name=dataset_name
440-
)
441-
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
442437

443438
return self._search(
444439
dataset=self.dataset,
@@ -479,23 +474,23 @@ def predict(
479474
raise ValueError("predict() is only supported after calling search. Kindly call first "
480475
"the estimator search() method.")
481476

482-
X_test = self.InputValidator.feature_validator.transform(X_test)
477+
X_test = self.input_validator.feature_validator.transform(X_test)
483478
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
484479
n_jobs=n_jobs)
485480

486-
if self.InputValidator.target_validator.is_single_column_target():
481+
if self.input_validator.target_validator.is_single_column_target():
487482
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
488483
else:
489484
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
490485

491486
# Allow to predict in the original domain -- that is, the user is not interested
492487
# in our encoded values
493-
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
488+
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
494489

495490
def predict_proba(self,
496491
X_test: Union[np.ndarray, pd.DataFrame, List],
497492
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
498-
if self.InputValidator is None or not self.InputValidator._is_fitted:
493+
if self.input_validator is None or not self.input_validator._is_fitted:
499494
raise ValueError("predict() is only supported after calling search. Kindly call first "
500495
"the estimator search() method.")
501496
X_test = self.input_validator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,8 @@ def search(
432432
y_test=y_test,
433433
resampling_strategy=self.resampling_strategy,
434434
resampling_strategy_args=self.resampling_strategy_args,
435-
<<<<<<< HEAD
436435
dataset_name=dataset_name,
437436
dataset_compression=self._dataset_compression)
438-
=======
439-
dataset_name=dataset_name
440-
)
441-
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
442437

443438
return self._search(
444439
dataset=self.dataset,
@@ -465,14 +460,14 @@ def predict(
465460
batch_size: Optional[int] = None,
466461
n_jobs: int = 1
467462
) -> np.ndarray:
468-
if self.InputValidator is None or not self.InputValidator._is_fitted:
463+
if self.input_validator is None or not self.input_validator._is_fitted:
469464
raise ValueError("predict() is only supported after calling search. Kindly call first "
470465
"the estimator search() method.")
471466

472-
X_test = self.InputValidator.feature_validator.transform(X_test)
467+
X_test = self.input_validator.feature_validator.transform(X_test)
473468
predicted_values = super().predict(X_test, batch_size=batch_size,
474469
n_jobs=n_jobs)
475470

476471
# Allow to predict in the original domain -- that is, the user is not interested
477472
# in our encoded values
478-
return self.InputValidator.target_validator.inverse_transform(predicted_values)
473+
return self.input_validator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/base_feature_validator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ def _fit(
113113

114114
def _check_data(
115115
self,
116-
X: SUPPORTED_FEAT_TYPES,
116+
X: SupportedFeatTypes,
117117
) -> None:
118118
"""
119119
Feature dimensionality and data type checks
120120
121121
Args:
122-
X (SUPPORTED_FEAT_TYPES):
122+
X (SupportedFeatTypes):
123123
A set of features that are going to be validated (type and dimensionality
124124
checks) and a encoder fitted in the case the data needs encoding
125125
"""
@@ -145,19 +145,19 @@ def transform(
145145

146146
def list_to_pandas(
147147
self,
148-
X_train: SUPPORTED_FEAT_TYPES,
149-
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
148+
X_train: SupportedFeatTypes,
149+
X_test: Optional[SupportedFeatTypes] = None,
150150
) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
151151
"""
152152
Converts a list to a pandas DataFrame. In this process, column types are inferred.
153153
154154
If test data is provided, we proactively match it to train data
155155
156156
Args:
157-
X_train (SUPPORTED_FEAT_TYPES):
157+
X_train (SupportedFeatTypes):
158158
A set of features that are going to be validated (type and dimensionality
159159
checks) and a encoder fitted in the case the data needs encoding
160-
X_test (Optional[SUPPORTED_FEAT_TYPES]):
160+
X_test (Optional[SupportedFeatTypes]):
161161
A hold out set of data used for checking
162162
Returns:
163163
pd.DataFrame:

autoPyTorch/data/tabular_target_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, cast
1+
from typing import List, Optional, Union, cast
22

33
import numpy as np
44

0 commit comments

Comments
 (0)