Skip to content

Commit 9d62c2b

Browse files
committed
maint
1 parent 1d9ddc0 commit 9d62c2b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
119119
concat_seq = []
120120

121121
layer_pointer = 0
122-
# Given that our embedding network is only applied to the last few feature columns self.embed_features
122+
# Time series tasks need to add targets to the embeddings. However, the target information is not recorded
123+
# by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
124+
# the last few dimensions of the input x
125+
# we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
123126
x_pointer = 0
124127
for x_pointer, embed in enumerate(self.embed_features):
125128
if not embed:

0 commit comments

Comments
 (0)