Skip to content

Commit 5f9713a

Browse files
committed
adapt embedding for forecasting tasks
1 parent 9d62c2b commit 5f9713a

File tree

9 files changed

+103
-47
lines changed

9 files changed

+103
-47
lines changed

autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/TimeSeriesTransformer.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N
2424
self.add_fit_requirements([
2525
FitRequirement('numerical_features', (List,), user_defined=True, dataset_property=True),
2626
FitRequirement('categorical_features', (List,), user_defined=True, dataset_property=True)])
27+
self.output_feature_order = None
2728

2829
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
2930
"""
@@ -74,6 +75,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
7475
X_train = X['backend'].load_datamanager().train_tensors[0]
7576

7677
self.preprocessor.fit(X_train)
78+
self.output_feature_order = self.get_output_column_orders(len(X['dataset_properties']['feature_names']))
7779
return self
7880

7981
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
@@ -86,7 +88,8 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
8688
Returns:
8789
X (Dict[str, Any]): updated fit dictionary
8890
"""
89-
X.update({'time_series_feature_transformer': self})
91+
X.update({'time_series_feature_transformer': self,
92+
'feature_order_after_preprocessing': self.output_feature_order})
9093
return X
9194

9295
def __call__(self, X: pd.DataFrame) -> pd.DataFrame:
@@ -108,6 +111,33 @@ def get_column_transformer(self) -> ColumnTransformer:
108111
.format(self.__class__.__name__))
109112
return self.preprocessor
110113

114+
def get_output_column_orders(self, n_input_columns: int) -> List[int]:
115+
"""
116+
get the order of the output features transformed by self.preprocessor
117+
TODO: replace this function with self.preprocessor.get_feature_names_out() when switch to sklearn 1.0 !
118+
119+
Args:
120+
n_input_columns (int): number of input columns that will be transformed
121+
122+
Returns:
123+
np.ndarray: a list of index indicating the order of each columns after transformation. Its length should
124+
equal to n_input_columns
125+
"""
126+
if self.preprocessor is None:
127+
raise ValueError("cant call {} without fitting the column transformer first."
128+
.format(self.__class__.__name__))
129+
transformers = self.preprocessor.transformers
130+
131+
n_reordered_input = np.arange(n_input_columns)
132+
processed_columns = np.asarray([], dtype=np.int)
133+
134+
for tran in transformers:
135+
trans_columns = np.array(tran[-1], dtype=np.int)
136+
unprocessed_columns = np.setdiff1d(processed_columns, trans_columns)
137+
processed_columns = np.hstack([unprocessed_columns, trans_columns])
138+
unprocessed_columns = np.setdiff1d(n_reordered_input, processed_columns)
139+
return np.hstack([processed_columns, unprocessed_columns]).tolist() # type: ignore[return-value]
140+
111141

112142
class TimeSeriesTargetTransformer(autoPyTorchTimeSeriesTargetPreprocessingComponent):
113143
def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None):

autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/ColumnSplitter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def __init__(
2525

2626
def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'TimeSeriesColumnSplitter':
2727
super(TimeSeriesColumnSplitter, self).fit(X, y)
28+
2829
self.num_categories_per_col_encoded = X['dataset_properties']['num_categories_per_col']
30+
for i in range(len(self.num_categories_per_col_encoded)):
31+
if i in self.special_feature_types['embed_columns']:
32+
self.num_categories_per_col_encoded[i] = 1
2933
return self
3034

3135
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:

autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,12 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None
2222
FitRequirement('X_train', (pd.DataFrame, ), user_defined=True,
2323
dataset_property=False),
2424
FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True),
25-
FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True),
26-
FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True),
25+
FitRequirement('feature_order_after_preprocessing', (List,), user_defined=False, dataset_property=False)
2726
])
2827

