Skip to content

Commit 76ce5e0

Browse files
committed
LayerBase.fixup_out_data, prepare for other layer dict args
1 parent aba4232 commit 76ce5e0

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

returnn/tf/layers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _post_init_output(cls, output, network, target=None, size_target=None, _targ
409409
output.available_for_inference = False
410410

411411
@classmethod
412-
def fixup_out_data(cls, output, network):
412+
def fixup_out_data(cls, output, network, **_kwargs):
413413
"""
414414
This is called after get_out_data_from_opts, to fixup incomplete information.
415415
E.g. we can patch batch or beam information here

returnn/tf/layers/rec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,13 +1250,13 @@ def add_templated_layer(lself, name, layer_class, **layer_desc):
12501250
layer_desc["network"] = self.net
12511251
old_layer_kwargs = layer_.kwargs
12521252
layer_.kwargs = layer_desc.copy() # set it now already for better debugging
1253-
if "output" not in layer_desc:
1253+
if "output" not in layer_.kwargs:
12541254
if old_layer_kwargs and "output" in old_layer_kwargs:
12551255
# First copy old output. Maybe the get_out_data_from_opts raises an exception,
12561256
# and we don't want this to be unset.
12571257
layer_.kwargs["output"] = old_layer_kwargs["output"]
12581258
layer_.kwargs["output"] = layer_class.get_out_data_from_opts(**layer_desc)
1259-
layer_.kwargs["output"] = layer_class.fixup_out_data(layer_.kwargs["output"], network=self.net)
1259+
layer_.kwargs["output"] = layer_class.fixup_out_data(**layer_.kwargs)
12601260
layer_.kwargs["output"].sanity_check(ignore_placeholder=True) # placeholder might be overwritten later
12611261
layer_.init(layer_class=layer_class, **layer_.kwargs)
12621262
if layer_.need_last:
@@ -1529,7 +1529,7 @@ def _add_template_layer(self, layer_name, layer_dict):
15291529
layer_class.transform_config_dict(
15301530
layer_dict, network=self.net, get_layer=lambda _name: self.layer_data_templates[_name])
15311531
out = layer_class.get_out_data_from_opts(name=layer_name, network=self.net, **layer_dict)
1532-
out = layer_class.fixup_out_data(output=out, network=self.net)
1532+
out = layer_class.fixup_out_data(output=out, network=self.net, **layer_dict)
15331533
layer.init(output=out, layer_class=layer_class, **layer_dict)
15341534
self.layer_data_templates[layer_name] = layer
15351535
return layer

returnn/tf/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def _create_layer(self, name, layer_class, **layer_desc):
983983
output_template = layer_desc["output"]
984984
assert isinstance(output_template, Data), "%s %r layer_desc %r ['output'] is not a Data instance" % (
985985
layer_class.__name__, name, layer_desc)
986-
output_template = layer_class.fixup_out_data(output_template, network=self)
986+
output_template = layer_class.fixup_out_data(**layer_desc)
987987
layer_desc["output"] = output_template
988988
print(
989989
"layer %s/%r output: %r" % (self.name, name, output_template),

0 commit comments

Comments
 (0)