@@ -70,10 +70,65 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
7070
7171 Returns
7272 -------
73- x: dict[str, torch.Tensor]
74- A dictionary containing the processed data.
75- y: torch.Tensor
76- The target variable.
73+ x : dict[str, torch.Tensor]
74+ Dict containing processed inputs for the model, with the following keys:
75+
76+ * ``history_cont`` : torch.Tensor of shape
77+ (context_length, n_history_cont_features)
78+ Continuous features for the encoder (historical data).
79+ * ``history_cat`` : torch.Tensor of shape
80+ (context_length, n_history_cat_features)
81+ Categorical features for the encoder (historical data).
82+ * ``future_cont`` : torch.Tensor of shape
83+ (prediction_length, n_future_cont_features)
84+ Known continuous features for the decoder (future data).
85+ * ``future_cat`` : torch.Tensor of shape
86+ (prediction_length, n_future_cat_features)
87+ Known categorical features for the decoder (future data).
88+ * ``history_length`` : torch.Tensor of shape (1,)
89+ Length of the encoder sequence.
90+ * ``future_length`` : torch.Tensor of shape (1,)
91+ Length of the decoder sequence.
92+ * ``history_mask`` : torch.Tensor of shape (context_length,)
93+ Boolean mask indicating valid encoder time points.
94+ * ``future_mask`` : torch.Tensor of shape (prediction_length,)
95+ Boolean mask indicating valid decoder time points.
96+ * ``groups`` : torch.Tensor of shape (1,)
97+ Group identifier for the time series instance.
98+ * ``history_time_idx`` : torch.Tensor of shape (context_length,)
99+ Time indices for the encoder sequence.
100+ * ``future_time_idx`` : torch.Tensor of shape (prediction_length,)
101+ Time indices for the decoder sequence.
102+ * ``history_target`` : torch.Tensor of shape (context_length,)
103+ Historical target values for the encoder sequence.
104+ * ``future_target`` : torch.Tensor of shape (prediction_length,)
105+ Target values for the decoder sequence.
106+ * ``future_target_len`` : torch.Tensor of shape (1,)
107+ Length of the decoder target sequence.
108+
109+ Optional fields, depending on dataset configuration:
110+
111+ * ``history_relative_time_idx`` : torch.Tensor of shape (context_length,),
112+ optional
113+ Relative time indices for the encoder sequence, present if
114+ `add_relative_time_idx` is True.
115+ * ``future_relative_time_idx`` : torch.Tensor of shape (prediction_length,),
116+ optional
117+ Relative time indices for the decoder sequence, present if
118+ `add_relative_time_idx` is True.
119+ * ``static_categorical_features`` : torch.Tensor of shape
120+ (1, n_static_features), optional
121+ Static categorical features if available.
122+ * ``static_continuous_features`` : torch.Tensor of shape
123+ (1, n_static_features), optional
124+ Static continuous features if available.
125+ * ``target_scale`` : torch.Tensor of shape (1,), optional
126+ Scaling factor for the target values if provided by the dataset.
127+
128+ y : torch.Tensor or list of torch.Tensor
129+ Target values for the decoder sequence.
130+ If ``n_targets`` > 1, a list of tensors each of shape (prediction_length,)
131+ is returned. Otherwise, a tensor of shape (prediction_length,) is returned.
77132 """
78133
79134 series_idx , start_idx , context_length , prediction_length = self .windows [idx ]
@@ -170,6 +225,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
170225 x ["target_scale" ] = processed_data ["target_scale" ]
171226
172227 y = processed_data ["target" ][future_indices ]
228+ if self .data_module .n_targets > 1 :
229+ y = [t .squeeze (- 1 ) for t in torch .split (y , 1 , dim = 1 )]
230+ else :
231+ y = y .squeeze (- 1 )
173232
174233 return x , y
175234
@@ -294,6 +353,7 @@ def __init__(
294353 self .window_stride = window_stride
295354
296355 self .time_series_metadata = time_series_dataset .get_metadata ()
356+ self .n_targets = len (self .time_series_metadata ["cols" ]["y" ])
297357
298358 for idx , col in enumerate (self .time_series_metadata ["cols" ]["x" ]):
299359 if self .time_series_metadata ["col_type" ].get (col ) == "C" :
@@ -774,8 +834,11 @@ def collate_fn(batch):
774834
775835 Returns
776836 -------
777- tuple[dict[str, torch.Tensor], torch.Tensor]
837+ tuple[dict[str, torch.Tensor], torch.Tensor or list of torch.Tensor ]
778838 A tuple containing the collated data and the target variable.
839+ If the dataset has multiple targets, a list of tensors each of shape
840+ (batch_size, prediction_length,). Otherwise, a single tensor of shape
841+ (batch_size, prediction_length).
779842 """
780843
781844 x_batch = {
@@ -816,5 +879,13 @@ def collate_fn(batch):
816879 [x ["static_continuous_features" ] for x , _ in batch ]
817880 )
818881
819- y_batch = torch .stack ([y for _ , y in batch ])
882+ if isinstance (batch [0 ][1 ], (list , tuple )):
883+ num_targets = len (batch [0 ][1 ])
884+ y_batch = []
885+ for i in range (num_targets ):
886+ target_tensors = [sample_y [i ] for _ , sample_y in batch ]
887+ stacked_target = torch .stack (target_tensors )
888+ y_batch .append (stacked_target )
889+ else :
890+ y_batch = torch .stack ([y for _ , y in batch ])
820891 return x_batch , y_batch
0 commit comments