Skip to content

Commit 225b5b7

Browse files
committed
test_rel_pos_self_attention_learnable, test_learned_rel_pos_enc
1 parent 41b0714 commit 225b5b7

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

tests/test_nn_attention.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from __future__ import annotations
66

77
from . import _setup_test_env # noqa
8-
from .returnn_helpers import dummy_run_net, dummy_config_net_dict, dummy_default_in_dim
8+
from .returnn_helpers import dummy_run_net, dummy_config_net_dict, dummy_default_in_dim, dummy_run_net_single_custom, \
9+
make_feed_dict
910
from pprint import pprint
1011
import typing
12+
import functools
1113

1214
if typing.TYPE_CHECKING:
1315
from .. import nn
@@ -48,7 +50,7 @@ def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
4850
def test_rel_pos_self_attention():
4951
class _Net(nn.Module):
5052
# noinspection PyShadowingNames
51-
def __init__(self, in_dim: nn.FeatureDim):
53+
def __init__(self, in_dim: nn.Dim):
5254
super().__init__()
5355
self.self_att = nn.RelPosSelfAttention(
5456
in_dim=in_dim, proj_dim=nn.FeatureDim("out", 5),
@@ -64,3 +66,66 @@ def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
6466
config, net_dict, net = dummy_config_net_dict(lambda: _Net(in_dim), with_axis=True, in_dim=in_dim)
6567
pprint(net_dict)
6668
dummy_run_net(config, net=net)
69+
70+
71+
def test_rel_pos_self_attention_learnable():
72+
class _Net(nn.Module):
73+
# noinspection PyShadowingNames
74+
def __init__(self, in_dim: nn.Dim):
75+
super().__init__()
76+
self.self_att = nn.RelPosSelfAttention(
77+
in_dim=in_dim, proj_dim=nn.FeatureDim("out", 5),
78+
key_dim_total=nn.FeatureDim("key-dim-total", 21),
79+
value_dim_total=nn.FeatureDim("value-dim-total", 33),
80+
num_heads=3,
81+
# Shawn et al 2018 style, old RETURNN way.
82+
with_bias=False,
83+
with_linear_pos=False,
84+
with_pos_bias=False,
85+
learnable_pos_emb=True,
86+
learnable_pos_emb_clipping=3,
87+
separate_pos_emb_per_head=False,
88+
)
89+
90+
def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
91+
"""forward"""
92+
return self.self_att(x, axis=axis)
93+
94+
in_dim = nn.FeatureDim("in", 12)
95+
config, net_dict, net = dummy_config_net_dict(lambda: _Net(in_dim), with_axis=True, in_dim=in_dim)
96+
pprint(net_dict)
97+
dummy_run_net(config, net=net, seq_len=3) # ok
98+
dummy_run_net(config, net=net, seq_len=3) # try again, to see that running again is ok.
99+
dummy_run_net(config, net=net, seq_len=1) # ok
100+
dummy_run_net(config, net=net, seq_len=4) # problem currently...
101+
102+
103+
def test_learned_rel_pos_enc():
104+
class _Net(nn.Module):
105+
# noinspection PyShadowingNames
106+
def __init__(self, in_dim: nn.Dim):
107+
super().__init__()
108+
self.in_dim = in_dim
109+
self.self_att = nn.LearnedRelativePositionalEncoding(in_dim, clipping=3)
110+
111+
def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
112+
y, _ = self.self_att(axis)
113+
print("y:", y)
114+
return y + nn.reduce(x, axis=(axis, nn.batch_dim), mode="mean")
115+
116+
nn.reset_default_root_name_ctx()
117+
net = _Net(in_dim=nn.FeatureDim("in", 12))
118+
time_dim = nn.SpatialDim("time")
119+
data = nn.get_extern_data(nn.Data("data", dim_tags=[nn.batch_dim, time_dim, net.in_dim]))
120+
out = net(data, axis=time_dim)
121+
out.mark_as_default_output()
122+
123+
config_code_str = nn.get_returnn_config().get_complete_py_code_str(net)
124+
print(config_code_str)
125+
126+
for seq_len in [1, 2, 3, 4, 5]:
127+
res = dummy_run_net_single_custom(
128+
config_code_str, make_feed_dict=functools.partial(make_feed_dict, n_time=seq_len))
129+
shape = res["layer:output"].shape
130+
print("res shape:", shape)
131+
assert shape == (2 * seq_len - 1, net.in_dim.dimension)

0 commit comments

Comments
 (0)