Skip to content

Commit 15fed20

Browse files
authored
test_LinearLayer_in_dim_spatial, allow in_dim not feature dim (#783)
#597
1 parent 5c6e440 commit 15fed20

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

returnn/tf/layers/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def _base_get_out_data_from_opts(cls, network, name,
298298
out_type=None, out_dim=None, n_out=NotSpecified,
299299
out_shape=None,
300300
target=None, _target_layers=None, size_target=None,
301-
sources=(), loss=None,
301+
sources=(), in_dim=None, loss=None,
302302
**kwargs):
303303
"""
304304
Called via BaseLayer.get_out_data_from_opts().
@@ -313,6 +313,7 @@ def _base_get_out_data_from_opts(cls, network, name,
313313
:param dict[str,LayerBase]|None _target_layers: if target.startswith("layer:"), then this is target -> layer
314314
:param str|None size_target:
315315
:param list[LayerBase] sources:
316+
:param DimensionTag|None in_dim:
316317
:param Loss|None loss:
317318
:param kwargs: remaining kwargs of self.__init__(), ignored here
318319
:return: Data template (placeholder not set)
@@ -338,6 +339,15 @@ def _base_get_out_data_from_opts(cls, network, name,
338339
if n_out is not NotSpecified:
339340
assert out_type["dim"] == n_out
340341
sources_data_list = [src.output for src in sources if src]
342+
if in_dim:
343+
assert len(sources_data_list) == 1
344+
if sources_data_list[0].feature_dim_or_sparse_dim != in_dim:
345+
# Allow to specify some in_dim which is not the feature dim.
346+
# However, the follow-up code will expect it to be the feature dim, thus reassign it if possible.
347+
assert in_dim in sources_data_list[0].dim_tags
348+
axis = sources_data_list[0].get_axis_from_description(in_dim)
349+
sources_data_list = [sources_data_list[0].copy()]
350+
sources_data_list[0].feature_dim_axis = axis
341351
allow_broadcast_all_sources = NotSpecified
342352
if "shape" in out_type or "dim_tags" in out_type or out_shape is not None:
343353
allow_broadcast_all_sources = True

returnn/tf/layers/basic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,17 @@ def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpec
9494
if len(src_layers) == 1:
9595
data = src_layers[0].output.copy()
9696
if out_dim:
97-
assert out_dim == data.feature_dim_or_sparse_dim
97+
if out_dim == data.feature_dim_or_sparse_dim:
98+
pass # good
99+
elif out_dim in data.dim_tags:
100+
# We found out_dim in the input but it is not marked as the feature dim.
101+
# This is explicitly allowed. Follow-up code will expect this to be the feature-dim though,
102+
# So we mark it accordingly.
103+
assert not data.sparse
104+
axis = data.get_axis_from_description(out_dim)
105+
data.feature_dim_axis = axis
106+
else:
107+
raise Exception("%s not found in %s" % (out_dim, data))
98108
return data
99109
network = src_layers[0].network
100110
cache_key = (tuple(src_layers), out_dim, 0.0, None)

tests/test_TFNetworkLayer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,32 @@ def test_LinearLayer():
136136
session.run(net.get_default_output_layer().output.placeholder, feed_dict=make_feed_dict(net.extern_data))
137137

138138

139+
def test_LinearLayer_in_dim_spatial():
140+
from returnn.tf.util.data import BatchDim
141+
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
142+
static_spatial_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="static-spatial", dimension=3)
143+
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="in-feature", dimension=5)
144+
out_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="out-feature", dimension=7)
145+
config = Config({
146+
"extern_data": {
147+
"data": {"dim_tags": [BatchDim, time_dim, static_spatial_dim, feat_dim]} # [B,T,D1,D2]
148+
}
149+
})
150+
for _ in range(2):
151+
with make_scope() as session:
152+
net = TFNetwork(config=config)
153+
net.construct_from_dict({
154+
"output": {"class": "linear", "from": "data", "in_dim": static_spatial_dim, "out_dim": out_dim}})
155+
layer = net.get_default_output_layer()
156+
print("Output:", layer.output)
157+
assert layer.output.dim_tags_set_implicit == {BatchDim, time_dim, out_dim, feat_dim}
158+
param = layer.params["W"]
159+
assert isinstance(param, tf.Variable)
160+
assert param.shape.as_list() == [static_spatial_dim.dimension, out_dim.dimension]
161+
session.run(tf_compat.v1.global_variables_initializer())
162+
session.run(layer.output.placeholder, feed_dict=make_feed_dict(net.extern_data))
163+
164+
139165
def test_LinearLayer_two_time_dims_allow_broadcast_all_sources():
140166
from returnn.tf.util.data import BatchDim
141167
with make_scope() as session:

0 commit comments

Comments
 (0)