@@ -27,8 +27,10 @@ def cum_concat_step(
27
27
Concatenates all previous frames of a time-axis.
28
28
See RETURNN :class:`CumConcatLayer` for details.
29
29
"""
30
- from ._generated_layers import _cum_concat
31
- return _cum_concat (source = source , state = state , out_spatial_dim = out_spatial_dim , name = name )
30
+ from ._generated_layers import rec_cum_concat
31
+ return rec_cum_concat (
32
+ source = source , axis = nn .single_step_dim ,
33
+ state = state , out_spatial_dim = out_spatial_dim , name = name )
32
34
33
35
34
36
def split (source : nn .LayerRef , * ,
@@ -40,9 +42,16 @@ def split(source: nn.LayerRef, *,
40
42
Basically a wrapper around tf.split.
41
43
"""
42
44
from ._generated_layers import _split
43
- from .base import get_sub_layer
45
+ from .base import _get_sub_layer
44
46
res = _split (source , axis = axis , out_dims = out_dims , name = name )
45
- return tuple (get_sub_layer (res , str (i )) for i in range (len (out_dims )))
47
+ src_axis_int = source .data .get_axis_from_description (axis )
48
+ return tuple (
49
+ _get_sub_layer (
50
+ layer = res , name = str (i ),
51
+ data = source .data .copy_template_replace_dim_tag (
52
+ axis = src_axis_int , new_dim_tag = dim ,
53
+ name = f"{ source .data .name } /split:{ i } :{ dim .description } " ))
54
+ for i , dim in enumerate (out_dims ))
46
55
47
56
48
57
def window (
@@ -57,8 +66,8 @@ def window(
57
66
"""
58
67
Window. See :func:`_generated_layers._window`.
59
68
"""
60
- from ._generated_layers import _window
61
- layer , state = _window (
69
+ from ._generated_layers import rec_window
70
+ layer , state = rec_window (
62
71
source ,
63
72
window_dim = window_dim , window_left = window_left , window_right = window_right ,
64
73
axis = axis , padding = padding , stride = stride ,
@@ -69,7 +78,6 @@ def window(
69
78
70
79
def window_step (
71
80
source : nn .LayerRef , * , state : nn .LayerState ,
72
- axis : nn .Dim ,
73
81
window_dim : nn .Dim ,
74
82
padding : str = NotSpecified ,
75
83
stride : int = NotSpecified ,
@@ -78,9 +86,9 @@ def window_step(
78
86
Window into the past when iterating.
79
87
See :func:`_generated_layers._window`.
80
88
"""
81
- from ._generated_layers import _window
82
- return _window (
89
+ from ._generated_layers import rec_window
90
+ return rec_window (
83
91
source , state = state ,
84
92
window_dim = window_dim , window_left = window_dim .dimension - 1 , window_right = 0 ,
85
- axis = axis , padding = padding , stride = stride ,
93
+ axis = nn . single_step_dim , padding = padding , stride = stride ,
86
94
name = name )
0 commit comments