@@ -2473,8 +2473,8 @@ def test_TileLayer():
2473
2473
def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer ():
2474
2474
from returnn .tf .util .data import BatchDim , DimensionTag , ImplicitDynSizeDim
2475
2475
n_batch , n_time , n_ts , n_in , n_out = 2 , 3 , 6 , 7 , 11
2476
- time_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "T " )
2477
- feat_dim = DimensionTag (kind = DimensionTag .Types .Feature , description = "F " , dimension = n_in )
2476
+ time_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "time " )
2477
+ feat_dim = DimensionTag (kind = DimensionTag .Types .Feature , description = "in-feature " , dimension = n_in )
2478
2478
ts_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "ts" , dimension = n_ts )
2479
2479
rnd = numpy .random .RandomState (42 )
2480
2480
config = Config ({
@@ -2489,9 +2489,9 @@ def test_ScatterNdLayer_RangeLayer_RangeInAxisLayer():
2489
2489
"add_t" : {
2490
2490
"class" : "combine" , "kind" : "add" , "from" : ["t" , "range" ],
2491
2491
"out_shape" : {time_dim , ts_dim , ImplicitDynSizeDim (BatchDim )}}, # (T,Ts)
2492
- "t_rel_var" : {"class" : "variable" , "shape" : (n_ts , n_out ), "init" : "glorot_uniform" }, # (B, Ts,D)
2493
- "output" : {"class" : "scatter_nd" , "from" : "t_rel_var" , "position" : "add_t" , "position_axis" : - 1 ,
2494
- "output_dim_via_time_from" : "data" , "filter_invalid_indices" : True }
2492
+ "t_rel_var" : {"class" : "variable" , "shape" : (ts_dim , n_out ), "init" : "glorot_uniform" }, # (Ts,D)
2493
+ "output" : {"class" : "scatter_nd" , "from" : "t_rel_var" , "position" : "add_t" , "position_axis" : ts_dim ,
2494
+ "output_dim_via_time_from" : "data" , "filter_invalid_indices" : True } # (T,T,Ts,D)
2495
2495
}
2496
2496
with make_scope () as session :
2497
2497
network = TFNetwork (config = config , train_flag = True )
0 commit comments