diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs
index b26d4715f..d3b660c7b 100644
--- a/src/TorchVision/Functional.cs
+++ b/src/TorchVision/Functional.cs
@@ -24,6 +24,35 @@ public static partial class transforms
{
public static partial class functional
{
+
+ private static bool IsTensorImage(Tensor img)
+ {
+ return img.ndim >= 2;
+ }
+
+ private static bool AssertTensorImage(Tensor img)
+ {
+ if (!IsTensorImage(img))
+ throw new ArgumentException("Tensor is not a torch image.");
+ return true;
+ }
+
+ ///
+ /// Returns the number of channels of an image.
+ ///
+ /// (Tensor) – The image to be checked.
+ /// The number of channels.
+ public static long get_image_num_channels(Tensor img)
+ {
+ AssertTensorImage(img);
+ var ndim_ = img.ndim;
+ return ndim_ switch {
+ 2 => 1,
+ > 2 => img.shape[ndim_ - 3],
+ _ => throw new ArgumentException($"Input ndim should be 2 or more. Got {ndim_}"),
+ };
+ }
+
///
/// Get the image dimensions
///
@@ -533,20 +562,29 @@ public static Tensor invert(Tensor input)
/// An image tensor.
/// Sequence of means for each channel.
/// Sequence of standard deviations for each channel.
- /// Bool to make this operation inplace.
+ /// Bool to make this operation inplace.
///
- public static Tensor normalize(Tensor input, double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32)
+ public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bool inplace = false)
{
- if (means.Length != stdevs.Length)
- throw new ArgumentException("means and stdevs must be the same length in call to Normalize");
- if (means.Length != input.shape[1])
- throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
-
- using var mean = means.ToTensor(new long[] { 1, means.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
- using var stdev = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
- using var t0 = input - mean;
-
- return t0 / stdev;
+ using var _ = NewDisposeScope();
+ AssertTensorImage(input);
+ if (!input.is_floating_point())
+ throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}.");
+ if (input.ndim < 3)
+ throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ({string.Join(", ", input.shape)})");
+ if (!inplace)
+ input = input.clone();
+
+
+ var mean = as_tensor(means, dtype: input.dtype, device: input.device);
+ var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device);
+ if ((stdev == 0).any().item())
+ throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero.");
+ if (mean.ndim == 1)
+ mean = mean.view(-1, 1, 1);
+ if (stdev.ndim == 1)
+ stdev = stdev.view(-1, 1, 1);
+ return input.sub_(mean).div_(stdev).MoveToOuterDisposeScope();
}
private static Tensor _pad(Tensor input, ReadOnlySpan padding, double fill = 0, PaddingModes padding_mode = PaddingModes.Constant)
diff --git a/src/TorchVision/Normalize.cs b/src/TorchVision/Normalize.cs
index 27ec1bb44..6bcb597b9 100644
--- a/src/TorchVision/Normalize.cs
+++ b/src/TorchVision/Normalize.cs
@@ -7,62 +7,32 @@ namespace TorchSharp
{
public static partial class torchvision
{
- internal class Normalize : ITransform, IDisposable
+ internal class Normalize : ITransform
{
- internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
+ internal Normalize(double[] means, double[] stdevs,bool inplace = false)
{
if (means is null) throw new ArgumentNullException(nameof(means));
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
if (means.Length != stdevs.Length)
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
- if (means.Length != 1 && means.Length != 3)
- throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long");
+ this.means = means;
+ this.stdevs = stdevs;
+ this.inplace = inplace;
- this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW
- this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW
-
- if (dtype != ScalarType.Float64) {
- this.means = this.means.to_type(dtype);
- this.stdevs = this.stdevs.to_type(dtype);
- }
-
- if (device != null && device.type != DeviceType.CPU) {
- this.means = this.means.to(device);
- this.stdevs = this.stdevs.to(device);
- }
}
public Tensor call(Tensor input)
{
- if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
- return (input - means) / stdevs;
- }
-
- private Tensor means;
- private Tensor stdevs;
- bool disposedValue;
-
- protected virtual void Dispose(bool disposing)
- {
- if (!disposedValue) {
- means?.Dispose();
- stdevs?.Dispose();
- disposedValue = true;
- }
+ var expectedChannels = transforms.functional.get_image_num_channels(input);
+ if (expectedChannels != means.Length)
+ throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
+ return transforms.functional.normalize(input, means, stdevs, inplace);
}
- ~Normalize()
- {
- // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
- Dispose(disposing: false);
- }
-
- public void Dispose()
- {
- // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
- Dispose(disposing: true);
- GC.SuppressFinalize(this);
- }
+ private readonly double[] means;
+ private readonly double[] stdevs;
+ private readonly bool inplace;
+
}
public static partial class transforms
@@ -72,12 +42,11 @@ public static partial class transforms
///
/// Sequence of means for each channel.
/// Sequence of standard deviations for each channel.
- /// Bool to make this operation inplace.
- /// The device to place the output tensor on.
+ /// Bool to make this operation inplace.
///
- static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
+ static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false)
{
- return new Normalize(means, stdevs, dtype, device);
+ return new Normalize(means, stdevs, inplace);
}
}
}
diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs
index c8f1bc341..4fd563244 100644
--- a/test/TorchSharpTest/TestTorchVision.cs
+++ b/test/TorchSharpTest/TestTorchVision.cs
@@ -845,17 +845,6 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveDifferen
Assert.Throws(() => Normalize(means, stdevs));
}
- [Fact]
- public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLengths()
- {
- // Arrange
- double[] means = { 0.485, 0.456 };
- double[] stdevs = { 0.229, 0.224 }; // Not 1 or 3
-
- // Act & Assert
- Assert.Throws(() => Normalize(means, stdevs));
- }
-
[Fact]
public void TestConstructor_CreatesNewNormalizeObject_WithValidArguments()
{