File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
pytorch_forecasting/tests Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change 66import lightning .pytorch as pl
77from lightning .pytorch .callbacks import EarlyStopping
88from lightning .pytorch .loggers import TensorBoardLogger
9+ import torch
910import torch .nn as nn
1011
1112from pytorch_forecasting .tests .test_all_estimators import (
@@ -77,6 +78,19 @@ def _integration(
7778 test_outputs = trainer .test (net , dataloaders = test_dataloader )
7879 assert len (test_outputs ) > 0
7980
81+ # todo: add the predict pipeline and make this test cleaner
82+ x , y = next (iter (test_dataloader ))
83+ net .eval ()
84+ with torch .no_grad ():
85+ output = net (x )
86+ net .train ()
87+ prediction = output ["prediction" ]
88+ n_dims = len (prediction .shape )
89+ assert n_dims == 3 , (
90+ f"Prediction output must be 3D, but got { n_dims } D tensor "
91+ f"with shape { output .shape } "
92+ )
93+
8094 shutil .rmtree (tmp_path , ignore_errors = True )
8195
8296
You can’t perform that action at this time.
0 commit comments