diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 5d89d659bd..f4944c697e 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -585,8 +585,9 @@ def get_absolute_name_scope_prefix(self): return self.get_base_absolute_name_scope_prefix() + "rec/" # all under "rec" sub-name-scope @classmethod - def get_rec_initial_extra_outputs(cls, **kwargs): + def get_rec_initial_extra_outputs(cls, unit, **kwargs): """ + :param str|type unit: cell name, minus the "Cell" at the end :rtype: dict[str,tf.Tensor|tuple[tf.Tensor]] """ # axis is handled in transform_config_dict @@ -597,7 +598,17 @@ def get_rec_initial_extra_outputs(cls, **kwargs): if axis != single_step_dim: return {} # We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer). - return {"state": RnnCellLayer.get_rec_initial_state(**kwargs)} + if isinstance(unit, _SubnetworkRecCell): + # noinspection PyProtectedMember + initial_outputs = {k: unit._get_init_output(k) for k in sorted(unit.prev_layers_needed)} + # noinspection PyProtectedMember + initial_extra_outputs = { + k: unit._get_init_extra_outputs(k) for k in sorted(unit.layer_data_templates.keys()) + } + initial_extra_outputs = {k: v for (k, v) in initial_extra_outputs.items() if v} + return {"outputs": initial_outputs, "extra_outputs": initial_extra_outputs} + assert isinstance(unit, str) + return {"state": RnnCellLayer.get_rec_initial_state(unit=unit, **kwargs)} @classmethod def get_rec_initial_output(cls, **kwargs): diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 476a13b9f4..2f73161026 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -9786,6 +9786,40 @@ def test_MaskedComputationLayer_sub_layers_RecLayer_construct(): print("seq lens:", out_seq_lens_v) +def test_MaskedComputationLayer_sub_rec_net_opt_out(): + from returnn.tf.util.data import single_step_dim + + check_reclayer_optimize_out( + { + "class": "masked_computation", + "mask": "mask", + "unit": { + "class": "rec", + "from": [], + "axis": single_step_dim, + "unit": { + "in": {"class": "copy", "from": "base:in"}, + "layer1": {"class": "linear", "from": "in", "n_out": 5}, + "layer2": {"class": "rec", "from": "layer1", "unit": "lstm", "n_out": 5, "axis": single_step_dim}, + "layer3": {"class": "combine", "from": ["layer2", "prev:layer3"], "kind": "add"}, + "output": {"class": "linear", "from": "layer3", "n_out": 3}, + }, + }, + "n_out": 3, + }, + { + "const1": {"class": "constant", "value": 1, "with_batch_dim": True}, # just to broadcast mask + "mask": { + "class": "eval", + "from": [":i", "const1"], + "out_type": {"dtype": "bool"}, + "eval": "tf.equal(source(0) % 2, source(1))", + }, + "in": {"class": "copy", "from": "data:source"}, + }, + ) + + def test_att_train_search_loss_prev_beam(): beam_size = 1 num_ner_labels = 13