1
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1
+ from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
2
2
3
3
import numpy as np
4
4
11
11
TASK_TYPES_TO_STRING ,
12
12
)
13
13
from autoPyTorch .data .tabular_validator import TabularInputValidator
14
+ from autoPyTorch .data .utils import (
15
+ get_dataset_compression_mapping
16
+ )
14
17
from autoPyTorch .datasets .base_dataset import BaseDatasetPropertiesType
15
18
from autoPyTorch .datasets .resampling_strategy import (
16
19
HoldoutValTypes ,
@@ -163,6 +166,7 @@ def _get_dataset_input_validator(
163
166
resampling_strategy : Optional [ResamplingStrategies ] = None ,
164
167
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
165
168
dataset_name : Optional [str ] = None ,
169
+ dataset_compression : Optional [Mapping [str , Any ]] = None ,
166
170
) -> Tuple [TabularDataset , TabularInputValidator ]:
167
171
"""
168
172
Returns an object of `TabularDataset` and an object of
@@ -199,26 +203,27 @@ def _get_dataset_input_validator(
199
203
200
204
# Create a validator object to make sure that the data provided by
201
205
# the user matches the autopytorch requirements
202
- InputValidator = TabularInputValidator (
206
+ input_validator = TabularInputValidator (
203
207
is_classification = True ,
204
208
logger_port = self ._logger_port ,
209
+ dataset_compression = dataset_compression
205
210
)
206
211
207
212
# Fit a input validator to check the provided data
208
213
# Also, an encoder is fit to both train and test data,
209
214
# to prevent unseen categories during inference
210
- InputValidator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
215
+ input_validator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
211
216
212
217
dataset = TabularDataset (
213
218
X = X_train , Y = y_train ,
214
219
X_test = X_test , Y_test = y_test ,
215
- validator = InputValidator ,
220
+ validator = input_validator ,
216
221
resampling_strategy = resampling_strategy ,
217
222
resampling_strategy_args = resampling_strategy_args ,
218
223
dataset_name = dataset_name
219
224
)
220
225
221
- return dataset , InputValidator
226
+ return dataset , input_validator
222
227
223
228
def search (
224
229
self ,
@@ -234,14 +239,15 @@ def search(
234
239
total_walltime_limit : int = 100 ,
235
240
func_eval_time_limit_secs : Optional [int ] = None ,
236
241
enable_traditional_pipeline : bool = True ,
237
- memory_limit : Optional [ int ] = 4096 ,
242
+ memory_limit : int = 4096 ,
238
243
smac_scenario_args : Optional [Dict [str , Any ]] = None ,
239
244
get_smac_object_callback : Optional [Callable ] = None ,
240
245
all_supported_metrics : bool = True ,
241
246
precision : int = 32 ,
242
247
disable_file_output : Optional [List [Union [str , DisableFileOutputParameters ]]] = None ,
243
248
load_models : bool = True ,
244
249
portfolio_selection : Optional [str ] = None ,
250
+ dataset_compression : Union [Mapping [str , Any ], bool ] = False ,
245
251
) -> 'BaseTask' :
246
252
"""
247
253
Search for the best pipeline configuration for the given dataset.
@@ -310,7 +316,7 @@ def search(
310
316
feature by turning this flag to False. All machine learning
311
317
algorithms that are fitted during search() are considered for
312
318
ensemble building.
313
- memory_limit (Optional[ int] : default=4096):
319
+ memory_limit (int: default=4096):
314
320
Memory limit in MB for the machine learning algorithm.
315
321
Autopytorch will stop fitting the machine learning algorithm
316
322
if it tries to allocate more than memory_limit MB. If None
@@ -368,20 +374,52 @@ def search(
368
374
Additionally, the keyword 'greedy' is supported,
369
375
which would use the default portfolio from
370
376
`AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>`_.
377
+ dataset_compression: Union[bool, Mapping[str, Any]] = True
378
+ We compress datasets so that they fit into some predefined amount of memory.
379
+ **NOTE**
380
+
381
+ Default configuration when left as ``True``:
382
+ .. code-block:: python
383
+ {
384
+ "memory_allocation": 0.1,
385
+ "methods": ["precision"]
386
+ }
387
+ You can also pass your own configuration with the same keys and choosing
388
+ from the available ``"methods"``.
389
+ The available options are described here:
390
+ **memory_allocation**
391
+ By default, we attempt to fit the dataset into ``0.1 * memory_limit``. This
392
+ float value can be set with ``"memory_allocation": 0.1``. We also allow for
393
+ specifying absolute memory in MB, e.g. 10MB is ``"memory_allocation": 10``.
394
+ The memory used by the dataset is checked after each reduction method is
395
+ performed. If the dataset fits into the allocated memory, any further methods
396
+ listed in ``"methods"`` will not be performed.
397
+
398
+ **methods**
399
+ We currently provide the following methods for reducing the dataset size.
400
+ These can be provided in a list and are performed in the order as given.
401
+ * ``"precision"`` - We reduce floating point precision as follows:
402
+ * ``np.float128 -> np.float64``
403
+ * ``np.float96 -> np.float64``
404
+ * ``np.float64 -> np.float32``
405
+ * pandas dataframes are reduced using the downcast option of `pd.to_numeric`
406
+ to the lowest possible precision.
371
407
372
408
Returns:
373
409
self
374
410
375
411
"""
412
+ self ._dataset_compression = get_dataset_compression_mapping (memory_limit , dataset_compression )
376
413
377
- self .dataset , self .InputValidator = self ._get_dataset_input_validator (
414
+ self .dataset , self .input_validator = self ._get_dataset_input_validator (
378
415
X_train = X_train ,
379
416
y_train = y_train ,
380
417
X_test = X_test ,
381
418
y_test = y_test ,
382
419
resampling_strategy = self .resampling_strategy ,
383
420
resampling_strategy_args = self .resampling_strategy_args ,
384
- dataset_name = dataset_name )
421
+ dataset_name = dataset_name ,
422
+ dataset_compression = self ._dataset_compression )
385
423
386
424
return self ._search (
387
425
dataset = self .dataset ,
@@ -418,28 +456,28 @@ def predict(
418
456
Returns:
419
457
Array with estimator predictions.
420
458
"""
421
- if self .InputValidator is None or not self .InputValidator ._is_fitted :
459
+ if self .input_validator is None or not self .input_validator ._is_fitted :
422
460
raise ValueError ("predict() is only supported after calling search. Kindly call first "
423
461
"the estimator search() method." )
424
462
425
- X_test = self .InputValidator .feature_validator .transform (X_test )
463
+ X_test = self .input_validator .feature_validator .transform (X_test )
426
464
predicted_probabilities = super ().predict (X_test , batch_size = batch_size ,
427
465
n_jobs = n_jobs )
428
466
429
- if self .InputValidator .target_validator .is_single_column_target ():
467
+ if self .input_validator .target_validator .is_single_column_target ():
430
468
predicted_indexes = np .argmax (predicted_probabilities , axis = 1 )
431
469
else :
432
470
predicted_indexes = (predicted_probabilities > 0.5 ).astype (int )
433
471
434
472
# Allow to predict in the original domain -- that is, the user is not interested
435
473
# in our encoded values
436
- return self .InputValidator .target_validator .inverse_transform (predicted_indexes )
474
+ return self .input_validator .target_validator .inverse_transform (predicted_indexes )
437
475
438
476
def predict_proba (self ,
439
477
X_test : Union [np .ndarray , pd .DataFrame , List ],
440
478
batch_size : Optional [int ] = None , n_jobs : int = 1 ) -> np .ndarray :
441
- if self .InputValidator is None or not self .InputValidator ._is_fitted :
479
+ if self .input_validator is None or not self .input_validator ._is_fitted :
442
480
raise ValueError ("predict() is only supported after calling search. Kindly call first "
443
481
"the estimator search() method." )
444
- X_test = self .InputValidator .feature_validator .transform (X_test )
482
+ X_test = self .input_validator .feature_validator .transform (X_test )
445
483
return super ().predict (X_test , batch_size = batch_size , n_jobs = n_jobs )
0 commit comments