5
5
from __future__ import annotations
6
6
7
7
from . import _setup_test_env # noqa
8
- from .returnn_helpers import dummy_run_net , dummy_config_net_dict , dummy_default_in_dim
8
+ from .returnn_helpers import dummy_run_net , dummy_config_net_dict , dummy_default_in_dim , dummy_run_net_single_custom , \
9
+ make_feed_dict
9
10
from pprint import pprint
10
11
import typing
12
+ import functools
11
13
12
14
if typing .TYPE_CHECKING :
13
15
from .. import nn
@@ -48,7 +50,7 @@ def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
48
50
def test_rel_pos_self_attention ():
49
51
class _Net (nn .Module ):
50
52
# noinspection PyShadowingNames
51
- def __init__ (self , in_dim : nn .FeatureDim ):
53
+ def __init__ (self , in_dim : nn .Dim ):
52
54
super ().__init__ ()
53
55
self .self_att = nn .RelPosSelfAttention (
54
56
in_dim = in_dim , proj_dim = nn .FeatureDim ("out" , 5 ),
@@ -64,3 +66,66 @@ def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
64
66
config , net_dict , net = dummy_config_net_dict (lambda : _Net (in_dim ), with_axis = True , in_dim = in_dim )
65
67
pprint (net_dict )
66
68
dummy_run_net (config , net = net )
69
+
70
+
71
+ def test_rel_pos_self_attention_learnable ():
72
+ class _Net (nn .Module ):
73
+ # noinspection PyShadowingNames
74
+ def __init__ (self , in_dim : nn .Dim ):
75
+ super ().__init__ ()
76
+ self .self_att = nn .RelPosSelfAttention (
77
+ in_dim = in_dim , proj_dim = nn .FeatureDim ("out" , 5 ),
78
+ key_dim_total = nn .FeatureDim ("key-dim-total" , 21 ),
79
+ value_dim_total = nn .FeatureDim ("value-dim-total" , 33 ),
80
+ num_heads = 3 ,
81
+ # Shawn et al 2018 style, old RETURNN way.
82
+ with_bias = False ,
83
+ with_linear_pos = False ,
84
+ with_pos_bias = False ,
85
+ learnable_pos_emb = True ,
86
+ learnable_pos_emb_clipping = 3 ,
87
+ separate_pos_emb_per_head = False ,
88
+ )
89
+
90
+ def __call__ (self , x : nn .Tensor , * , axis : nn .Dim ) -> nn .Tensor :
91
+ """forward"""
92
+ return self .self_att (x , axis = axis )
93
+
94
+ in_dim = nn .FeatureDim ("in" , 12 )
95
+ config , net_dict , net = dummy_config_net_dict (lambda : _Net (in_dim ), with_axis = True , in_dim = in_dim )
96
+ pprint (net_dict )
97
+ dummy_run_net (config , net = net , seq_len = 3 ) # ok
98
+ dummy_run_net (config , net = net , seq_len = 3 ) # try again, to see that running again is ok.
99
+ dummy_run_net (config , net = net , seq_len = 1 ) # ok
100
+ dummy_run_net (config , net = net , seq_len = 4 ) # problem currently...
101
+
102
+
103
+ def test_learned_rel_pos_enc ():
104
+ class _Net (nn .Module ):
105
+ # noinspection PyShadowingNames
106
+ def __init__ (self , in_dim : nn .Dim ):
107
+ super ().__init__ ()
108
+ self .in_dim = in_dim
109
+ self .self_att = nn .LearnedRelativePositionalEncoding (in_dim , clipping = 3 )
110
+
111
+ def __call__ (self , x : nn .Tensor , * , axis : nn .Dim ) -> nn .Tensor :
112
+ y , _ = self .self_att (axis )
113
+ print ("y:" , y )
114
+ return y + nn .reduce (x , axis = (axis , nn .batch_dim ), mode = "mean" )
115
+
116
+ nn .reset_default_root_name_ctx ()
117
+ net = _Net (in_dim = nn .FeatureDim ("in" , 12 ))
118
+ time_dim = nn .SpatialDim ("time" )
119
+ data = nn .get_extern_data (nn .Data ("data" , dim_tags = [nn .batch_dim , time_dim , net .in_dim ]))
120
+ out = net (data , axis = time_dim )
121
+ out .mark_as_default_output ()
122
+
123
+ config_code_str = nn .get_returnn_config ().get_complete_py_code_str (net )
124
+ print (config_code_str )
125
+
126
+ for seq_len in [1 , 2 , 3 , 4 , 5 ]:
127
+ res = dummy_run_net_single_custom (
128
+ config_code_str , make_feed_dict = functools .partial (make_feed_dict , n_time = seq_len ))
129
+ shape = res ["layer:output" ].shape
130
+ print ("res shape:" , shape )
131
+ assert shape == (2 * seq_len - 1 , net .in_dim .dimension )
0 commit comments