Skip to content

Commit 7836108

Browse files
committed
PT engine forward task
Implement and fix #1336
1 parent 281a4cb commit 7836108

File tree

6 files changed

+200
-29
lines changed

6 files changed

+200
-29
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ jobs:
261261
action:
262262
- TEST=demos RETURNN_DISABLE_TF=1
263263
- TEST=PTDataset
264+
- TEST=torch_engine
264265
- TEST=torch_frontend
265266
- TEST=torch_internal_frontend
266267

returnn/__main__.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -445,22 +445,36 @@ def execute_main_task():
445445
lr_control_update_scores=lr_control_update_scores,
446446
)
447447
elif task in ["forward", "hpx"]:
448-
assert eval_data is not None, "no eval data provided"
449-
combine_labels = config.value("combine_labels", "")
450-
engine.use_search_flag = config.bool("forward_use_search", False)
451-
if config.has("epoch"):
452-
config.set("load_epoch", config.int("epoch", 0))
453-
engine.init_network_from_config(config)
454-
output_file = config.value("output_file", "dump-fwd-epoch-%i.hdf" % engine.epoch)
455-
forward_batch_size = config.int("forward_batch_size", 0)
456-
if not forward_batch_size:
457-
raise Exception("forward_batch_size not set")
458-
engine.forward_to_hdf(
459-
data=eval_data,
460-
output_file=output_file,
461-
combine_labels=combine_labels,
462-
batch_size=forward_batch_size,
463-
)
448+
if config.typed_value("forward_callback") or not BackendEngine.is_tensorflow_selected():
449+
engine.init_network_from_config(config)
450+
if config.value("forward_data", "eval") in ["train", "dev", "eval"]:
451+
data = {"train": train_data, "dev": dev_data, "eval": eval_data}[config.value("forward_data", "eval")]
452+
assert data, "set forward_data"
453+
else:
454+
data = init_dataset(config.opt_typed_value("forward_data"))
455+
forward_callback = config.typed_value("forward_callback")
456+
assert forward_callback, "no forward_callback specified"
457+
if callable(forward_callback):
458+
forward_callback = forward_callback()
459+
engine.forward_with_callback(dataset=data, callback=forward_callback)
460+
else:
461+
assert BackendEngine.is_tensorflow_selected()
462+
assert eval_data is not None, "no eval data provided"
463+
combine_labels = config.value("combine_labels", "")
464+
engine.use_search_flag = config.bool("forward_use_search", False)
465+
if config.has("epoch"):
466+
config.set("load_epoch", config.int("epoch", 0))
467+
engine.init_network_from_config(config)
468+
output_file = config.value("output_file", "dump-fwd-epoch-%i.hdf" % engine.epoch)
469+
forward_batch_size = config.int("forward_batch_size", 0)
470+
if not forward_batch_size:
471+
raise Exception("forward_batch_size not set")
472+
engine.forward_to_hdf(
473+
data=eval_data,
474+
output_file=output_file,
475+
combine_labels=combine_labels,
476+
batch_size=forward_batch_size,
477+
)
464478
elif task == "search":
465479
engine.use_search_flag = True
466480
engine.use_eval_flag = config.bool("search_do_eval", True)

returnn/engine/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from returnn.log import log
1414
from returnn.pretrain import Pretrain
1515
from returnn.util import basic as util
16+
from returnn.forward_iface import ForwardCallbackIface
17+
from returnn.datasets import Dataset
1618

1719

1820
class EngineBase:
@@ -241,3 +243,11 @@ def is_first_epoch_after_pretrain(self):
241243
:rtype: bool
242244
"""
243245
return self.pretrain and self.epoch == self.pretrain.get_train_num_epochs() + 1
246+
247+
def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIface):
248+
"""
249+
Iterate through the dataset, calling `forward_step` from user config,
250+
collecting outputs in `rf.get_run_ctx()` via `mark_as_output` calls,
251+
and then calling `callback` for each entry.
252+
"""
253+
raise NotImplementedError

returnn/forward_iface.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Defines the interface for the "forward" task,
3+
which can be used for recognition, alignment, search, etc.
4+
5+
https://github.com/rwth-i6/returnn/issues/1336
6+
"""
7+
8+
from __future__ import annotations
9+
from returnn.tensor import TensorDict
10+
11+
12+
class ForwardCallbackIface:
13+
"""
14+
Callback interface for the forward task.
15+
16+
Define `forward_callback` in your config to an instance or class of this.
17+
18+
https://github.com/rwth-i6/returnn/issues/1336
19+
"""
20+
21+
def init(self, *, model):
22+
"""
23+
Run at the beginning.
24+
"""
25+
26+
def process_seq(self, *, seq_tag: str, outputs: TensorDict):
27+
"""
28+
Called for each sequence, or entry in the dataset.
29+
This does not have the batch dim anymore.
30+
The values in `outputs` are Numpy arrays.
31+
32+
:param seq_tag:
33+
:param outputs:
34+
"""
35+
36+
def finish(self):
37+
"""
38+
Run at the end.
39+
"""

