Skip to content

Commit 574e48a

Browse files
committed
explicit parameter handling
Fix #82
1 parent bfb948f commit 574e48a

File tree

5 files changed

+320
-126
lines changed

5 files changed

+320
-126
lines changed

nn/array_.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def cum_concat_step(
2727
Concatenates all previous frames of a time-axis.
2828
See RETURNN :class:`CumConcatLayer` for details.
2929
"""
30-
from ._generated_layers import _cum_concat
31-
return _cum_concat(source=source, state=state, out_spatial_dim=out_spatial_dim, name=name)
30+
from ._generated_layers import rec_cum_concat
31+
return rec_cum_concat(
32+
source=source, axis=nn.single_step_dim,
33+
state=state, out_spatial_dim=out_spatial_dim, name=name)
3234

3335

3436
def split(source: nn.LayerRef, *,
@@ -40,9 +42,16 @@ def split(source: nn.LayerRef, *,
4042
Basically a wrapper around tf.split.
4143
"""
4244
from ._generated_layers import _split
43-
from .base import get_sub_layer
45+
from .base import _get_sub_layer
4446
res = _split(source, axis=axis, out_dims=out_dims, name=name)
45-
return tuple(get_sub_layer(res, str(i)) for i in range(len(out_dims)))
47+
src_axis_int = source.data.get_axis_from_description(axis)
48+
return tuple(
49+
_get_sub_layer(
50+
layer=res, name=str(i),
51+
data=source.data.copy_template_replace_dim_tag(
52+
axis=src_axis_int, new_dim_tag=dim,
53+
name=f"{source.data.name}/split:{i}:{dim.description}"))
54+
for i, dim in enumerate(out_dims))
4655

4756

4857
def window(
@@ -57,8 +66,8 @@ def window(
5766
"""
5867
Window. See :func:`_generated_layers._window`.
5968
"""
60-
from ._generated_layers import _window
61-
layer, state = _window(
69+
from ._generated_layers import rec_window
70+
layer, state = rec_window(
6271
source,
6372
window_dim=window_dim, window_left=window_left, window_right=window_right,
6473
axis=axis, padding=padding, stride=stride,
@@ -69,7 +78,6 @@ def window(
6978

7079
def window_step(
7180
source: nn.LayerRef, *, state: nn.LayerState,
72-
axis: nn.Dim,
7381
window_dim: nn.Dim,
7482
padding: str = NotSpecified,
7583
stride: int = NotSpecified,
@@ -78,9 +86,9 @@ def window_step(
7886
Window into the past when iterating.
7987
See :func:`_generated_layers._window`.
8088
"""
81-
from ._generated_layers import _window
82-
return _window(
89+
from ._generated_layers import rec_window
90+
return rec_window(
8391
source, state=state,
8492
window_dim=window_dim, window_left=window_dim.dimension - 1, window_right=0,
85-
axis=axis, padding=padding, stride=stride,
93+
axis=nn.single_step_dim, padding=padding, stride=stride,
8694
name=name)

0 commit comments

Comments
 (0)