Skip to content

Commit 1d9ddc0

Browse files
committed
embedding compatible for time series tasks
1 parent c9f1ca7 commit 1d9ddc0

File tree

3 files changed

+38
-24
lines changed

3 files changed

+38
-24
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,34 +72,45 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n
7272

7373
self.ee_layers = self._create_ee_layers()
7474

75-
def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbedding":
75+
def get_partial_models(self,
76+
n_excl_embed_features: int,
77+
idx_embed_feat_partial: List[int]) -> "_LearnedEntityEmbedding":
7678
"""
7779
extract a partial models that only works on a subset of the data that ought to be passed to the embedding
7880
network, this function is implemented for time series forecasting tasks where the known future features is only
7981
a subset of the past features
8082
Args:
81-
subset_features (List[int]):
82-
a set of index identifying which features will pass through the partial model
83+
n_excl_embed_features (int):
84+
number of unembedded features
85+
idx_embed_feat_partial (List[int]):
86+
a set of index identifying the which embedding features will be inherited by the partial model
8387
8488
Returns:
8589
partial_model (_LearnedEntityEmbedding)
8690
a new partial model
8791
"""
88-
num_input_features = self.num_categories_per_col[subset_features]
89-
num_features_excl_embed = sum([sf < self.num_features_excl_embed for sf in subset_features])
92+
n_partial_features = n_excl_embed_features + len(idx_embed_feat_partial)
9093

91-
num_output_dimensions = [self.num_output_dimensions[sf] for sf in subset_features]
92-
embed_features = [self.embed_features[sf] for sf in subset_features]
94+
num_categories_per_col = np.zeros(n_partial_features, dtype=np.int16)
95+
num_output_dimensions = [1] * n_partial_features
9396

9497
ee_layers = []
95-
ee_layer_tracker = 0
96-
for sf in subset_features:
97-
if self.embed_features[sf]:
98-
ee_layers.append(self.ee_layers[ee_layer_tracker])
99-
ee_layer_tracker += 1
98+
for idx, idx_embed in enumerate(idx_embed_feat_partial):
99+
idx_raw = self.num_features_excl_embed + idx_embed
100+
n_embed = self.num_categories_per_col[idx_raw]
101+
n_output = self.num_output_dimensions[idx_raw]
102+
103+
idx_new = n_excl_embed_features + idx
104+
num_categories_per_col[idx_new] = n_embed
105+
num_output_dimensions[idx_new] = n_output
106+
107+
ee_layers.append(self.ee_layers[idx_embed])
108+
100109
ee_layers = nn.ModuleList(ee_layers)
101110

102-
return PartialLearnedEntityEmbedding(num_input_features, num_features_excl_embed, embed_features,
111+
embed_features = num_categories_per_col > 0
112+
113+
return PartialLearnedEntityEmbedding(num_categories_per_col, n_excl_embed_features, embed_features,
103114
num_output_dimensions, ee_layers)
104115

105116
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -108,10 +119,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108119
concat_seq = []
109120

110121
layer_pointer = 0
111-
# Time series tasks need to add targets to the embeddings. However, the target information is not recorded
112-
# by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
113-
# the last few dimensions of the input x
114-
# we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
122+
# Given that our embedding network is only applied to the last few feature columns self.embed_features
115123
x_pointer = 0
116124
for x_pointer, embed in enumerate(self.embed_features):
117125
if not embed:
@@ -121,9 +129,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
121129
current_feature_slice = x[..., x_pointer]
122130
current_feature_slice = current_feature_slice.to(torch.int)
123131
concat_seq.append(self.ee_layers[layer_pointer](current_feature_slice))
124-
layer_pointer += 1
125132

126-
concat_seq.append(x[..., x_pointer:])
133+
layer_pointer += 1
127134

128135
return torch.cat(concat_seq, dim=-1)
129136

autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22

33
from ConfigSpace.configuration_space import ConfigurationSpace
44

@@ -12,7 +12,7 @@
1212

1313

1414
class _NoEmbedding(nn.Module):
15-
def get_partial_models(self, subset_features: List[int]) -> "_NoEmbedding":
15+
def get_partial_models(self, **kwargs: Any) -> "_NoEmbedding":
1616
return self
1717

1818
def forward(self, x: torch.Tensor) -> torch.Tensor:

autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None):
2020
self.embedding: Optional[nn.Module] = None
2121
self.random_state = random_state
2222
self.feature_shapes: Dict[str, int] = {}
23+
self.embed_features_idx: Optional[Tuple] = None
2324

2425
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
2526

@@ -30,22 +31,28 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
3031
num_features_excl_embed=num_features_excl_embed
3132
)
3233
if "feature_shapes" in X['dataset_properties']:
34+
n_features_embedded = len(num_categories_per_col) - num_features_excl_embed
3335
if num_output_features is not None:
3436
feature_shapes = X['dataset_properties']['feature_shapes']
3537
# forecasting tasks
3638
feature_names = X['dataset_properties']['feature_names']
37-
for idx_cat, n_output_cat in enumerate(num_output_features[num_features_excl_embed:]):
38-
cat_feature_name = feature_names[idx_cat + num_features_excl_embed]
39-
feature_shapes[cat_feature_name] = n_output_cat
39+
n_features_all = len(feature_names)
40+
embed_features_idx = tuple(range(n_features_all - n_features_embedded, n_features_all))
41+
for idx, n_output_embedded in zip(embed_features_idx, num_output_features[-n_features_embedded:]):
42+
feat_name = feature_names[idx]
43+
feature_shapes[feat_name] = n_output_embedded
44+
self.embed_features_idx = embed_features_idx
4045
self.feature_shapes = feature_shapes
4146
else:
4247
self.feature_shapes = X['dataset_properties']['feature_shapes']
48+
self.embed_features_idx = []
4349
return self
4450

4551
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
4652
X.update({'network_embedding': self.embedding})
4753
if "feature_shapes" in X['dataset_properties']:
4854
X['dataset_properties'].update({"feature_shapes": self.feature_shapes})
55+
X['embed_features_idx'] = self.embed_features_idx
4956
return X
5057

5158
def build_embedding(self,

0 commit comments

Comments
 (0)