returnn/torch/engine.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from returnn.util import basic as util
2525
from returnn.util import NumbersDict
2626
from returnn.util.basic import NotSpecified
27+
from returnn.forward_iface import ForwardCallbackIface
28+
2729
from .updater import Updater
2830
from .data import pipeline as data_pipeline
2931
from .data import returnn_dataset_wrapper
@@ -55,6 +57,11 @@ def __init__(self, config: Config):
5557
self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]]
5658
self._pt_model = None # type: Optional[torch.nn.Module]
5759
self._train_step_func = None # type: Optional[Callable]
60+
self._forward_step_func = self.config.typed_value("forward_step") # type: Optional[Callable]
61+
self._forward_step_expected_outputs = None # type: Optional[TensorDict]
62+
if self.config.typed_value("model_outputs") is not None:
63+
self._forward_step_expected_outputs = TensorDict()
64+
self._forward_step_expected_outputs.update(self.config.typed_value("model_outputs"), auto_convert=True)
5865
self._save_model_epoch_interval = 1
5966
self._updater = None # type: Optional[Updater]
6067

@@ -98,6 +105,7 @@ def init_network_from_config(self, config: Optional[Config] = None):
98105

99106
extern_data = TensorDict()
100107
extern_data_dict = self.config.typed_value("extern_data")
108+
assert extern_data_dict, "extern_data is not specified in config"
101109
extern_data.update(extern_data_dict, auto_convert=True)
102110
if "seq_tag" not in extern_data.data:
103111
batch_dim = _get_batch_dim_from_extern_data(extern_data)
@@ -194,9 +202,12 @@ def train_epoch(self):
194202
accumulated_losses_dict = NumbersDict()
195203
accumulated_inv_norm_factors_dict = NumbersDict()
196204
step_idx = 0
197-
for data in self._train_dataloader:
205+
for extern_data_raw in self._train_dataloader:
198206
self._updater.get_optimizer().zero_grad()
199-
self._run_step(data, train_flag=True)
207+
extern_data = _raw_dict_to_extern_data(
208+
extern_data_raw, extern_data_template=self.extern_data, device=self._device
209+
)
210+
self._run_step(extern_data, train_func=True, train_flag=True)
200211

201212
train_ctx = rf.get_run_ctx()
202213
total_loss = train_ctx.total_loss()
@@ -265,9 +276,12 @@ def eval_model(self):
265276
step_idx = 0
266277

267278
with torch.no_grad():
268-
for data in data_loader:
279+
for extern_data_raw in data_loader:
280+
extern_data = _raw_dict_to_extern_data(
281+
extern_data_raw, extern_data_template=self.extern_data, device=self._device
282+
)
269283

270-
self._run_step(data)
284+
self._run_step(extern_data, train_func=True)
271285
train_ctx = rf.get_run_ctx()
272286

273287
if score_keys is None:
@@ -345,19 +359,23 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader2:
345359
raise ModuleNotFoundError("Possible type error in DataLoader2 due to missing module 'dill'") from exc
346360
raise
347361

348-
def _run_step(self, extern_data_raw: Dict[str, torch.Tensor], *, train_flag: bool = False):
362+
def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool):
349363
"""
350-
:param dict[str, torch.Tensor] extern_data_raw: model inputs for the step
364+
:param extern_data: model inputs for the step
351365
"""
352-
extern_data = _raw_dict_to_extern_data(
353-
extern_data_raw, extern_data_template=self.extern_data, device=self._device
354-
)
355-
356-
rf.init_train_step_run_ctx(train_flag=train_flag)
366+
if train_func:
367+
assert self._train_step_func is not None
368+
rf.init_train_step_run_ctx(train_flag=train_flag)
369+
else:
370+
assert self._forward_step_func is not None, "define forward_step in the config"
371+
rf.init_forward_step_run_ctx(expected_outputs=self._forward_step_expected_outputs)
357372

