2
2
Main engine for PyTorch
3
3
"""
4
4
5
- from returnn .config import Config
6
5
from returnn .engine .base import EngineBase
6
+ from returnn .datasets .basic import init_dataset
7
7
8
8
9
9
class Engine (EngineBase ):
@@ -13,7 +13,33 @@ class Engine(EngineBase):
13
13
14
14
def __init__ (self , config ):
15
15
"""
16
- :param Config config:
16
+ :param returnn.config. Config config:
17
17
"""
18
18
super (Engine , self ).__init__ ()
19
19
self .config = config
20
+ self .train_dataset = None
21
+ self .eval_datasets = {}
22
+
23
+ def init_train_from_config (self , config = None , train_data = None , dev_data = None , eval_data = None ):
24
+ """
25
+ :param returnn.config.Config|None config:
26
+ :param returnn.datasets.basic.Dataset|None train_data:
27
+ :param returnn.datasets.basic.Dataset|None dev_data:
28
+ :param returnn.datasets.basic.Dataset|None eval_data:
29
+ """
30
+ assert config is self .config
31
+ self .train_dataset = train_data
32
+ self .eval_datasets .clear ()
33
+ if dev_data :
34
+ self .eval_datasets ["dev" ] = dev_data
35
+ if eval_data :
36
+ self .eval_datasets ["eval" ] = eval_data
37
+ if config .has ("eval_datasets" ):
38
+ for dataset_name , dataset_opts in config .typed_value ("eval_datasets" , {}).items ():
39
+ self .eval_datasets [dataset_name ] = init_dataset (dataset_opts , default_kwargs = {"name" : dataset_name })
40
+
41
+ def train (self ):
42
+ """
43
+ Main training loop.
44
+ """
45
+ pass
0 commit comments