Skip to content

Commit 4db5d8d

Browse files
committed
PyTorch engine dummy
#1120
1 parent a2388a5 commit 4db5d8d

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

returnn/torch/engine.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
Main engine for PyTorch
33
"""
44

5-
from returnn.config import Config
65
from returnn.engine.base import EngineBase
6+
from returnn.datasets.basic import init_dataset
77

88

99
class Engine(EngineBase):
@@ -13,7 +13,33 @@ class Engine(EngineBase):
1313

1414
def __init__(self, config):
1515
"""
16-
:param Config config:
16+
:param returnn.config.Config config:
1717
"""
1818
super(Engine, self).__init__()
1919
self.config = config
20+
self.train_dataset = None
21+
self.eval_datasets = {}
22+
23+
def init_train_from_config(self, config=None, train_data=None, dev_data=None, eval_data=None):
24+
"""
25+
:param returnn.config.Config|None config:
26+
:param returnn.datasets.basic.Dataset|None train_data:
27+
:param returnn.datasets.basic.Dataset|None dev_data:
28+
:param returnn.datasets.basic.Dataset|None eval_data:
29+
"""
30+
assert config is self.config
31+
self.train_dataset = train_data
32+
self.eval_datasets.clear()
33+
if dev_data:
34+
self.eval_datasets["dev"] = dev_data
35+
if eval_data:
36+
self.eval_datasets["eval"] = eval_data
37+
if config.has("eval_datasets"):
38+
for dataset_name, dataset_opts in config.typed_value("eval_datasets", {}).items():
39+
self.eval_datasets[dataset_name] = init_dataset(dataset_opts, default_kwargs={"name": dataset_name})
40+
41+
def train(self):
42+
"""
43+
Main training loop.
44+
"""
45+
pass

0 commit comments

Comments
 (0)