-
Notifications
You must be signed in to change notification settings - Fork 196
/
Copy pathNormalize.cs
53 lines (47 loc) · 2.15 KB
/
Normalize.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using static TorchSharp.torch;
#nullable enable
namespace TorchSharp
{
public static partial class torchvision
{
internal class Normalize : ITransform
{
internal Normalize(double[] means, double[] stdevs,bool inplace = false)
{
if (means is null) throw new ArgumentNullException(nameof(means));
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
if (means.Length != stdevs.Length)
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
this.means = means;
this.stdevs = stdevs;
this.inplace = inplace;
}
public Tensor call(Tensor input)
{
var expectedChannels = transforms.functional.get_image_num_channels(input);
if (expectedChannels != means.Length)
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
return transforms.functional.normalize(input, means, stdevs, inplace);
}
private readonly double[] means;
private readonly double[] stdevs;
private readonly bool inplace;
}
public static partial class transforms
{
/// <summary>
/// Normalize a float tensor image with mean and standard deviation.
/// </summary>
/// <param name="means">Sequence of means for each channel.</param>
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
/// <param name="inplace">Bool to make this operation inplace.</param>
/// <returns></returns>
static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false)
{
return new Normalize(means, stdevs, inplace);
}
}
}
}