Skip to content

Commit 33de58f

Browse files
PyTorch: dummy training loop, iterate through dataset
1 parent 6a756d7 commit 33de58f

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

returnn/torch/engine.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
Main engine for PyTorch
33
"""
44

5+
from torch.utils.data import DataLoader
6+
7+
from returnn.log import log
58
from returnn.engine.base import EngineBase
69
from returnn.datasets.basic import init_dataset
10+
from returnn.torch.dataset_wrapper import DatasetWrapper
711

812

913
class Engine(EngineBase):
@@ -42,4 +46,24 @@ def train(self):
4246
"""
4347
Main training loop.
4448
"""
45-
pass
49+
start_epoch, _ = self.get_train_start_epoch_batch(self.config)
50+
final_epoch = self.config_get_final_epoch(self.config)
51+
52+
print(f"Starting training at epoch {start_epoch}.", file=log.v3)
53+
54+
self.epoch = start_epoch
55+
while self.epoch <= final_epoch:
56+
print("Starting " + self.get_epoch_str() + "...", file=log.v4)
57+
58+
self.train_dataset.init_seq_order(epoch=self.epoch)
59+
60+
train_data = DatasetWrapper(self.train_dataset)
61+
62+
data_loader = DataLoader(train_data, batch_size=1) # TODO: implement batching
63+
64+
for batch_index, data in enumerate(data_loader):
65+
pass # TODO: only iterates through dataset so far
66+
67+
self.epoch += 1
68+
69+
print(f"Finished training at epoch {self.epoch}.", file=log.v3)

0 commit comments

Comments
 (0)