diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index b26d4715f..d3b660c7b 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -24,6 +24,35 @@ public static partial class transforms { public static partial class functional { + + private static bool IsTensorImage(Tensor img) + { + return img.ndim >= 2; + } + + private static bool AssertTensorImage(Tensor img) + { + if (!IsTensorImage(img)) + throw new ArgumentException("Tensor is not a torch image."); + return true; + } + + /// + /// Returns the number of channels of an image. + /// + /// (Tensor) – The image to be checked. + /// The number of channels. + public static long get_image_num_channels(Tensor img) + { + AssertTensorImage(img); + var ndim_ = img.ndim; + return ndim_ switch { + 2 => 1, + > 2 => img.shape[ndim_ - 3], + _ => throw new ArgumentException($"Input ndim should be 2 or more. Got {ndim_}"), + }; + } + /// /// Get the image dimensions /// @@ -533,20 +562,29 @@ public static Tensor invert(Tensor input) /// An image tensor. /// Sequence of means for each channel. /// Sequence of standard deviations for each channel. - /// Bool to make this operation inplace. + /// Bool to make this operation inplace. /// - public static Tensor normalize(Tensor input, double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32) + public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bool inplace = false) { - if (means.Length != stdevs.Length) - throw new ArgumentException("means and stdevs must be the same length in call to Normalize"); - if (means.Length != input.shape[1]) - throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); - - using var mean = means.ToTensor(new long[] { 1, means.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW - using var stdev = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW - using var t0 = input - mean; - - return t0 / stdev; + using var _ = NewDisposeScope(); + AssertTensorImage(input); + if (!input.is_floating_point()) + throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}."); + if (input.ndim < 3) + throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ({string.Join(", ", input.shape)})"); + if (!inplace) + input = input.clone(); + + + var mean = as_tensor(means, dtype: input.dtype, device: input.device); + var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device); + if ((stdev == 0).any().item()) + throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero."); + if (mean.ndim == 1) + mean = mean.view(-1, 1, 1); + if (stdev.ndim == 1) + stdev = stdev.view(-1, 1, 1); + return input.sub_(mean).div_(stdev).MoveToOuterDisposeScope(); } private static Tensor _pad(Tensor input, ReadOnlySpan padding, double fill = 0, PaddingModes padding_mode = PaddingModes.Constant) diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index 27ec1bb44..6bcb597b9 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -7,62 +7,32 @@ namespace TorchSharp { public static partial class torchvision { - internal class Normalize : ITransform, IDisposable + internal class Normalize : ITransform { - internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null) + internal Normalize(double[] means, double[] stdevs,bool inplace = false) { if (means is null) throw new ArgumentNullException(nameof(means)); if (stdevs is null) throw new ArgumentNullException(nameof(stdevs)); if (means.Length != stdevs.Length) throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize"); - if (means.Length != 1 && means.Length != 3) - throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long"); + this.means = means; + this.stdevs = stdevs; + this.inplace = inplace; - this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW - this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW - - if (dtype != ScalarType.Float64) { - this.means = this.means.to_type(dtype); - this.stdevs = this.stdevs.to_type(dtype); - } - - if (device != null && device.type != DeviceType.CPU) { - this.means = this.means.to(device); - this.stdevs = this.stdevs.to(device); - } } public Tensor call(Tensor input) { - if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); - return (input - means) / stdevs; - } - - private Tensor means; - private Tensor stdevs; - bool disposedValue; - - protected virtual void Dispose(bool disposing) - { - if (!disposedValue) { - means?.Dispose(); - stdevs?.Dispose(); - disposedValue = true; - } + var expectedChannels = transforms.functional.get_image_num_channels(input); + if (expectedChannels != means.Length) + throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); + return transforms.functional.normalize(input, means, stdevs, inplace); } - ~Normalize() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: false); - } - - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); - } + private readonly double[] means; + private readonly double[] stdevs; + private readonly bool inplace; + } public static partial class transforms @@ -72,12 +42,11 @@ public static partial class transforms /// /// Sequence of means for each channel. /// Sequence of standard deviations for each channel. - /// Bool to make this operation inplace. - /// The device to place the output tensor on. + /// Bool to make this operation inplace. /// - static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null) + static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false) { - return new Normalize(means, stdevs, dtype, device); + return new Normalize(means, stdevs, inplace); } } } diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index c8f1bc341..4fd563244 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -845,17 +845,6 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveDifferen Assert.Throws(() => Normalize(means, stdevs)); } - [Fact] - public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLengths() - { - // Arrange - double[] means = { 0.485, 0.456 }; - double[] stdevs = { 0.229, 0.224 }; // Not 1 or 3 - - // Act & Assert - Assert.Throws(() => Normalize(means, stdevs)); - } - [Fact] public void TestConstructor_CreatesNewNormalizeObject_WithValidArguments() {