Skip to content

Commit 484ead4

Browse files
committed
fixes after rebase
1 parent 9c28f3a commit 484ead4

File tree

5 files changed

+15
-34
lines changed

5 files changed

+15
-34
lines changed

autoPyTorch/api/tabular_classification.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,8 @@ def search(
418418
y_test=y_test,
419419
resampling_strategy=self.resampling_strategy,
420420
resampling_strategy_args=self.resampling_strategy_args,
421-
<<<<<<< HEAD
422421
dataset_name=dataset_name,
423422
dataset_compression=self._dataset_compression)
424-
=======
425-
dataset_name=dataset_name
426-
)
427-
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
428423

429424
return self._search(
430425
dataset=self.dataset,
@@ -465,23 +460,23 @@ def predict(
465460
raise ValueError("predict() is only supported after calling search. Kindly call first "
466461
"the estimator search() method.")
467462

468-
X_test = self.InputValidator.feature_validator.transform(X_test)
463+
X_test = self.input_validator.feature_validator.transform(X_test)
469464
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
470465
n_jobs=n_jobs)
471466

472-
if self.InputValidator.target_validator.is_single_column_target():
467+
if self.input_validator.target_validator.is_single_column_target():
473468
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
474469
else:
475470
predicted_indexes = (predicted_probabilities > 0.5).astype(int)
476471

477472
# Allow to predict in the original domain -- that is, the user is not interested
478473
# in our encoded values
479-
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
474+
return self.input_validator.target_validator.inverse_transform(predicted_indexes)
480475

481476
def predict_proba(self,
482477
X_test: Union[np.ndarray, pd.DataFrame, List],
483478
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
484-
if self.InputValidator is None or not self.InputValidator._is_fitted:
479+
if self.input_validator is None or not self.input_validator._is_fitted:
485480
raise ValueError("predict() is only supported after calling search. Kindly call first "
486481
"the estimator search() method.")
487482
X_test = self.input_validator.feature_validator.transform(X_test)

autoPyTorch/api/tabular_regression.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,8 @@ def search(
419419
y_test=y_test,
420420
resampling_strategy=self.resampling_strategy,
421421
resampling_strategy_args=self.resampling_strategy_args,
422-
<<<<<<< HEAD
423422
dataset_name=dataset_name,
424423
dataset_compression=self._dataset_compression)
425-
=======
426-
dataset_name=dataset_name
427-
)
428-
>>>>>>> [FIX] Enable preprocessing in reg_cocktails (#369)
429424

430425
return self._search(
431426
dataset=self.dataset,
@@ -452,14 +447,14 @@ def predict(
452447
batch_size: Optional[int] = None,
453448
n_jobs: int = 1
454449
) -> np.ndarray:
455-
if self.InputValidator is None or not self.InputValidator._is_fitted:
450+
if self.input_validator is None or not self.input_validator._is_fitted:
456451
raise ValueError("predict() is only supported after calling search. Kindly call first "
457452
"the estimator search() method.")
458453

459-
X_test = self.InputValidator.feature_validator.transform(X_test)
454+
X_test = self.input_validator.feature_validator.transform(X_test)
460455
predicted_values = super().predict(X_test, batch_size=batch_size,
461456
n_jobs=n_jobs)
462457

463458
# Allow to predict in the original domain -- that is, the user is not interested
464459
# in our encoded values
465-
return self.InputValidator.target_validator.inverse_transform(predicted_values)
460+
return self.input_validator.target_validator.inverse_transform(predicted_values)

autoPyTorch/data/base_feature_validator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@ def _fit(
112112

113113
def _check_data(
114114
self,
115-
X: SUPPORTED_FEAT_TYPES,
115+
X: SupportedFeatTypes,
116116
) -> None:
117117
"""
118118
Feature dimensionality and data type checks
119119
120120
Args:
121-
X (SUPPORTED_FEAT_TYPES):
121+
X (SupportedFeatTypes):
122122
A set of features that are going to be validated (type and dimensionality
123123
checks) and a encoder fitted in the case the data needs encoding
124124
"""
@@ -144,19 +144,19 @@ def transform(
144144

145145
def list_to_pandas(
146146
self,
147-
X_train: SUPPORTED_FEAT_TYPES,
148-
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
147+
X_train: SupportedFeatTypes,
148+
X_test: Optional[SupportedFeatTypes] = None,
149149
) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
150150
"""
151151
Converts a list to a pandas DataFrame. In this process, column types are inferred.
152152
153153
If test data is provided, we proactively match it to train data
154154
155155
Args:
156-
X_train (SUPPORTED_FEAT_TYPES):
156+
X_train (SupportedFeatTypes):
157157
A set of features that are going to be validated (type and dimensionality
158158
checks) and a encoder fitted in the case the data needs encoding
159-
X_test (Optional[SUPPORTED_FEAT_TYPES]):
159+
X_test (Optional[SupportedFeatTypes]):
160160
A hold out set of data used for checking
161161
Returns:
162162
pd.DataFrame:

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import functools
2-
<<<<<<< HEAD
32
from logging import Logger
4-
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast
5-
=======
6-
from typing import Dict, List, Optional, Tuple, Type, Union, cast
7-
>>>>>>> [FIX] Tests after rebase of `reg_cocktails` (#359)
3+
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union, cast
84

95
import numpy as np
106

@@ -283,13 +279,8 @@ def transform(
283279
if isinstance(X, np.ndarray):
284280
X = self.numpy_to_pandas(X)
285281

286-
<<<<<<< HEAD
287282
if hasattr(X, "iloc") and not issparse(X):
288-
X = cast(pd.DataFrame, X)
289-
=======
290-
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
291283
X = cast(Type[pd.DataFrame], X)
292-
>>>>>>> [FIX] Tests after rebase of `reg_cocktails` (#359)
293284

294285
if self.all_nan_columns is None:
295286
raise ValueError('_fit must be called before calling transform')

autoPyTorch/data/tabular_target_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, cast
1+
from typing import List, Optional, Union, cast
22

33
import numpy as np
44

0 commit comments

Comments
 (0)