@@ -54,7 +54,10 @@ def get_out_data_from_opts(cls, **kwargs):
54
54
55
55
# For compatibility, we have some parameter names (e.g. "L2") which do not conform to PEP8.
56
56
# noinspection PyPep8Naming
57
- def __init__ (self , name , network , output , n_out = NotSpecified , out_type = None , sources = (),
57
+ def __init__ (self , name , network , output ,
58
+ n_out = NotSpecified , out_type = None ,
59
+ out_shape = None ,
60
+ sources = (),
58
61
target = None , _target_layers = None , loss = None , size_target = None ,
59
62
reuse_params = None ,
60
63
name_scope = None ,
@@ -87,6 +90,8 @@ def __init__(self, name, network, output, n_out=NotSpecified, out_type=None, sou
87
90
:param Data output: Set a specific output instead of using :func:`get_out_data_from_opts`
88
91
:param NotSpecified|None|int n_out: output dim
89
92
:param dict[str] out_type: kwargs for Data class. more explicit than n_out.
93
+ :param set[DimensionTag|_ImplicitDim]|tuple|list|None out_shape: verifies the output shape (dim tags).
94
+ See :func:`Data.verify_out_shape`.
90
95
:param list[LayerBase] sources: via self.transform_config_dict()
91
96
:param str|list[str]|None target: if some loss is set, this is the target data-key,
92
97
i.e. network.extern_data.get_data(target). alternatively, this also can be a layer name.
@@ -164,6 +169,7 @@ def __init__(self, name, network, output, n_out=NotSpecified, out_type=None, sou
164
169
assert self .output .shape == out_type ["shape" ]
165
170
if "dim" in out_type :
166
171
assert self .output .dim == out_type ["dim" ]
172
+ out_shape # noqa # not used here but in fixup_out_data
167
173
self .output_before_activation = None # type: typing.Optional[OutputWithActivation]
168
174
self .output_loss = None # type: typing.Optional[tf.Tensor]
169
175
if copy_output_loss_from_source_idx is not None :
@@ -409,7 +415,7 @@ def _post_init_output(cls, output, network, target=None, size_target=None, _targ
409
415
output .available_for_inference = False
410
416
411
417
@classmethod
412
- def fixup_out_data (cls , output , network , ** _kwargs ):
418
+ def fixup_out_data (cls , output , network , out_shape = None , ** _kwargs ):
413
419
"""
414
420
This is called after get_out_data_from_opts, to fixup incomplete information.
415
421
E.g. we can patch batch or beam information here
@@ -420,6 +426,8 @@ def fixup_out_data(cls, output, network, **_kwargs):
420
426
421
427
:param Data output:
422
428
:param returnn.tf.network.TFNetwork network:
429
+ :param set[DimensionTag|_ImplicitDim]|tuple|list|None out_shape: verifies the output shape (dim tags).
430
+ See :func:`Data.verify_out_shape`.
423
431
:rtype: Data
424
432
"""
425
433
from ..network import ExternData
@@ -447,6 +455,8 @@ def fixup_out_data(cls, output, network, **_kwargs):
447
455
# Some layers might just copy the input. But the input might have buggy ctx.
448
456
# Just leave the placeholder as-is. Most layers should anyway reset this.
449
457
output .placeholder = x
458
+ if out_shape is not None :
459
+ output .verify_out_shape (out_shape )
450
460
return output
451
461
452
462
@classmethod
0 commit comments