2928
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
3029
"""
3130
if dataset is small process, we transform the entire dataset here.
32-
Before transformation, the order of the dataset is:
33-
[(unknown_columns), categorical_columns, numerical_columns]
34-
While after transformation, the order of the dataset is:
35-
[numerical_columns, categorical_columns, unknown_columns]
36-
we need to change feature_names and feature_shapes accordingly
3731
3832
Args:
3933
X(Dict): fit dictionary
@@ -52,20 +46,9 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
5246
X['X_train'] = time_series_preprocess(dataset=X_train, transforms=transforms)
5347

5448
feature_names = X['dataset_properties']['feature_names']
55-
numerical_columns = X['dataset_properties']['numerical_columns']
56-
categorical_columns = X['dataset_properties']['categorical_columns']
57-
# encoding_columns = X['dataset_properties']['encoding_columns']
58-
encode_columns = X['encode_columns']
59-
import pdb
60-
pdb.set_trace()
61-
62-
# resort feature_names
63-
# Previously, the categorical features are sorted before numerical features. However,
64-
# After the preprocessing. The numerical features are sorted at the first place.
65-
new_feature_names = [feature_names[num_col] for num_col in numerical_columns]
66-
new_feature_names += [feature_names[cat_col] for cat_col in categorical_columns]
67-
if set(feature_names) != set(new_feature_names):
68-
new_feature_names += list(set(feature_names) - set(new_feature_names))
49+
50+
feature_order_after_preprocessing = X['feature_order_after_preprocessing']
51+
new_feature_names = (feature_names[i] for i in feature_order_after_preprocessing)
6952
X['dataset_properties']['feature_names'] = tuple(new_feature_names)
7053

7154
preprocessed_dtype = get_preprocessed_dtype(X['X_train'])

autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5+
import numpy as np
56
import torch
67
from torch import nn
78
from torch.distributions import AffineTransform, TransformedDistribution
@@ -205,6 +206,7 @@ def __init__(self,
205206
auto_regressive: bool,
206207
feature_names: Union[Tuple[str], Tuple[()]] = (),
207208
known_future_features: Union[Tuple[str], Tuple[()]] = (),
209+
embed_features_idx: Tuple[int] = (),
208210
feature_shapes: Dict[str, int] = {},
209211
static_features: Union[Tuple[str], Tuple[()]] = (),
210212
time_feature_names: Union[Tuple[str], Tuple[()]] = (),
@@ -218,7 +220,16 @@ def __init__(self,
218220
self.embedding = network_embedding
219221
if len(known_future_features) > 0:
220222
known_future_features_idx = [feature_names.index(kff) for kff in known_future_features]
221-
self.decoder_embedding = self.embedding.get_partial_models(known_future_features_idx)
223+
known_future_embed_features = np.where(
224+
np.in1d(embed_features_idx, known_future_features_idx, assume_unique=True)
225+
)[0]
226+
idx_excl_embed_future_features = np.setdiff1d(known_future_features_idx, embed_features_idx)
227+
n_excl_embed_features = sum(feature_shapes[feature_names[i]] for i in idx_excl_embed_future_features)
228+
229+
self.decoder_embedding = self.embedding.get_partial_models(
230+
n_excl_embed_features=n_excl_embed_features,
231+
idx_embed_feat_partial=known_future_embed_features
232+
)
222233
else:
223234
self.decoder_embedding = _NoEmbedding()
224235
# modules that generate tensors while doing forward pass
@@ -558,7 +569,7 @@ def pre_processing(self,
558569
return x_past, x_future, x_static, loc, scale, static_context_initial_hidden, past_targets
559570
else:
560571
if past_features is not None:
561-
x_past = torch.cat([truncated_past_targets, past_features], dim=-1).to(device=self.device)
572+
x_past = torch.cat([past_features, truncated_past_targets], dim=-1).to(device=self.device)
562573
x_past = self.embedding(x_past.to(device=self.device))
563574
else:
564575
x_past = self.embedding(truncated_past_targets.to(device=self.device))
@@ -615,8 +626,8 @@ def forward(self,
615626
return self.rescale_output(output, loc, scale, self.device)
616627

617628
def _unwrap_past_targets(
618-
self,
619-
past_targets: dict
629+
self,
630+
past_targets: dict
620631
) -> Tuple[torch.Tensor,
621632
Optional[torch.Tensor],
622633
Optional[torch.Tensor],

autoPyTorch/pipeline/components/setup/network/forecasting_network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
FitRequirement("auto_regressive", (bool,), user_defined=False, dataset_property=False),
4545
FitRequirement("target_scaler", (BaseTargetScaler,), user_defined=False, dataset_property=False),
4646
FitRequirement("net_output_type", (str,), user_defined=False, dataset_property=False),
47+
FitRequirement('embed_features_idx', (tuple,), user_defined=False, dataset_property=False),
4748
FitRequirement("feature_names", (Iterable,), user_defined=False, dataset_property=True),
4849
FitRequirement("feature_shapes", (Iterable,), user_defined=False, dataset_property=True),
4950
FitRequirement('transform_time_features', (bool,), user_defined=False, dataset_property=False),
@@ -85,6 +86,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
8586
feature_names=feature_names,
8687
feature_shapes=feature_shapes,
8788
known_future_features=known_future_features,
89+
embed_features_idx=X['embed_features_idx'],
8890
time_feature_names=time_feature_names,
8991
static_features=X['dataset_properties']['static_features']
9092
)

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,37 @@ def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, n
7272

7373
self.ee_layers = self._create_ee_layers()
7474

75+
def insert_new_input_features(self, n_new_features: int):
76+
"""
77+
Time series tasks need to add targets to the embeddings. However, the target information is not recorded
78+
by autoPyTorch's embeddings. Therefore, we need to add the targets to the input features manually, which is
79+
located in front of the features
80+
81+
Args:
82+
n_new_features (int):
83+
number of new features that is inserted in front of the input features
84+
"""
85+
self.num_categories_per_col = np.hstack([np.zeros(n_new_features, dtype=np.int16), self.num_categories_per_col])
86+
self.embed_features = np.hstack([np.zeros(n_new_features, dtype=np.bool), self.num_categories_per_col])
87+
88+
self.num_features_excl_embed += n_new_features
89+
self.num_output_dimensions = [1] * n_new_features + self.num_output_dimensions
90+
self.num_out_feats += n_new_features
91+
7592
def get_partial_models(self,
7693
n_excl_embed_features: int,
7794
idx_embed_feat_partial: List[int]) -> "_LearnedEntityEmbedding":
7895
"""
7996
extract a partial models that only works on a subset of the data that ought to be passed to the embedding
8097
network, this function is implemented for time series forecasting tasks where the known future features is only
8198
a subset of the past features
99+
82100
Args:
83101
n_excl_embed_features (int):
84102
number of unembedded features
85103
idx_embed_feat_partial (List[int]):
86-
a set of index identifying the which embedding features will be inherited by the partial model
104+
a set of index identifying the which embedding features will be inherited by the partial model. This
105+
index is used to extract self.ee_layers
87106
88107
Returns:
89108
partial_model (_LearnedEntityEmbedding)
@@ -119,11 +138,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
119138
concat_seq = []
120139

