Skip to content

Commit 5e223b2

Browse files
committed
LearnedRelativePositionalEncoding, fix out shape
1 parent 225b5b7 commit 5e223b2

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

nn/attention.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,12 @@ class LearnedRelativePositionalEncoding(nn.Module):
303303
"""
304304

305305
def __init__(self, feat_dim: nn.Dim, *, clipping: int = 16, dtype: str = "float32"):
306+
"""
307+
:param feat_dim: feature dim, for the emb matrix and output
308+
:param clipping: max distance to consider. emb matrix shape is [2 * clipping + 1, feat_dim].
309+
The first and last frame will be the clipping frames.
310+
:param dtype: for the emb matrix and output
311+
"""
306312
super(LearnedRelativePositionalEncoding, self).__init__()
307313
self.feat_dim = feat_dim
308314
self.clipping = clipping
@@ -319,11 +325,12 @@ def __call__(self, spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
319325
In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.
320326
"""
321327
out_spatial_dim = spatial_dim - 1 + spatial_dim
322-
with nn.Cond(nn.dim_value(spatial_dim) > self.clipping) as cond:
328+
mat_spatial_size = self.clipping + 1
329+
with nn.Cond(nn.dim_value(spatial_dim) > mat_spatial_size) as cond:
323330
# True branch
324331
left = nn.gather(self.pos_emb, axis=self.clipped_spatial_dim, position=0)
325332
right = nn.gather(self.pos_emb, axis=self.clipped_spatial_dim, position=self.clipped_spatial_dim.dimension - 1)
326-
remaining_dim = spatial_dim - self.clipping
333+
remaining_dim = spatial_dim - mat_spatial_size
327334
left = nn.expand_dim(left, dim=remaining_dim)
328335
right = nn.expand_dim(right, dim=remaining_dim)
329336
cond.true, out_spatial_dim_ = nn.concat(
@@ -335,7 +342,7 @@ def __call__(self, spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
335342
# False branch, spatial_dim <= self.clipping
336343
cond.false, _ = nn.slice_nd(
337344
self.pos_emb, axis=self.clipped_spatial_dim,
338-
start=self.clipping - nn.dim_value(spatial_dim),
345+
start=mat_spatial_size - nn.dim_value(spatial_dim),
339346
size=out_spatial_dim)
340347

341348
return cond.result, out_spatial_dim

0 commit comments

Comments
 (0)