Skip to content

Commit 56fbd19

Browse files
committed
test_nn_transformer_search
Also related to search flag: #18
1 parent 95c201d commit 56fbd19

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/test_nn_transformer.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Test nn.transformer.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from . import _setup_test_env # noqa
8+
from .returnn_helpers import dummy_run_net, config_net_dict_via_serialized
9+
import typing
10+
11+
if typing.TYPE_CHECKING:
12+
from .. import nn
13+
else:
14+
from returnn_common import nn # noqa
15+
16+
17+
def test_nn_transformer_search():
18+
with nn.NameCtx.new_root() as name_ctx:
19+
time_dim = nn.SpatialDim("time")
20+
input_dim = nn.FeatureDim("input", 4)
21+
data = nn.get_extern_data(nn.Data("data", dim_tags=[nn.batch_dim, time_dim, input_dim]))
22+
transformer = nn.Transformer()
23+
out, _ = transformer(data, source_spatial_axis=time_dim, search=True, beam_size=3, eos_symbol=0, name=name_ctx)
24+
out.mark_as_default_output()
25+
26+
config_code = name_ctx.get_returnn_config_serialized()
27+
config, net_dict = config_net_dict_via_serialized(config_code)
28+
dummy_run_net(config)

0 commit comments

Comments
 (0)