Skip to content

Commit 8e15eec

Browse files
committed
Reg cocktails apt1.0+reg cocktails pytorch embedding reduced (#454)
* reduce number of hyperparameters for pytorch embedding * remove todos for the preprocessing PR, and apply suggestion from code review * remove unwanted exclude in test
1 parent d58dd9d commit 8e15eec

File tree

5 files changed

+55
-37
lines changed

5 files changed

+55
-37
lines changed

autoPyTorch/api/base_task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def build_pipeline(
273273
) -> BasePipeline:
274274
"""
275275
Build pipeline according to current task
276+
and for the passed dataset properties
277+
278+
Args:
279+
dataset_properties (Dict[str, Any]):
276280
Characteristics of the dataset to guide the pipeline
277281
choices of components
278282
include_components (Optional[Dict[str, Any]]):

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from ConfigSpace.configuration_space import ConfigurationSpace
55
from ConfigSpace.hyperparameters import (
66
UniformFloatHyperparameter,
7+
UniformIntegerHyperparameter,
8+
CategoricalHyperparameter
79
)
810

911
import numpy as np
@@ -16,6 +18,34 @@
1618
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1719

1820

21+
def get_num_output_dimensions(config: Dict[str, Any], num_categs_per_feature: List[int]) -> List[int]:
22+
"""
23+
Returns list of embedding sizes for each categorical variable.
24+
Selects this adaptively based on training_datset.
25+
Note: Assumes there is at least one embed feature.
26+
Args:
27+
config (Dict[str, Any]):
28+
contains the hyperparameters required to calculate the `num_output_dimensions`
29+
num_categs_per_feature (List[int]):
30+
list containing number of categories for each feature that is to be embedded,
31+
0 if the column is not an embed column
32+
Returns:
33+
List[int]:
34+
list containing the output embedding size for each column,
35+
1 if the column is not an embed column
36+
"""
37+
38+
max_embedding_dim = config['max_embedding_dim']
39+
embed_exponent = config['embed_exponent']
40+
size_factor = config['embedding_size_factor']
41+
num_output_dimensions = [int(size_factor*max(
42+
2,
43+
min(max_embedding_dim,
44+
1.6 * num_categories**embed_exponent)))
45+
if num_categories > 0 else 1 for num_categories in num_categs_per_feature]
46+
return num_output_dimensions
47+
48+
1949
class _LearnedEntityEmbedding(nn.Module):
2050
""" Learned entity embedding module for categorical features"""
2151

@@ -35,9 +65,7 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n
3565

3666
self.num_embed_features = self.num_categories_per_col[self.embed_features]
3767

38-
self.num_output_dimensions = [1] * num_features_excl_embed
39-
self.num_output_dimensions.extend([ceil(config["dimension_reduction_" + str(i)] * num_in) for i, num_in in
40-
enumerate(self.num_embed_features)])
68+
self.num_output_dimensions = get_num_output_dimensions(config, self.num_categories_per_col)
4169

4270
self.num_out_feats = num_features_excl_embed + sum(self.num_output_dimensions)
4371

@@ -78,12 +106,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
78106
# before passing it through the model
79107
concat_seq = []
80108

81-
x_pointer = 0
82109
layer_pointer = 0
83110
for x_pointer, embed in enumerate(self.embed_features):
84111
current_feature_slice = x[:, x_pointer]
85112
if not embed:
86-
x_pointer += 1
87113
concat_seq.append(current_feature_slice.view(-1, 1))
88114
continue
89115
current_feature_slice = current_feature_slice.to(torch.int)
@@ -153,28 +179,24 @@ def build_embedding(self, num_categories_per_col: np.ndarray, num_features_excl_
153179
@staticmethod
154180
def get_hyperparameter_search_space(
155181
dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None,
156-
dimension_reduction: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="dimension_reduction",
157-
value_range=(0, 1),
158-
default_value=0.5),
182+
embed_exponent: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="embed_exponent",
183+
value_range=(0.56,),
184+
default_value=0.56),
185+
max_embedding_dim: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="max_embedding_dim",
186+
value_range=(100,),
187+
default_value=100),
188+
embedding_size_factor: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="embedding_size_factor",
189+
value_range=(0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5),
190+
default_value=1,
191+
),
159192
) -> ConfigurationSpace:
160193
cs = ConfigurationSpace()
161194
if dataset_properties is not None:
162-
for i in range(len(dataset_properties['categorical_columns'])
163-
if isinstance(dataset_properties['categorical_columns'], List) else 0):
164-
# currently as we dont have information about the embedding columns
165-
# we search for more dimensions than necessary. This can be solved by
166-
# not having `min_unique_values_for_embedding` as a hyperparameter and
167-
# instead passing it as a parameter to the feature validator, which
168-
# allows us to pass embed_columns to the dataset properties.
169-
# TODO: test the trade off
170-
# Another solution is to combine `OneHotEncoding`, `Embedding` and `NoEncoding`
171-
# in one custom transformer. this will also allow users to use this transformer
172-
# outside the pipeline
173-
ee_dimensions_search_space = HyperparameterSearchSpace(hyperparameter="dimension_reduction_" + str(i),
174-
value_range=dimension_reduction.value_range,
175-
default_value=dimension_reduction.default_value,
176-
log=dimension_reduction.log)
177-
add_hyperparameter(cs, ee_dimensions_search_space, UniformFloatHyperparameter)
195+
if len(dataset_properties['categorical_columns']) > 0:
196+
add_hyperparameter(cs, embed_exponent, UniformFloatHyperparameter)
197+
add_hyperparameter(cs, max_embedding_dim, UniformIntegerHyperparameter)
198+
add_hyperparameter(cs, embedding_size_factor, CategoricalHyperparameter)
199+
178200
return cs
179201

180202
@staticmethod

test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
)
1414

1515

16-
# TODO: fix in preprocessing PR
17-
# @pytest.mark.skip("Skipping tests as preprocessing is not finalised")
1816
@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_numerical_only',
1917
'classification_categorical_only',
2018
'classification_numerical_and_categorical'], indirect=True)

test/test_pipeline/components/setup/test_setup_networks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def head(request):
1919
return request.param
2020

2121

22-
# TODO: add 'LearnedEntityEmbedding' after preprocessing dix
2322
@pytest.fixture(params=['NoEmbedding', 'LearnedEntityEmbedding'])
2423
def embedding(request):
2524
return request.param

test/test_pipeline/test_tabular_regression.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,9 @@ def test_pipeline_fit(self, fit_dictionary_tabular):
6161
"""This test makes sure that the pipeline is able to fit
6262
given random combinations of hyperparameters across the pipeline"""
6363
# TODO: fix issue where adversarial also works for regression
64-
# TODO: Fix issue with learned entity embedding after preprocessing PR
6564
pipeline = TabularRegressionPipeline(
6665
dataset_properties=fit_dictionary_tabular['dataset_properties'],
67-
exclude={'trainer': ['AdversarialTrainer'],
68-
'network_embedding': ['LearnedEntityEmbedding']})
66+
exclude={'trainer': ['AdversarialTrainer']})
6967
cs = pipeline.get_hyperparameter_search_space()
7068

7169
config = cs.sample_configuration()
@@ -91,8 +89,7 @@ def test_pipeline_predict(self, fit_dictionary_tabular):
9189
X = fit_dictionary_tabular['X_train'].copy()
9290
pipeline = TabularRegressionPipeline(
9391
dataset_properties=fit_dictionary_tabular['dataset_properties'],
94-
exclude={'trainer': ['AdversarialTrainer'],
95-
'network_embedding': ['LearnedEntityEmbedding']})
92+
exclude={'trainer': ['AdversarialTrainer']})
9693

9794
cs = pipeline.get_hyperparameter_search_space()
9895
config = cs.sample_configuration()
@@ -121,8 +118,7 @@ def test_pipeline_transform(self, fit_dictionary_tabular):
121118

122119
pipeline = TabularRegressionPipeline(
123120
dataset_properties=fit_dictionary_tabular['dataset_properties'],
124-
exclude={'trainer': ['AdversarialTrainer'],
125-
'network_embedding': ['LearnedEntityEmbedding']})
121+
exclude={'trainer': ['AdversarialTrainer']})
126122
cs = pipeline.get_hyperparameter_search_space()
127123
config = cs.sample_configuration()
128124
pipeline.set_hyperparameters(config)
@@ -139,11 +135,10 @@ def test_pipeline_transform(self, fit_dictionary_tabular):
139135
assert fit_dictionary_tabular.items() <= transformed_fit_dictionary_tabular.items()
140136

141137
# Then the pipeline should have added the following keys
142-
# Removing 'imputer', 'encoder', 'scaler', these will be
143-
# TODO: added back after a PR fixing preprocessing
144138
expected_keys = {'tabular_transformer', 'preprocess_transforms', 'network',
145139
'optimizer', 'lr_scheduler', 'train_data_loader',
146-
'val_data_loader', 'run_summary', 'feature_preprocessor'}
140+
'val_data_loader', 'run_summary', 'feature_preprocessor',
141+
'imputer', 'encoder', 'scaler'}
147142
assert expected_keys.issubset(set(transformed_fit_dictionary_tabular.keys()))
148143

149144
# Then we need to have transformations being created.

0 commit comments

Comments
 (0)