3
3
"""
4
4
5
5
from __future__ import annotations
6
+
6
7
import _setup_test_env # noqa
8
+
9
+ import sys
10
+ import unittest
11
+
12
+ from returnn .util import better_exchook
7
13
import returnn .frontend as rf
8
14
from returnn .tensor import Tensor , Dim , TensorDict , batch_dim
9
15
from rf_utils import run_model
@@ -30,6 +36,39 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
30
36
run_model (extern_data , lambda * , epoch , step : _Net (), _forward_step )
31
37
32
38
39
+ def test_compare_bc ():
40
+ beam_dim = Dim (3 , name = "beam" )
41
+ in_dim = Dim (7 , name = "in" )
42
+ extern_data = TensorDict ({"idx" : Tensor ("idx" , [batch_dim , beam_dim ], dtype = "int32" , sparse_dim = in_dim )})
43
+
44
+ # noinspection PyShadowingNames,PyUnusedLocal
45
+ def _forward_step (* , model : rf .Module , extern_data : TensorDict ):
46
+ idx = extern_data ["idx" ]
47
+ cond = rf .compare_bc (idx , "!=" , rf .range_over_dim (in_dim ))
48
+ cond .mark_as_default_output (shape = (batch_dim , beam_dim , in_dim ))
49
+
50
+ run_model (extern_data , lambda * , epoch , step : rf .Module (), _forward_step )
51
+
52
+
53
+ def test_logical_or ():
54
+ beam_dim = Dim (3 , name = "beam" )
55
+ in_dim = Dim (7 , name = "in" )
56
+ extern_data = TensorDict (
57
+ {
58
+ "a" : Tensor ("a" , [batch_dim , beam_dim ], dtype = "bool" ),
59
+ "b" : Tensor ("b" , [batch_dim , beam_dim , in_dim ], dtype = "bool" ),
60
+ }
61
+ )
62
+
63
+ # noinspection PyShadowingNames,PyUnusedLocal
64
+ def _forward_step (* , model : rf .Module , extern_data : TensorDict ):
65
+ a , b = extern_data ["a" ], extern_data ["b" ]
66
+ cond = a | b
67
+ cond .mark_as_default_output (shape = (batch_dim , beam_dim , in_dim ))
68
+
69
+ run_model (extern_data , lambda * , epoch , step : rf .Module (), _forward_step )
70
+
71
+
33
72
def test_squared_difference ():
34
73
time_dim = Dim (Tensor ("time" , [batch_dim ], dtype = "int32" ))
35
74
in_dim = Dim (7 , name = "in" )
@@ -125,3 +164,26 @@ def _forward_step(*, model: rf.Module, extern_data: TensorDict):
125
164
out .mark_as_default_output (shape = (batch_dim , time_dim ))
126
165
127
166
run_model (extern_data , lambda * , epoch , step : rf .Module (), _forward_step )
167
+
168
+
169
+ if __name__ == "__main__" :
170
+ better_exchook .install ()
171
+ if len (sys .argv ) <= 1 :
172
+ for k , v in sorted (globals ().items ()):
173
+ if k .startswith ("test_" ):
174
+ print ("-" * 40 )
175
+ print ("Executing: %s" % k )
176
+ try :
177
+ v ()
178
+ except unittest .SkipTest as exc :
179
+ print ("SkipTest:" , exc )
180
+ print ("-" * 40 )
181
+ print ("Finished all tests." )
182
+ else :
183
+ assert len (sys .argv ) >= 2
184
+ for arg in sys .argv [1 :]:
185
+ print ("Executing: %s" % arg )
186
+ if arg in globals ():
187
+ globals ()[arg ]() # assume function and execute
188
+ else :
189
+ eval (arg ) # assume Python code and execute
0 commit comments