Skip to content

Commit a870004

Browse files
committed
test_relative_positional_encoding
1 parent 51df3ba commit a870004

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_nn_attention.py

+11
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
3232
config, net_dict, net = dummy_config_net_dict(_Net, with_axis=True)
3333
pprint(net_dict)
3434
dummy_run_net(config, net=net)
35+
36+
37+
def test_relative_positional_encoding():
38+
class _Net(nn.Module):
39+
def __call__(self, x: nn.Tensor, *, axis: nn.Dim) -> nn.Tensor:
40+
x, _ = nn.relative_positional_encoding(axis, x.feature_dim)
41+
return x
42+
43+
config, net_dict, net = dummy_config_net_dict(_Net, with_axis=True, in_dim=nn.FeatureDim("in", 12))
44+
pprint(net_dict)
45+
dummy_run_net(config, net=net)

0 commit comments

Comments
 (0)