Skip to content

Commit edfd4f5

Browse files
author
Masaru Kimura
committed
Add torchvision.transforms.Resize interpolation and antialias.
torchvision.transforms.Resize forced nearest interpolation and no antialias, but shouln't. Based on my understanding, original torchvision.transforms.Resize calls like; - torchvision.transforms.Resize - torchvision.transforms.functional.resize - torchvision.transforms._functional_pil.resize - PIL.Image.Image.resize - torchvision.transforms._functional_tensor.resize - torch.nn.functional.interpolate Note, this PR still keeps nearest interpolation and no antialias by default for torchvision.transforms.Resize to maximize compatibility for existing code using TorchSharp and make it being incompatible to original torchvision.transforms.Resize default, however, it would be up to the upstream decision. See also; * https://pytorch.org/vision/main/generated/torchvision.transforms.Resize.html * https://pytorch.org/vision/main/generated/torchvision.transforms.functional.resize.html * https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
1 parent 3760ba3 commit edfd4f5

File tree

8 files changed

+84
-22
lines changed

8 files changed

+84
-22
lines changed

src/Native/LibTorchSharp/THSNN.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ void ApplyInterpolateMode(T& opts, const int8_t mode)
109109
opts = opts.mode(torch::kTrilinear);
110110
if (mode == 5)
111111
opts = opts.mode(torch::kArea);
112+
if (mode == 6)
113+
opts = opts.mode(torch::kNearestExact);
112114
}
113115

114116
template<typename T>
@@ -176,13 +178,14 @@ Tensor THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size
176178
}
177179

178180

179-
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule)
181+
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule)
180182
{
181183
auto opts = torch::nn::functional::InterpolateFuncOptions().recompute_scale_factor(recompute_scale_factor);
182184
// align_corners -- 0=None, 1=true, 2=false
183185
if (align_corners != 0)
184186
opts.align_corners(align_corners == 1);
185187
ApplyInterpolateMode(opts, mode);
188+
opts.antialias(antialias);
186189

187190
if (size_len > 0) {
188191
std::vector<int64_t> sizes;

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ EXPORT_API(Tensor) THSNN_pixel_unshuffle(const Tensor tensor, const int64_t do
7171
// Vision -- Functions
7272

7373
EXPORT_API(Tensor) THSNN_pad(const Tensor input, const int64_t* pad, const int pad_length, const int8_t mode, const double value);
74-
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule);
74+
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule);
7575
EXPORT_API(Tensor) THSNN_grid_sample(const Tensor input, const Tensor grid, const int8_t mode, const int8_t padding_mode, const int8_t align_corners);
7676
EXPORT_API(Tensor) THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size_len, const bool align_corners);
7777

src/TorchSharp/NN/Vision.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ public enum InterpolationMode
2323
Bilinear = 2,
2424
Bicubic = 3,
2525
Trilinear = 4,
26-
Area = 5
26+
Area = 5,
27+
NearestExact = 6
2728
}
2829

2930
public enum GridSampleMode
@@ -194,7 +195,7 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c
194195
/// <param name="x">The input tensor</param>
195196
/// <param name="size">Output spatial size</param>
196197
/// <param name="scale_factor">Multiplier for spatial size. Has to match input size if it is a tuple.</param>
197-
/// <param name="mode">The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'</param>
198+
/// <param name="mode">The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' | 'nearest-exact'</param>
198199
/// <param name="align_corners">Geometrically, we consider the pixels of the input and output as squares rather than points.
199200
/// If set to true, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels.
200201
/// If set to false, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation independent of input size when scale_factor is kept the same.</param>
@@ -205,14 +206,19 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c
205206
/// Otherwise, a new scale_factor will be computed based on the output and input sizes for use in the interpolation computation
206207
/// (i.e. the computation will be identical to if the computed output_size were passed-in explicitly).
207208
/// </param>
209+
/// <param name="antialias">
210+
/// Flag to apply anti-aliasing. Using anti-alias
211+
/// option together with align_corners = false, interpolation result would match Pillow
212+
/// result for downsampling operation. Supported modes: 'bilinear', 'bicubic'.
213+
/// </param>
208214
/// <returns></returns>
209-
public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false)
215+
public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false, bool antialias = false)
210216
{
211217
unsafe {
212218
fixed (long* psize = size) {
213219
fixed (double* pSF = scale_factor) {
214220
byte ac = (byte)((align_corners.HasValue) ? (align_corners.Value ? 1 : 2) : 0);
215-
var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor);
221+
var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor, antialias);
216222
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
217223
return new Tensor(res);
218224
}

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ internal static extern IntPtr THSNN_custom_module(
4444

4545
[DllImport("LibTorchSharp")]
4646
// align_corners -- 0=None, 1=true, 2=false
47-
internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor);
47+
internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor, [MarshalAs(UnmanagedType.U1)] bool antialias);
4848

4949
[DllImport("LibTorchSharp")]
5050
// align_corners -- 0=None, 1=true, 2=false