121140
layer_pointer = 0
122-
# Time series tasks need to add targets to the embeddings. However, the target information is not recorded
123-
# by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
124-
# the last few dimensions of the input x
125-
# we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
126141
x_pointer = 0
142+
# For forcasting architectures,besides the input features, we might also need to feed targets and time features
143+
# to the embedding layers, which are not counted by self.embed_features.
127144
for x_pointer, embed in enumerate(self.embed_features):
128145
if not embed:
129146
current_feature_slice = x[..., [x_pointer]]
@@ -134,6 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
134151
concat_seq.append(self.ee_layers[layer_pointer](current_feature_slice))
135152

136153
layer_pointer += 1
154+
concat_seq.append(x[..., x_pointer + 1:])
137155

138156
return torch.cat(concat_seq, dim=-1)
139157

autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313

1414
class _NoEmbedding(nn.Module):
15-
def get_partial_models(self, **kwargs: Any) -> "_NoEmbedding":
15+
def get_partial_models(self, *args, **kwargs) -> "_NoEmbedding":
16+
return self
17+
18+
def insert_new_input_features(self, *args, **kwargs) -> "_NoEmbedding":
1619
return self
1720

1821
def forward(self, x: torch.Tensor) -> torch.Tensor:

autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
3737
# forecasting tasks
3838
feature_names = X['dataset_properties']['feature_names']
3939
n_features_all = len(feature_names)
40+
# embedded feature index
4041
embed_features_idx = tuple(range(n_features_all - n_features_embedded, n_features_all))
4142
for idx, n_output_embedded in zip(embed_features_idx, num_output_features[-n_features_embedded:]):
4243
feat_name = feature_names[idx]

