Skip to content

Commit 86c846f

Browse files
committed
Layer out_shape option for verification
Fix #706
1 parent 6706707 commit 86c846f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

returnn/tf/layers/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def get_out_data_from_opts(cls, **kwargs):
5454

5555
# For compatibility, we have some parameter names (e.g. "L2") which do not conform to PEP8.
5656
# noinspection PyPep8Naming
57-
def __init__(self, name, network, output, n_out=NotSpecified, out_type=None, sources=(),
57+
def __init__(self, name, network, output,
58+
n_out=NotSpecified, out_type=None,
59+
out_shape=None,
60+
sources=(),
5861
target=None, _target_layers=None, loss=None, size_target=None,
5962
reuse_params=None,
6063
name_scope=None,
@@ -87,6 +90,8 @@ def __init__(self, name, network, output, n_out=NotSpecified, out_type=None, sou
8790
:param Data output: Set a specific output instead of using :func:`get_out_data_from_opts`
8891
:param NotSpecified|None|int n_out: output dim
8992
:param dict[str] out_type: kwargs for Data class. more explicit than n_out.
93+
:param set[DimensionTag|_ImplicitDim]|tuple|list|None out_shape: verifies the output shape (dim tags).
94+
See :func:`Data.verify_out_shape`.
9095
:param list[LayerBase] sources: via self.transform_config_dict()
9196
:param str|list[str]|None target: if some loss is set, this is the target data-key,
9297
i.e. network.extern_data.get_data(target). alternatively, this also can be a layer name.
@@ -164,6 +169,7 @@ def __init__(self, name, network, output, n_out=NotSpecified, out_type=None, sou
164169
assert self.output.shape == out_type["shape"]
165170
if "dim" in out_type:
166171
assert self.output.dim == out_type["dim"]
172+
out_shape # noqa # not used here but in fixup_out_data
167173
self.output_before_activation = None # type: typing.Optional[OutputWithActivation]
168174
self.output_loss = None # type: typing.Optional[tf.Tensor]
169175
if copy_output_loss_from_source_idx is not None:
@@ -409,7 +415,7 @@ def _post_init_output(cls, output, network, target=None, size_target=None, _targ
409415
output.available_for_inference = False
410416

411417
@classmethod
412-
def fixup_out_data(cls, output, network, **_kwargs):
418+
def fixup_out_data(cls, output, network, out_shape=None, **_kwargs):
413419
"""
414420
This is called after get_out_data_from_opts, to fixup incomplete information.
415421
E.g. we can patch batch or beam information here
@@ -420,6 +426,8 @@ def fixup_out_data(cls, output, network, **_kwargs):
420426
421427
:param Data output:
422428
:param returnn.tf.network.TFNetwork network:
429+
:param set[DimensionTag|_ImplicitDim]|tuple|list|None out_shape: verifies the output shape (dim tags).
430+
See :func:`Data.verify_out_shape`.
423431
:rtype: Data
424432
"""
425433
from ..network import ExternData
@@ -447,6 +455,8 @@ def fixup_out_data(cls, output, network, **_kwargs):
447455
# Some layers might just copy the input. But the input might have buggy ctx.
448456
# Just leave the placeholder as-is. Most layers should anyway reset this.
449457
output.placeholder = x
458+
if out_shape is not None:
459+
output.verify_out_shape(out_shape)
450460
return output
451461

452462
@classmethod

0 commit comments

Comments
 (0)