358373
with autocast(device_type=self._device, dtype=self._autocast_dtype) if self._use_autocast else nullcontext():
359374
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
360-
self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
375+
if train_func:
376+
self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
377+
else:
378+
self._forward_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
361379

362380
def _load_model(self):
363381
"""
@@ -393,7 +411,7 @@ def _load_model(self):
393411
rf.set_random_seed(random_seed)
394412

395413
get_model_func = self.config.typed_value("get_model")
396-
assert get_model_func, "get_model not defined"
414+
assert get_model_func, "get_model not defined in config"
397415
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
398416
# Note on the `epoch` and `step` args:
399417
# This is the current epoch and step, i.e. the epoch and step we are about to run.
@@ -521,6 +539,42 @@ def _save_optimizer(self):
521539
if os.path.isfile(filename):
522540
os.unlink(filename)
523541

542+
def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIface):
543+
"""forward"""
544+
assert isinstance(dataset, Dataset)
545+
assert isinstance(callback, ForwardCallbackIface)
546+
547+
self._pt_model.eval()
548+
549+
data_loader = self._create_data_loader(dataset)
550+
batch_dim = _get_batch_dim_from_extern_data(self.extern_data)
551+
552+
with torch.no_grad():
553+
callback.init(model=self._orig_model)
554+
555+
for extern_data_raw in data_loader:
556+
extern_data = _raw_dict_to_extern_data(
557+
extern_data_raw, extern_data_template=self.extern_data, device=self._device
558+
)
559+
self._run_step(extern_data, train_func=False)
560+
ctx = rf.get_run_ctx()
561+
ctx.check_outputs_complete()
562+
563+
model_outputs = ctx.outputs
564+
model_outputs_per_batch_template = TensorDict(
565+
{k: v.copy_template_excluding_axis(0) for k, v in model_outputs.data.items()}
566+
)
567+
for batch_idx in range(batch_dim.get_dim_value()):
568+
seq_tag = extern_data["seq_tag"].raw_tensor[batch_idx].item()
569+
model_outputs_per_batch = TensorDict(
570+
{k: v.copy() for k, v in model_outputs_per_batch_template.data.items()}
571+
)
572+
for k, v in model_outputs.data.items():
573+
model_outputs_per_batch[k].raw_tensor = v.raw_tensor[batch_idx]
574+
callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch)
575+
576+
callback.finish()
577+
524578

525579
def _to_raw(n: Union[int, float, Tensor]):
526580
if isinstance(n, (int, float)):

tests/test_torch_engine.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
Tests for PyTorch engine.
3+
"""
4+
5+
import _setup_test_env # noqa
6+
import torch
7+
from returnn.config import Config, global_config_ctx
8+
from returnn.tensor import TensorDict, Tensor
9+
from returnn.torch.engine import Engine
10+
import returnn.frontend as rf
11+
from returnn.forward_iface import ForwardCallbackIface
12+
from returnn.datasets import init_dataset
13+
14+
15+
def test_torch_engine():
16+
def _get_model(**_kwargs):
17+
return torch.nn.Module()
18+
19+
def _forward_step(*, extern_data: TensorDict, **_kwargs):
20+
rf.get_run_ctx().mark_as_default_output(extern_data["data"])
21+
22+
class _ForwardCallback(ForwardCallbackIface):
23+
def __init__(self):
24+
self.num_seqs = 0
25+
self.init_called = False
26+
self.finish_called = False
27+
28+
def init(self, *, model):
29+
assert isinstance(model, torch.nn.Module)
30+
assert self.num_seqs == 0
31+
self.init_called = True
32+
33+
def process_seq(self, *, seq_tag: str, outputs: TensorDict):
34+
out = outputs["output"]
35+
assert isinstance(out, Tensor)
36+
assert out.batch_ndim == 2 and out.batch_shape[-1] == 9
37+
self.num_seqs += 1
38+
39+
def finish(self):
40+
self.finish_called = True
41+
42+
config = Config(
43+
dict(task="forward", extern_data={"data": {"dim": 9}}, get_model=_get_model, forward_step=_forward_step)
44+
)
45+
dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})
46+
callback = _ForwardCallback()
47+
48+
with global_config_ctx(config):
49+
engine = Engine(config=config)
50+
engine.init_network_from_config()
51+
engine.forward_with_callback(callback=callback, dataset=dataset)
52+
assert callback.num_seqs == 100
53+
assert callback.init_called and callback.finish_called

0 commit comments

Comments
 (0)