diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index c386175284..fe2230f0d9 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7246,7 +7246,7 @@ class VariableLayer(LayerBase): """ layer_class = "variable" - def __init__(self, shape, dtype="float32", add_batch_axis=True, add_time_axis=False, trainable=True, + def __init__(self, shape, dtype="float32", add_batch_axis=False, add_time_axis=False, trainable=True, init=0, **kwargs): """ @@ -7298,7 +7298,7 @@ def transform_config_dict(cls, d, network, get_layer): @classmethod def get_out_data_from_opts(cls, name, network, - shape, dtype="float32", add_batch_axis=True, add_time_axis=False, **kwargs): + shape, dtype="float32", add_batch_axis=False, add_time_axis=False, **kwargs): """ :param str name: :param returnn.tf.network.TFNetwork network: diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 1833d74aba..f5b37abdbd 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -2473,8 +2473,8 @@ def test_TileLayer(): def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer(): from returnn.tf.util.data import BatchDim, DimensionTag, ImplicitDynSizeDim n_batch, n_time, n_ts, n_in, n_out = 2, 3, 6, 7, 11 - time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="T") - feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="F", dimension=n_in) + time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time") + feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="in-feature", dimension=n_in) ts_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="ts", dimension=n_ts) rnd = numpy.random.RandomState(42) config = Config({ @@ -2489,9 +2489,9 @@ def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer(): "add_t": { "class": "combine", "kind": "add", "from": ["t", "range"], "out_shape": {time_dim, ts_dim, ImplicitDynSizeDim(BatchDim)}}, # (T,Ts) - "t_rel_var": {"class": "variable", "shape": (n_ts, n_out), "init": "glorot_uniform"}, # (B,Ts,D) - "output": {"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": -1, - "output_dim_via_time_from": "data", "filter_invalid_indices": True} + "t_rel_var": {"class": "variable", "shape": (ts_dim, n_out), "init": "glorot_uniform"}, # (Ts,D) + "output": {"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": ts_dim, + "output_dim_via_time_from": "data", "filter_invalid_indices": True} # (T,T,D) } with make_scope() as session: network = TFNetwork(config=config, train_flag=True) @@ -2506,7 +2506,7 @@ def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer(): out_layer = network.get_default_output_layer() assert isinstance(out_layer, ScatterNdLayer) assert out_layer.output.shape == (None, None, 11) - assert out_layer.output.feature_dim_axis_or_unspecified is NotSpecified and out_layer.output.feature_dim_axis == 3 + assert out_layer.output.feature_dim_axis_or_unspecified is NotSpecified and out_layer.output.feature_dim_axis == 2 assert out_layer.output.time_dim_axis == 0 session.run(tf_compat.v1.variables_initializer(tf_compat.v1.global_variables() + [network.global_train_step]))