File tree 1 file changed +25
-1
lines changed
1 file changed +25
-1
lines changed Original file line number Diff line number Diff line change 2
2
Main engine for PyTorch
3
3
"""
4
4
5
+ from torch .utils .data import DataLoader
6
+
7
+ from returnn .log import log
5
8
from returnn .engine .base import EngineBase
6
9
from returnn .datasets .basic import init_dataset
10
+ from returnn .torch .dataset_wrapper import DatasetWrapper
7
11
8
12
9
13
class Engine (EngineBase ):
@@ -42,4 +46,24 @@ def train(self):
42
46
"""
43
47
Main training loop.
44
48
"""
45
- pass
49
+ start_epoch , _ = self .get_train_start_epoch_batch (self .config )
50
+ final_epoch = self .config_get_final_epoch (self .config )
51
+
52
+ print (f"Starting training at epoch { start_epoch } ." , file = log .v3 )
53
+
54
+ self .epoch = start_epoch
55
+ while self .epoch <= final_epoch :
56
+ print ("Starting " + self .get_epoch_str () + "..." , file = log .v4 )
57
+
58
+ self .train_dataset .init_seq_order (epoch = self .epoch )
59
+
60
+ train_data = DatasetWrapper (self .train_dataset )
61
+
62
+ data_loader = DataLoader (train_data , batch_size = 1 ) # TODO: implement batching
63
+
64
+ for batch_index , data in enumerate (data_loader ):
65
+ pass # TODO: only iterates through dataset so far
66
+
67
+ self .epoch += 1
68
+
69
+ print (f"Finished training at epoch { self .epoch } ." , file = log .v3 )
You can’t perform that action at this time.
0 commit comments