Skip to content

Commit 6be2a5c

Browse files
committed
test_ScatterNdLayer_RangeLayer_RangeInAxisLayer fix
1 parent 618b708 commit 6be2a5c

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2473,8 +2473,8 @@ def test_TileLayer():
24732473
def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer():
24742474
from returnn.tf.util.data import BatchDim, DimensionTag, ImplicitDynSizeDim
24752475
n_batch, n_time, n_ts, n_in, n_out = 2, 3, 6, 7, 11
2476-
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="T")
2477-
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="F", dimension=n_in)
2476+
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
2477+
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="in-feature", dimension=n_in)
24782478
ts_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="ts", dimension=n_ts)
24792479
rnd = numpy.random.RandomState(42)
24802480
config = Config({
@@ -2489,9 +2489,9 @@ def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer():
24892489
"add_t": {
24902490
"class": "combine", "kind": "add", "from": ["t", "range"],
24912491
"out_shape": {time_dim, ts_dim, ImplicitDynSizeDim(BatchDim)}}, # (T,Ts)
2492-
"t_rel_var": {"class": "variable", "shape": (n_ts, n_out), "init": "glorot_uniform"}, # (B,Ts,D)
2493-
"output": {"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": -1,
2494-
"output_dim_via_time_from": "data", "filter_invalid_indices": True}
2492+
"t_rel_var": {"class": "variable", "shape": (ts_dim, n_out), "init": "glorot_uniform"}, # (Ts,D)
2493+
"output": {"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": ts_dim,
2494+
"output_dim_via_time_from": "data", "filter_invalid_indices": True} # (T,T,D)
24952495
}
24962496
with make_scope() as session:
24972497
network = TFNetwork(config=config, train_flag=True)
@@ -2506,7 +2506,7 @@ def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer():
25062506
out_layer = network.get_default_output_layer()
25072507
assert isinstance(out_layer, ScatterNdLayer)
25082508
assert out_layer.output.shape == (None, None, 11)
2509-
assert out_layer.output.feature_dim_axis_or_unspecified is NotSpecified and out_layer.output.feature_dim_axis == 3
2509+
assert out_layer.output.feature_dim_axis_or_unspecified is NotSpecified and out_layer.output.feature_dim_axis == 2
25102510
assert out_layer.output.time_dim_axis == 0
25112511

25122512
session.run(tf_compat.v1.variables_initializer(tf_compat.v1.global_variables() + [network.global_train_step]))

0 commit comments

Comments
 (0)