From a6eec896ef9dad724b439fe8e458371b7b5339f3 Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Thu, 26 Dec 2024 17:41:37 +0200 Subject: [PATCH 1/8] Refactor Normalize method like pytorch. --- src/TorchVision/Functional.cs | 46 ++++++++++++++++++++++++++--------- src/TorchVision/Normalize.cs | 34 +++++++++----------------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 0f8b00259..081d1244e 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -24,6 +24,19 @@ public static partial class transforms { public static partial class functional { + + private static bool IsTensorImage(Tensor img) + { + return img.ndim >= 2; + } + + private static bool AssertTensorImage(Tensor img) + { + if (!IsTensorImage(img)) + throw new ArgumentException("Tensor is not a torch image."); + return true; + } + /// /// Get the image dimensions /// @@ -533,20 +546,29 @@ public static Tensor invert(Tensor input) /// An image tensor. /// Sequence of means for each channel. /// Sequence of standard deviations for each channel. - /// Bool to make this operation inplace. + /// Bool to make this operation inplace. /// - public static Tensor normalize(Tensor input, double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32) + public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bool inplace = false) { - if (means.Length != stdevs.Length) - throw new ArgumentException("means and stdevs must be the same length in call to Normalize"); - if (means.Length != input.shape[1]) - throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); - - using var mean = means.ToTensor(new long[] { 1, means.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW - using var stdev = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW - using var t0 = input - mean; - - return t0 / stdev; + using var _ = NewDisposeScope(); + AssertTensorImage(input); + if (!input.is_floating_point()) + throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}."); + if (input.ndim < 3) + throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {input.size()}"); + if (!inplace) + input = input.clone(); + + + var mean = as_tensor(means, dtype: input.dtype, device: input.device); + var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device); + if (stdev.eq(0).any().ToBoolean()) + throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero."); + if (mean.ndim == 1) + mean = mean.view(-1, 1, 1); + if (stdev.ndim == 1) + stdev = stdev.view(-1, 1, 1); + return input.sub_(mean).div_(stdev).MoveToOuterDisposeScope(); } private static Tensor _pad(Tensor input, ReadOnlySpan padding, double fill = 0, PaddingModes padding_mode = PaddingModes.Constant) diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index 27ec1bb44..b5b4b1d6b 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -9,7 +9,7 @@ public static partial class torchvision { internal class Normalize : ITransform, IDisposable { - internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null) + internal Normalize(double[] means, double[] stdevs,bool inplace = false) { if (means is null) throw new ArgumentNullException(nameof(means)); if (stdevs is null) throw new ArgumentNullException(nameof(stdevs)); @@ -17,36 +17,25 @@ internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarTyp throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize"); if (means.Length != 1 && means.Length != 3) throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long"); + this.means = means; + this.stdevs = stdevs; + this.inplace = inplace; - this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW - this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW - - if (dtype != ScalarType.Float64) { - this.means = this.means.to_type(dtype); - this.stdevs = this.stdevs.to_type(dtype); - } - - if (device != null && device.type != DeviceType.CPU) { - this.means = this.means.to(device); - this.stdevs = this.stdevs.to(device); - } } public Tensor call(Tensor input) { - if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); - return (input - means) / stdevs; + return transforms.functional.normalize(input, means, stdevs, inplace); } - private Tensor means; - private Tensor stdevs; + private readonly double[] means; + private readonly double[] stdevs; + private readonly bool inplace; bool disposedValue; protected virtual void Dispose(bool disposing) { if (!disposedValue) { - means?.Dispose(); - stdevs?.Dispose(); disposedValue = true; } } @@ -72,12 +61,11 @@ public static partial class transforms /// /// Sequence of means for each channel. /// Sequence of standard deviations for each channel. - /// Bool to make this operation inplace. - /// The device to place the output tensor on. + /// Bool to make this operation inplace. /// - static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null) + static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false) { - return new Normalize(means, stdevs, dtype, device); + return new Normalize(means, stdevs, inplace); } } } From ecfc4b542aaf94cc81475f3007a40d78798bf2f6 Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Fri, 27 Dec 2024 15:31:25 +0200 Subject: [PATCH 2/8] Update the condition --- src/TorchVision/Functional.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 081d1244e..279477eb9 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -562,7 +562,7 @@ public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bo var mean = as_tensor(means, dtype: input.dtype, device: input.device); var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device); - if (stdev.eq(0).any().ToBoolean()) + if ((stdev == 0).any().item()) throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero."); if (mean.ndim == 1) mean = mean.view(-1, 1, 1); From da3117367b052d874e6184a887860c2b4e85b1f1 Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Fri, 27 Dec 2024 15:41:41 +0200 Subject: [PATCH 3/8] Remove IDisposable and constructor checks from Normalize --- src/TorchVision/Normalize.cs | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index b5b4b1d6b..b50f379c7 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -7,16 +7,12 @@ namespace TorchSharp { public static partial class torchvision { - internal class Normalize : ITransform, IDisposable + internal class Normalize : ITransform { internal Normalize(double[] means, double[] stdevs,bool inplace = false) { if (means is null) throw new ArgumentNullException(nameof(means)); if (stdevs is null) throw new ArgumentNullException(nameof(stdevs)); - if (means.Length != stdevs.Length) - throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize"); - if (means.Length != 1 && means.Length != 3) - throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long"); this.means = means; this.stdevs = stdevs; this.inplace = inplace; @@ -31,27 +27,7 @@ public Tensor call(Tensor input) private readonly double[] means; private readonly double[] stdevs; private readonly bool inplace; - bool disposedValue; - - protected virtual void Dispose(bool disposing) - { - if (!disposedValue) { - disposedValue = true; - } - } - - ~Normalize() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: false); - } - - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); - } + } public static partial class transforms From 6999a5a3b5211b802f0421cbd27e9fda8e60304d Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Fri, 27 Dec 2024 19:33:05 +0200 Subject: [PATCH 4/8] Add validation checks in Normalize constructor and call method --- src/TorchVision/Normalize.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index b50f379c7..12411c5e8 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -13,6 +13,8 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false) { if (means is null) throw new ArgumentNullException(nameof(means)); if (stdevs is null) throw new ArgumentNullException(nameof(stdevs)); + if (means.Length != stdevs.Length) + throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize"); this.means = means; this.stdevs = stdevs; this.inplace = inplace; @@ -21,6 +23,9 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false) public Tensor call(Tensor input) { + var expectedChannels = (input.shape.Length == 4) ? input.size(1) : input.size(0); + if (expectedChannels != means.Length) + throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); return transforms.functional.normalize(input, means, stdevs, inplace); } From 196124b85ba01b249e40f54409263e9e4b47e464 Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Sat, 28 Dec 2024 19:00:26 +0200 Subject: [PATCH 5/8] Improve tensor error messages and code readability. --- src/TorchVision/Functional.cs | 2 +- src/TorchVision/Normalize.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 279477eb9..44f5d834f 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -555,7 +555,7 @@ public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bo if (!input.is_floating_point()) throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}."); if (input.ndim < 3) - throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {input.size()}"); + throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ({string.Join(", ", input.shape)})"); if (!inplace) input = input.clone(); diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index 12411c5e8..ef3f1c03d 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -23,7 +23,7 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false) public Tensor call(Tensor input) { - var expectedChannels = (input.shape.Length == 4) ? input.size(1) : input.size(0); + var expectedChannels = (input.ndim == 4) ? input.size(1) : input.size(0); if (expectedChannels != means.Length) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); return transforms.functional.normalize(input, means, stdevs, inplace); From 0e6513231172debb81cfdeb02125971e73462af3 Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Sat, 28 Dec 2024 19:15:12 +0200 Subject: [PATCH 6/8] Refactor channel determination logic in TorchSharp. --- src/TorchVision/Functional.cs | 11 +++++++++++ src/TorchVision/Normalize.cs | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 44f5d834f..b24463e4a 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -37,6 +37,17 @@ private static bool AssertTensorImage(Tensor img) return true; } + public static long GetImageNumChannels(Tensor img) + { + AssertTensorImage(img); + var ndim_ = img.ndim; + return ndim_ switch { + 2 => 1, + > 2 => img.shape[ndim_ - 3], + _ => throw new ArgumentException($"Input ndim should be 2 or more. Got {ndim_}"), + }; + } + /// /// Get the image dimensions /// diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index ef3f1c03d..e15fc114d 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -23,7 +23,7 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false) public Tensor call(Tensor input) { - var expectedChannels = (input.ndim == 4) ? input.size(1) : input.size(0); + var expectedChannels = transforms.functional.GetImageNumChannels(input); if (expectedChannels != means.Length) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); return transforms.functional.normalize(input, means, stdevs, inplace); From 2e72fac291024d613f13488d18729917f1f4788e Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Sat, 28 Dec 2024 20:12:34 +0200 Subject: [PATCH 7/8] Remove `TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLengths` test for ArgumentException in Normalize function --- test/TorchSharpTest/TestTorchVision.cs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index c8f1bc341..4fd563244 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -845,17 +845,6 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveDifferen Assert.Throws(() => Normalize(means, stdevs)); } - [Fact] - public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLengths() - { - // Arrange - double[] means = { 0.485, 0.456 }; - double[] stdevs = { 0.229, 0.224 }; // Not 1 or 3 - - // Act & Assert - Assert.Throws(() => Normalize(means, stdevs)); - } - [Fact] public void TestConstructor_CreatesNewNormalizeObject_WithValidArguments() { From 80fce55e98c3bed2d661e32e2acec7982e24e87c Mon Sep 17 00:00:00 2001 From: AhmedZero Date: Sat, 4 Jan 2025 20:32:43 +0200 Subject: [PATCH 8/8] Rename GetImageNumChannels and update references Renamed the method `GetImageNumChannels` to `get_image_num_channels` in `Functional.cs`. Added a summary documentation comment to the `get_image_num_channels` method, detailing its purpose, parameter, and return value. Updated the call to this method in `Normalize.cs` to reflect the new name. --- src/TorchVision/Functional.cs | 7 ++++++- src/TorchVision/Normalize.cs | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index b24463e4a..9d98c1ea3 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -37,7 +37,12 @@ private static bool AssertTensorImage(Tensor img) return true; } - public static long GetImageNumChannels(Tensor img) + /// + /// Returns the number of channels of an image. + /// + /// (Tensor) – The image to be checked. + /// The number of channels. + public static long get_image_num_channels(Tensor img) { AssertTensorImage(img); var ndim_ = img.ndim; diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs index e15fc114d..6bcb597b9 100644 --- a/src/TorchVision/Normalize.cs +++ b/src/TorchVision/Normalize.cs @@ -23,7 +23,7 @@ internal Normalize(double[] means, double[] stdevs,bool inplace = false) public Tensor call(Tensor input) { - var expectedChannels = transforms.functional.GetImageNumChannels(input); + var expectedChannels = transforms.functional.get_image_num_channels(input); if (expectedChannels != means.Length) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations"); return transforms.functional.normalize(input, means, stdevs, inplace);