test/test_pipeline/components/setup/forecasting/forecasting_networks/test_forecasting_architecture.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@
2424
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdate
2525

2626

27-
class ReducedEmbedding(torch.nn.Module):
27+
class IncrementalEmbedding(torch.nn.Module):
2828
# a dummy reduced embedding, it simply cut row for each categorical features
29-
def __init__(self, num_input_features, num_numerical_features: int):
30-
super(ReducedEmbedding, self).__init__()
31-
self.num_input_features = num_input_features
32-
self.num_numerical_features = num_numerical_features
33-
self.n_cat_features = len(num_input_features) - num_numerical_features
29+
def __init__(self, n_excl_embed_features, embed_feat_idx):
30+
super(IncrementalEmbedding, self).__init__()
31+
self.n_excl_embed_features = n_excl_embed_features
32+
self.embed_feat_idx = embed_feat_idx
3433

3534
def forward(self, x):
36-
x = x[..., :-self.n_cat_features]
35+
if len(self.embed_feat_idx) > 0:
36+
x = torch.cat([x, x[..., -len(self.embed_feat_idx):]], dim=-1)
3737
return x
3838

39-
def get_partial_models(self, subset_features):
40-
num_numerical_features = sum([sf < self.num_numerical_features for sf in subset_features])
41-
num_input_features = [self.num_input_features[sf] for sf in subset_features]
42-
return ReducedEmbedding(num_input_features, num_numerical_features)
39+
def get_partial_models(self, n_excl_embed_features, idx_embed_feat_partial):
40+
n_excl_embed_features = n_excl_embed_features
41+
embed_feat_idx = [self.embed_feat_idx[idx] for idx in idx_embed_feat_partial]
42+
return IncrementalEmbedding(n_excl_embed_features, embed_feat_idx)
4343

4444

4545
@pytest.fixture(params=['ForecastingNet', 'ForecastingSeq2SeqNet', 'ForecastingDeepARNet', 'NBEATSNet'])
@@ -52,7 +52,7 @@ def network_encoder(request):
5252
return request.param
5353

5454

55-
@pytest.fixture(params=['ReducedEmbedding', 'NoEmbedding'])
55+
@pytest.fixture(params=['IncrementalEmbedding', 'NoEmbedding'])
5656
def embedding(request):
5757
return request.param
5858

@@ -110,7 +110,7 @@ def test_network_forward(self,
110110
dataset_properties['known_future_features'] = ('f1', 'f3', 'f5')
111111

112112
if with_static_features:
113-
dataset_properties['static_features'] = (0, 4)
113+
dataset_properties['static_features'] = (0, 3)
114114
else:
115115
dataset_properties['static_features'] = tuple()
116116

@@ -130,10 +130,14 @@ def test_network_forward(self,
130130
fit_dictionary['net_output_type'] = net_output_type
131131

132132
if embedding == 'NoEmbedding':
133+
embed_features_idx = ()
133134
fit_dictionary['network_embedding'] = _NoEmbedding()
135+
fit_dictionary['embed_features_idx'] = embed_features_idx
134136
else:
135-
fit_dictionary['network_embedding'] = ReducedEmbedding([10] * 5, 2)
136-
dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 9, 'f4': 9, 'f5': 9}
137+
embed_features_idx = (3, 4)
138+
fit_dictionary['network_embedding'] = IncrementalEmbedding(50, embed_features_idx)
139+
fit_dictionary['embed_features_idx'] = embed_features_idx
140+
dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 10, 'f4': 11, 'f5': 11}
137141

138142
if uni_variant_data:
139143
fit_dictionary['X_train'] = None

0 commit comments

Comments
 (0)