From 48456ec8c599939f9b296cfb87bc8f30b6bddadf Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Nov 2021 16:30:51 +0100 Subject: [PATCH 1/2] MaskedComputationLayer, out_spatial_dim option #597 --- returnn/tf/layers/rec.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 36f3283253..66041bb48e 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -7320,6 +7320,7 @@ class MaskedComputationLayer(LayerBase): recurrent = True def __init__(self, mask, unit, masked_from, + out_spatial_dim=None, _layer_class, _layer_desc, _parent_layer_cache=None, **kwargs): @@ -7327,10 +7328,12 @@ def __init__(self, mask, unit, masked_from, :param LayerBase|None mask: :param dict[str] unit: :param LayerBase|None masked_from: + :param DimensionTag|None out_spatial_dim: :param type[LayerBase] _layer_class: :param dict[str] _layer_desc: :param dict[str,LayerBase]|None _parent_layer_cache: """ + out_spatial_dim # noqa # handled in transform_config_dict from returnn.tf.network import get_layer_class from .base import WrappedInternalLayer from returnn.tf.util.basic import where_bc, get_shape, nd_indices @@ -7514,8 +7517,10 @@ def transform_config_dict(cls, d, network, get_layer): # Just call it for dep resolution. parent_layer_cache = d.setdefault("_parent_layer_cache", {}) d["_layer_class"], d["_layer_desc"] = cls._create_template( - name=d["_name"], network=network, sources=d["sources"], masked_from=masked_from, + name=d["_name"], network=network, sources=d["sources"], + masked_from=masked_from, unit=d["unit"], + out_spatial_dim=d.get("out_spatial_dim", None), get_layer=get_layer, _parent_layer_cache=parent_layer_cache) if masked_from and not parent_layer_cache: # We explicitly do not want to have these as deps. @@ -7526,13 +7531,15 @@ def transform_config_dict(cls, d, network, get_layer): # noinspection PyUnusedLocal @classmethod def _create_template(cls, name, network, sources, masked_from, unit, + out_spatial_dim=None, get_layer=None, _parent_layer_cache=None, **kwargs): """ :param str name: :param returnn.tf.network.TFNetwork network: :param list[LayerBase] sources: - :param dict[str] unit: :param LayerBase masked_from: + :param dict[str] unit: + :param DimensionTag|None out_spatial_dim: :param (str)->LayerBase get_layer: :param dict[str,LayerBase]|None parent_layer_cache: :return: layer_class, layer_desc @@ -7543,6 +7550,8 @@ def _create_template(cls, name, network, sources, masked_from, unit, get_layer = network.get_layer # We don't care about the right masked input here, but just about deriving the right output shape. if masked_from: + if out_spatial_dim: + masked_from.output.get_time_dim_tag().declare_same_as(out_spatial_dim) if network.is_inside_rec_layer(inside_loop=True): source_data = ( masked_from.output @@ -7562,9 +7571,13 @@ def _create_template(cls, name, network, sources, masked_from, unit, if not network.is_inside_rec_layer() and source: source_data = source.output.copy_template().copy_as_time_major() # Create own time dim tag, to make sure we have some own custom. + if not out_spatial_dim: + out_spatial_dim = DimensionTag( + kind=DimensionTag.Types.Spatial, description="%s:masked:time" % name, + derived_from_tag=source_data.get_time_dim_tag()) source_data = source_data.copy_template_replace_dim_tag( axis=0, - new_dim_tag=DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:masked:time" % name)) + new_dim_tag=out_spatial_dim) source = WrappedInternalLayer( base_layer=source, network=source.network, name=source.name, output=source_data) @@ -7612,11 +7625,13 @@ def sub_get_layer(sub_layer_name): return layer_class, layer_desc @classmethod - def get_out_data_from_opts(cls, network, **kwargs): + def get_out_data_from_opts(cls, network, out_spatial_dim=None, **kwargs): """ :param returnn.tf.network.TFNetwork network: + :param DimensionTag|None out_spatial_dim: :rtype: Data """ + out_spatial_dim # noqa # handled in transform_config_dict layer_class, layer_desc = kwargs["_layer_class"], kwargs["_layer_desc"] output = layer_class.get_out_data_from_opts(**layer_desc) assert isinstance(output, Data) From 0305a4a064370d3aa2ba1b16856fbacdf2541723 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Nov 2021 16:34:25 +0100 Subject: [PATCH 2/2] small fix --- returnn/tf/layers/rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 66041bb48e..36528a3353 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -7320,8 +7320,8 @@ class MaskedComputationLayer(LayerBase): recurrent = True def __init__(self, mask, unit, masked_from, - out_spatial_dim=None, _layer_class, _layer_desc, + out_spatial_dim=None, _parent_layer_cache=None, **kwargs): """