diff --git a/modules/interpolator.py b/modules/interpolator.py index 699c1b8..d7594aa 100644 --- a/modules/interpolator.py +++ b/modules/interpolator.py @@ -29,5 +29,5 @@ def forward(self, x, pos, H, W): [B, N, C] sampled channels at 2d positions """ grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype) - x = F.grid_sample(x, grid, mode = self.mode , align_corners = False) + x = F.grid_sample(x, grid, mode = self.mode , align_corners = self.align_corners) return x.permute(0,2,3,1).squeeze(-2) \ No newline at end of file