Skip to content

Commit 49d49c2

Browse files
committed
fix embeddings after rebase
1 parent 8e15eec commit 49d49c2

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n
6262
# or 0 for numerical data
6363
self.num_categories_per_col = num_categories_per_col
6464
self.embed_features = self.num_categories_per_col > 0
65+
self.num_features_excl_embed = num_features_excl_embed
6566

6667
self.num_embed_features = self.num_categories_per_col[self.embed_features]
6768

@@ -84,8 +85,8 @@ def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbed
8485
partial_model (_LearnedEntityEmbedding)
8586
a new partial model
8687
"""
87-
num_input_features = self.num_input_features[subset_features]
88-
num_numerical_features = sum([sf < self.num_numerical for sf in subset_features])
88+
num_input_features = self.num_categories_per_col[subset_features]
89+
num_features_excl_embed = sum([sf < self.num_features_excl_embed for sf in subset_features])
8990

9091
num_output_dimensions = [self.num_output_dimensions[sf] for sf in subset_features]
9192
embed_features = [self.embed_features[sf] for sf in subset_features]
@@ -98,7 +99,7 @@ def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbed
9899
ee_layer_tracker += 1
99100
ee_layers = nn.ModuleList(ee_layers)
100101

101-
return PartialLearnedEntityEmbedding(num_input_features, num_numerical_features, embed_features,
102+
return PartialLearnedEntityEmbedding(num_input_features, num_features_excl_embed, embed_features,
102103
num_output_dimensions, ee_layers)
103104

104105
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -136,28 +137,27 @@ class PartialLearnedEntityEmbedding(_LearnedEntityEmbedding):
136137
of the input features. This is applied to forecasting tasks where not all the features might be known beforehand
137138
"""
138139
def __init__(self,
139-
num_input_features: np.ndarray,
140-
num_numerical_features: int,
140+
num_categories_per_col: np.ndarray,
141+
num_features_excl_embed: int,
141142
embed_features: List[bool],
142143
num_output_dimensions: List[int],
143144
ee_layers: nn.Module
144145
):
145146
super(_LearnedEntityEmbedding, self).__init__()
146-
self.num_numerical = num_numerical_features
147+
self.num_features_excl_embed = num_features_excl_embed
147148
# list of number of categories of categorical data
148149
# or 0 for numerical data
149-
self.num_input_features = num_input_features
150-
categorical_features: np.ndarray = self.num_input_features > 0
151-
152-
self.num_categorical_features = self.num_input_features[categorical_features]
150+
self.num_categories_per_col = num_categories_per_col
153151

154152
self.embed_features = embed_features
155153

156154
self.num_output_dimensions = num_output_dimensions
157-
self.num_out_feats = self.num_numerical + sum(self.num_output_dimensions)
155+
self.num_out_feats = self.num_features_excl_embed + sum(self.num_output_dimensions)
158156

159157
self.ee_layers = ee_layers
160158

159+
self.num_embed_features = self.num_categories_per_col[self.embed_features]
160+
161161

162162
class LearnedEntityEmbedding(NetworkEmbeddingComponent):
163163
"""

0 commit comments

Comments
 (0)