@@ -303,6 +303,12 @@ class LearnedRelativePositionalEncoding(nn.Module):
303
303
"""
304
304
305
305
def __init__ (self , feat_dim : nn .Dim , * , clipping : int = 16 , dtype : str = "float32" ):
306
+ """
307
+ :param feat_dim: feature dim, for the emb matrix and output
308
+ :param clipping: max distance to consider. emb matrix shape is [2 * clipping + 1, feat_dim].
309
+ The first and last frame will be the clipping frames.
310
+ :param dtype: for the emb matrix and output
311
+ """
306
312
super (LearnedRelativePositionalEncoding , self ).__init__ ()
307
313
self .feat_dim = feat_dim
308
314
self .clipping = clipping
@@ -319,11 +325,12 @@ def __call__(self, spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
319
325
In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.
320
326
"""
321
327
out_spatial_dim = spatial_dim - 1 + spatial_dim
322
- with nn .Cond (nn .dim_value (spatial_dim ) > self .clipping ) as cond :
328
+ mat_spatial_size = self .clipping + 1
329
+ with nn .Cond (nn .dim_value (spatial_dim ) > mat_spatial_size ) as cond :
323
330
# True branch
324
331
left = nn .gather (self .pos_emb , axis = self .clipped_spatial_dim , position = 0 )
325
332
right = nn .gather (self .pos_emb , axis = self .clipped_spatial_dim , position = self .clipped_spatial_dim .dimension - 1 )
326
- remaining_dim = spatial_dim - self . clipping
333
+ remaining_dim = spatial_dim - mat_spatial_size
327
334
left = nn .expand_dim (left , dim = remaining_dim )
328
335
right = nn .expand_dim (right , dim = remaining_dim )
329
336
cond .true , out_spatial_dim_ = nn .concat (
@@ -335,7 +342,7 @@ def __call__(self, spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
335
342
# False branch, spatial_dim <= self.clipping
336
343
cond .false , _ = nn .slice_nd (
337
344
self .pos_emb , axis = self .clipped_spatial_dim ,
338
- start = self . clipping - nn .dim_value (spatial_dim ),
345
+ start = mat_spatial_size - nn .dim_value (spatial_dim ),
339
346
size = out_spatial_dim )
340
347
341
348
return cond .result , out_spatial_dim
0 commit comments