src/TorchVision/Functional.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,22 @@ public static Tensor posterize(Tensor input, int bits)
694694
/// <param name="input">An image tensor.</param>
695695
/// <param name="height">The height of the resized image. Must be > 0.</param>
696696
/// <param name="width">The width of the resized image. Must be > 0.</param>
697+
/// <param name="interpolation">
698+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
699+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
700+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
701+
/// </param>
697702
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
703+
/// <param name="antialias">
704+
/// Whether to apply antialiasing.
705+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
706+
/// Possible values are:
707+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
708+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
709+
/// </param>
698710
/// <returns></returns>
699-
public static Tensor resize(Tensor input, int height, int width, int? maxSize = null)
711+
public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int ? maxSize = null, bool antialias = false)
700712
{
701-
// For now, we don't allow any other modes.
702-
const InterpolationMode interpolation = InterpolationMode.Nearest;
703-
704713
var hoffset = input.Dimensions - 2;
705714
var iHeight = input.shape[hoffset];
706715
var iWidth = input.shape[hoffset + 1];
@@ -727,9 +736,12 @@ public static Tensor resize(Tensor input, int height, int width, int? maxSize =
727736
}
728737
}
729738

739+
if (antialias && interpolation != InterpolationMode.Bilinear && interpolation != InterpolationMode.Bicubic)
740+
antialias = false;
741+
730742
using var img0 = SqueezeIn(input, new ScalarType[] { ScalarType.Float32, ScalarType.Float64 }, out var needCast, out var needSqueeze, out var dtype);
731743

732-
using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null);
744+
using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null, antialias: antialias);
733745

734746
return SqueezeOut(img1, needCast, needSqueeze, dtype);
735747
}

src/TorchVision/Resize.cs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,24 @@ public static partial class torchvision
88
{
99
internal class Resize : ITransform
1010
{
11-
internal Resize(int height, int width, int? maxSize)
11+
internal Resize(int height, int width, InterpolationMode interpolation, int? maxSize, bool antialias)
1212
{
1313
this.height = height;
1414
this.width = width;
15+
this.interpolation = interpolation;
1516
this.maxSize = maxSize;
17+
this.antialias = antialias;
1618
}
1719

1820
public Tensor call(Tensor input)
1921
{
20-
return transforms.functional.resize(input, height, width, maxSize);
22+
return transforms.functional.resize(input, height, width, interpolation, maxSize, antialias);
2123
}
2224

2325
private int height, width;
26+
private InterpolationMode interpolation;
2427
private int? maxSize;
28+
private bool antialias;
2529
}
2630

2731
public static partial class transforms
@@ -31,20 +35,45 @@ public static partial class transforms
3135
/// </summary>
3236
/// <param name="height">Desired output height</param>
3337
/// <param name="width">Desired output width</param>
38+
/// <param name="interpolation">
39+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
40+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
41+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
42+
/// </param>
43+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
44+
/// <param name="antialias">
45+
/// Whether to apply antialiasing.
46+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
47+
/// Possible values are:
48+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
49+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
50+
/// </param>
3451
/// <returns></returns>
35-
static public ITransform Resize(int height, int width)
52+
static public ITransform Resize(int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false)
3653
{
37-
return new Resize(height, width, null);
54+
return new Resize(height, width, interpolation, maxSize, antialias);
3855
}
3956

4057
/// <summary>
4158
/// Resize the input image to the given size.
4259
/// </summary>
4360
/// <param name="size">Desired output size</param>
44-
/// <param name="maxSize">Max size</param>
45-
static public ITransform Resize(int size, int? maxSize = null)
61+
/// <param name="interpolation">
62+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
63+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
64+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
65+
/// </param>
66+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
67+
/// <param name="antialias">
68+
/// Whether to apply antialiasing.
69+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
70+
/// Possible values are:
71+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
72+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
73+
/// </param>
74+
static public ITransform Resize(int size, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false)
4675
{
47-
return new Resize(size, -1, maxSize);
76+
return new Resize(size, -1, interpolation, maxSize, antialias);
4877
}
4978
}
5079
}

test/TorchSharpTest/NN.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6668,6 +6668,18 @@ public void TestInterpolateTrilinear()
66686668
}
66696669
}
66706670

6671+
[Fact]
6672+
public void TestInterpolateNearestExact()
6673+
{
6674+
foreach (var device in TestUtils.AvailableDevices()) {
6675+
using (Tensor input = torch.arange(1, 5, float32, device: device).view(1, 1, 2, 2))
6676+
using (var res = interpolate(input, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.NearestExact)) {
6677+
Assert.Equal(device.type, res.device_type);
6678+
Assert.Equal(new long[] { 1, 1, 4, 4 }, res.shape);
6679+
}
6680+
}
6681+
}
6682+
66716683
[Fact]
66726684
public void TestUpsampleNearest()
66736685
{

test/TorchSharpTest/TestTorchVision.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ public void Resize_WithSizeAndMaxSize_ReturnsTensor()
938938
int size = 20;
939939
int? maxSize = 30;
940940
var input = torch.randn(1, 3, 256, 256);
941-
var transform = Resize(size, maxSize);
941+
var transform = Resize(size, maxSize: maxSize);
942942

943943
//Act
944944
var result = transform.call(input);
@@ -1345,7 +1345,7 @@ public void Resize_WhenMaxSizeNotMet_ThrowsArgumentException()
13451345
int? maxSize = 8;
13461346

13471347
// Act + Assert
1348-
Assert.Throws<System.ArgumentException>(() => functional.resize(input, height, -1, maxSize));
1348+
Assert.Throws<System.ArgumentException>(() => functional.resize(input, height, -1, maxSize: maxSize));
13491349
}
13501350

13511351
[Fact]
@@ -1357,7 +1357,7 @@ public void Resize_WhenMaxSizeMet_DoesNotThrowException()
13571357
int? maxSize = 10;
13581358

13591359
// Act + Assert
1360-
functional.resize(input, height, -1, maxSize);
1360+
functional.resize(input, height, -1, maxSize: maxSize);
13611361
}
13621362

13631363

0 commit comments

Comments
 (0)