|
24 | 24 | from returnn.util import basic as util
|
25 | 25 | from returnn.util import NumbersDict
|
26 | 26 | from returnn.util.basic import NotSpecified
|
| 27 | +from returnn.forward_iface import ForwardCallbackIface |
| 28 | + |
27 | 29 | from .updater import Updater
|
28 | 30 | from .data import pipeline as data_pipeline
|
29 | 31 | from .data import returnn_dataset_wrapper
|
@@ -55,6 +57,11 @@ def __init__(self, config: Config):
|
55 | 57 | self._orig_model = None # type: Optional[Union[rf.Module, torch.nn.Module]]
|
56 | 58 | self._pt_model = None # type: Optional[torch.nn.Module]
|
57 | 59 | self._train_step_func = None # type: Optional[Callable]
|
| 60 | + self._forward_step_func = self.config.typed_value("forward_step") # type: Optional[Callable] |
| 61 | + self._forward_step_expected_outputs = None # type: Optional[TensorDict] |
| 62 | + if self.config.typed_value("model_outputs") is not None: |
| 63 | + self._forward_step_expected_outputs = TensorDict() |
| 64 | + self._forward_step_expected_outputs.update(self.config.typed_value("model_outputs"), auto_convert=True) |
58 | 65 | self._save_model_epoch_interval = 1
|
59 | 66 | self._updater = None # type: Optional[Updater]
|
60 | 67 |
|
@@ -98,6 +105,7 @@ def init_network_from_config(self, config: Optional[Config] = None):
|
98 | 105 |
|
99 | 106 | extern_data = TensorDict()
|
100 | 107 | extern_data_dict = self.config.typed_value("extern_data")
|
| 108 | + assert extern_data_dict, "extern_data is not specified in config" |
101 | 109 | extern_data.update(extern_data_dict, auto_convert=True)
|
102 | 110 | if "seq_tag" not in extern_data.data:
|
103 | 111 | batch_dim = _get_batch_dim_from_extern_data(extern_data)
|
@@ -194,9 +202,12 @@ def train_epoch(self):
|
194 | 202 | accumulated_losses_dict = NumbersDict()
|
195 | 203 | accumulated_inv_norm_factors_dict = NumbersDict()
|
196 | 204 | step_idx = 0
|
197 |
| - for data in self._train_dataloader: |
| 205 | + for extern_data_raw in self._train_dataloader: |
198 | 206 | self._updater.get_optimizer().zero_grad()
|
199 |
| - self._run_step(data, train_flag=True) |
| 207 | + extern_data = _raw_dict_to_extern_data( |
| 208 | + extern_data_raw, extern_data_template=self.extern_data, device=self._device |
| 209 | + ) |
| 210 | + self._run_step(extern_data, train_func=True, train_flag=True) |
200 | 211 |
|
201 | 212 | train_ctx = rf.get_run_ctx()
|
202 | 213 | total_loss = train_ctx.total_loss()
|
@@ -265,9 +276,12 @@ def eval_model(self):
|
265 | 276 | step_idx = 0
|
266 | 277 |
|
267 | 278 | with torch.no_grad():
|
268 |
| - for data in data_loader: |
| 279 | + for extern_data_raw in data_loader: |
| 280 | + extern_data = _raw_dict_to_extern_data( |
| 281 | + extern_data_raw, extern_data_template=self.extern_data, device=self._device |
| 282 | + ) |
269 | 283 |
|
270 |
| - self._run_step(data) |
| 284 | + self._run_step(extern_data, train_func=True) |
271 | 285 | train_ctx = rf.get_run_ctx()
|
272 | 286 |
|
273 | 287 | if score_keys is None:
|
@@ -345,19 +359,23 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader2:
|
345 | 359 | raise ModuleNotFoundError("Possible type error in DataLoader2 due to missing module 'dill'") from exc
|
346 | 360 | raise
|
347 | 361 |
|
348 |
| - def _run_step(self, extern_data_raw: Dict[str, torch.Tensor], *, train_flag: bool = False): |
| 362 | + def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool): |
349 | 363 | """
|
350 |
| - :param dict[str, torch.Tensor] extern_data_raw: model inputs for the step |
| 364 | + :param extern_data: model inputs for the step |
351 | 365 | """
|
352 |
| - extern_data = _raw_dict_to_extern_data( |
353 |
| - extern_data_raw, extern_data_template=self.extern_data, device=self._device |
354 |
| - ) |
355 |
| - |
356 |
| - rf.init_train_step_run_ctx(train_flag=train_flag) |
| 366 | + if train_func: |
| 367 | + assert self._train_step_func is not None |
| 368 | + rf.init_train_step_run_ctx(train_flag=train_flag) |
| 369 | + else: |
| 370 | + assert self._forward_step_func is not None, "define forward_step in the config" |
| 371 | + rf.init_forward_step_run_ctx(expected_outputs=self._forward_step_expected_outputs) |
357 | 372 |
|
358 | 373 | with autocast(device_type=self._device, dtype=self._autocast_dtype) if self._use_autocast else nullcontext():
|
359 | 374 | sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
|
360 |
| - self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw) |
| 375 | + if train_func: |
| 376 | + self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw) |
| 377 | + else: |
| 378 | + self._forward_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw) |
361 | 379 |
|
362 | 380 | def _load_model(self):
|
363 | 381 | """
|
@@ -393,7 +411,7 @@ def _load_model(self):
|
393 | 411 | rf.set_random_seed(random_seed)
|
394 | 412 |
|
395 | 413 | get_model_func = self.config.typed_value("get_model")
|
396 |
| - assert get_model_func, "get_model not defined" |
| 414 | + assert get_model_func, "get_model not defined in config" |
397 | 415 | sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
|
398 | 416 | # Note on the `epoch` and `step` args:
|
399 | 417 | # This is the current epoch and step, i.e. the epoch and step we are about to run.
|
@@ -521,6 +539,42 @@ def _save_optimizer(self):
|
521 | 539 | if os.path.isfile(filename):
|
522 | 540 | os.unlink(filename)
|
523 | 541 |
|
| 542 | + def forward_with_callback(self, *, dataset: Dataset, callback: ForwardCallbackIface): |
| 543 | + """forward""" |
| 544 | + assert isinstance(dataset, Dataset) |
| 545 | + assert isinstance(callback, ForwardCallbackIface) |
| 546 | + |
| 547 | + self._pt_model.eval() |
| 548 | + |
| 549 | + data_loader = self._create_data_loader(dataset) |
| 550 | + batch_dim = _get_batch_dim_from_extern_data(self.extern_data) |
| 551 | + |
| 552 | + with torch.no_grad(): |
| 553 | + callback.init(model=self._orig_model) |
| 554 | + |
| 555 | + for extern_data_raw in data_loader: |
| 556 | + extern_data = _raw_dict_to_extern_data( |
| 557 | + extern_data_raw, extern_data_template=self.extern_data, device=self._device |
| 558 | + ) |
| 559 | + self._run_step(extern_data, train_func=False) |
| 560 | + ctx = rf.get_run_ctx() |
| 561 | + ctx.check_outputs_complete() |
| 562 | + |
| 563 | + model_outputs = ctx.outputs |
| 564 | + model_outputs_per_batch_template = TensorDict( |
| 565 | + {k: v.copy_template_excluding_axis(0) for k, v in model_outputs.data.items()} |
| 566 | + ) |
| 567 | + for batch_idx in range(batch_dim.get_dim_value()): |
| 568 | + seq_tag = extern_data["seq_tag"].raw_tensor[batch_idx].item() |
| 569 | + model_outputs_per_batch = TensorDict( |
| 570 | + {k: v.copy() for k, v in model_outputs_per_batch_template.data.items()} |
| 571 | + ) |
| 572 | + for k, v in model_outputs.data.items(): |
| 573 | + model_outputs_per_batch[k].raw_tensor = v.raw_tensor[batch_idx] |
| 574 | + callback.process_seq(seq_tag=seq_tag, outputs=model_outputs_per_batch) |
| 575 | + |
| 576 | + callback.finish() |
| 577 | + |
524 | 578 |
|
525 | 579 | def _to_raw(n: Union[int, float, Tensor]):
|
526 | 580 | if isinstance(n, (int, float)):
|
|
0 commit comments