Skip to content

MaskedComputationLayer, out_spatial_dim option #811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 1, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7321,16 +7321,19 @@ class MaskedComputationLayer(LayerBase):

def __init__(self, mask, unit, masked_from,
_layer_class, _layer_desc,
out_spatial_dim=None,
_parent_layer_cache=None,
**kwargs):
"""
:param LayerBase|None mask:
:param dict[str] unit:
:param LayerBase|None masked_from:
:param DimensionTag|None out_spatial_dim:
:param type[LayerBase] _layer_class:
:param dict[str] _layer_desc:
:param dict[str,LayerBase]|None _parent_layer_cache:
"""
out_spatial_dim # noqa # handled in transform_config_dict
from returnn.tf.network import get_layer_class
from .base import WrappedInternalLayer
from returnn.tf.util.basic import where_bc, get_shape, nd_indices
Expand Down Expand Up @@ -7514,8 +7517,10 @@ def transform_config_dict(cls, d, network, get_layer):
# Just call it for dep resolution.
parent_layer_cache = d.setdefault("_parent_layer_cache", {})
d["_layer_class"], d["_layer_desc"] = cls._create_template(
name=d["_name"], network=network, sources=d["sources"], masked_from=masked_from,
name=d["_name"], network=network, sources=d["sources"],
masked_from=masked_from,
unit=d["unit"],
out_spatial_dim=d.get("out_spatial_dim", None),
get_layer=get_layer, _parent_layer_cache=parent_layer_cache)
if masked_from and not parent_layer_cache:
# We explicitly do not want to have these as deps.
Expand All @@ -7526,13 +7531,15 @@ def transform_config_dict(cls, d, network, get_layer):
# noinspection PyUnusedLocal
@classmethod
def _create_template(cls, name, network, sources, masked_from, unit,
out_spatial_dim=None,
get_layer=None, _parent_layer_cache=None, **kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param dict[str] unit:
:param LayerBase masked_from:
:param dict[str] unit:
:param DimensionTag|None out_spatial_dim:
:param (str)->LayerBase get_layer:
:param dict[str,LayerBase]|None parent_layer_cache:
:return: layer_class, layer_desc
Expand All @@ -7543,6 +7550,8 @@ def _create_template(cls, name, network, sources, masked_from, unit,
get_layer = network.get_layer
# We don't care about the right masked input here, but just about deriving the right output shape.
if masked_from:
if out_spatial_dim:
masked_from.output.get_time_dim_tag().declare_same_as(out_spatial_dim)
if network.is_inside_rec_layer(inside_loop=True):
source_data = (
masked_from.output
Expand All @@ -7562,9 +7571,13 @@ def _create_template(cls, name, network, sources, masked_from, unit,
if not network.is_inside_rec_layer() and source:
source_data = source.output.copy_template().copy_as_time_major()
# Create own time dim tag, to make sure we have some own custom.
if not out_spatial_dim:
out_spatial_dim = DimensionTag(
kind=DimensionTag.Types.Spatial, description="%s:masked:time" % name,
derived_from_tag=source_data.get_time_dim_tag())
source_data = source_data.copy_template_replace_dim_tag(
axis=0,
new_dim_tag=DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:masked:time" % name))
new_dim_tag=out_spatial_dim)
source = WrappedInternalLayer(
base_layer=source, network=source.network, name=source.name,
output=source_data)
Expand Down Expand Up @@ -7612,11 +7625,13 @@ def sub_get_layer(sub_layer_name):
return layer_class, layer_desc

@classmethod
def get_out_data_from_opts(cls, network, **kwargs):
def get_out_data_from_opts(cls, network, out_spatial_dim=None, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param DimensionTag|None out_spatial_dim:
:rtype: Data
"""
out_spatial_dim # noqa # handled in transform_config_dict
layer_class, layer_desc = kwargs["_layer_class"], kwargs["_layer_desc"]
output = layer_class.get_out_data_from_opts(**layer_desc)
assert isinstance(output, Data)
Expand Down