Skip to content

Commit 6e8de68

Browse files
committed
RF test_compare_bc, test_logical_or
1 parent d0c6e2b commit 6e8de68

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tests/test_rf_math.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
"""
44

55
from __future__ import annotations
6+
67
import _setup_test_env # noqa
8+
9+
import sys
10+
import unittest
11+
12+
from returnn.util import better_exchook
713
import returnn.frontend as rf
814
from returnn.tensor import Tensor, Dim, TensorDict, batch_dim
915
from rf_utils import run_model
@@ -30,6 +36,39 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
3036
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
3137

3238

39+
def test_compare_bc():
40+
beam_dim = Dim(3, name="beam")
41+
in_dim = Dim(7, name="in")
42+
extern_data = TensorDict({"idx": Tensor("idx", [batch_dim, beam_dim], dtype="int32", sparse_dim=in_dim)})
43+
44+
# noinspection PyShadowingNames,PyUnusedLocal
45+
def _forward_step(*, model: rf.Module, extern_data: TensorDict):
46+
idx = extern_data["idx"]
47+
cond = rf.compare_bc(idx, "!=", rf.range_over_dim(in_dim))
48+
cond.mark_as_default_output(shape=(batch_dim, beam_dim, in_dim))
49+
50+
run_model(extern_data, lambda *, epoch, step: rf.Module(), _forward_step)
51+
52+
53+
def test_logical_or():
54+
beam_dim = Dim(3, name="beam")
55+
in_dim = Dim(7, name="in")
56+
extern_data = TensorDict(
57+
{
58+
"a": Tensor("a", [batch_dim, beam_dim], dtype="bool"),
59+
"b": Tensor("b", [batch_dim, beam_dim, in_dim], dtype="bool"),
60+
}
61+
)
62+
63+
# noinspection PyShadowingNames,PyUnusedLocal
64+
def _forward_step(*, model: rf.Module, extern_data: TensorDict):
65+
a, b = extern_data["a"], extern_data["b"]
66+
cond = a | b
67+
cond.mark_as_default_output(shape=(batch_dim, beam_dim, in_dim))
68+
69+
run_model(extern_data, lambda *, epoch, step: rf.Module(), _forward_step)
70+
71+
3372
def test_squared_difference():
3473
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
3574
in_dim = Dim(7, name="in")
@@ -125,3 +164,26 @@ def _forward_step(*, model: rf.Module, extern_data: TensorDict):
125164
out.mark_as_default_output(shape=(batch_dim, time_dim))
126165

127166
run_model(extern_data, lambda *, epoch, step: rf.Module(), _forward_step)
167+
168+
169+
if __name__ == "__main__":
170+
better_exchook.install()
171+
if len(sys.argv) <= 1:
172+
for k, v in sorted(globals().items()):
173+
if k.startswith("test_"):
174+
print("-" * 40)
175+
print("Executing: %s" % k)
176+
try:
177+
v()
178+
except unittest.SkipTest as exc:
179+
print("SkipTest:", exc)
180+
print("-" * 40)
181+
print("Finished all tests.")
182+
else:
183+
assert len(sys.argv) >= 2
184+
for arg in sys.argv[1:]:
185+
print("Executing: %s" % arg)
186+
if arg in globals():
187+
globals()[arg]() # assume function and execute
188+
else:
189+
eval(arg) # assume Python code and execute

0 commit comments

Comments
 (0)