Skip to content

Commit

Permalink
accord-netGH-807: Add an Example for CrossValidating NaiveBayes
Browse files Browse the repository at this point in the history
  • Loading branch information
cesarsouza committed Aug 24, 2017
1 parent 0438083 commit 3035647
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Setup/Accord.Setup.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
<None Include="NuGet\Accord.MachineLearning.nuspec" />
<None Include="NuGet\Accord.Math.Noncommercial.nuspec" />
<None Include="NuGet\Accord.Math.nuspec" />
<None Include="NuGet\Accord.Neuro.nuspec" />
<None Include="NuGet\Accord.Neuro.nuspec">
<SubType>Designer</SubType>
</None>
<None Include="NuGet\Accord.nuspec" />
<None Include="NuGet\Accord.Statistics.nuspec" />
<None Include="NuGet\Accord.Text.nuspec" />
Expand Down
5 changes: 5 additions & 0 deletions Sources/Accord.MachineLearning/Bayes/NaiveBayes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ namespace Accord.MachineLearning.Bayes
/// Naive Bayes on those vectors.</para>
///
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_multiclass" />
///
/// <para>
/// Like all other learning algorithms in the framework, it is also possible to obtain a better measure
/// of the performance of the Naive Bayes algorithm using cross-validation, as shown in the example below:</para>
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
/// <seealso cref="NaiveBayesLearning"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace Accord.MachineLearning
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
[Serializable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public delegate CrossValidationValues<TModel>
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
[Serializable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace Accord.MachineLearning.Performance
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
[Serializable]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ namespace Accord.MachineLearning.Performance
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
/// <seealso cref="Bootstrap{TModel, TInput, TOutput}"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace Accord.MachineLearning.Performance
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
public class CrossValidation<TModel, TInput, TOutput> :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace Accord.MachineLearning.Performance
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
/// </example>
///
public class CrossValidation<TModel, TLearner, TInput, TOutput> :
Expand Down
89 changes: 89 additions & 0 deletions Unit Tests/Accord.Tests.MachineLearning/Bayes/NaiveBayesTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,95 @@ public void gh_758()

Assert.IsTrue(teacher.optimized);
}

[Test]
public void CrossValidationTest()
{
#region doc_cross_validation
// Ensure we have reproducible results
Accord.Math.Random.Generator.Seed = 0;

// Let's say we have the following data to be classified
// into three possible classes. Those are the samples:
//
int[][] inputs =
{
// input output
new int[] { 0, 1, 1, 0 }, // 0
new int[] { 0, 1, 0, 0 }, // 0
new int[] { 0, 0, 1, 0 }, // 0
new int[] { 0, 1, 1, 0 }, // 0
new int[] { 0, 1, 0, 0 }, // 0
new int[] { 1, 0, 0, 0 }, // 1
new int[] { 1, 0, 0, 0 }, // 1
new int[] { 1, 0, 0, 1 }, // 1
new int[] { 0, 0, 0, 1 }, // 1
new int[] { 0, 0, 0, 1 }, // 1
new int[] { 1, 1, 1, 1 }, // 2
new int[] { 1, 0, 1, 1 }, // 2
new int[] { 1, 1, 0, 1 }, // 2
new int[] { 0, 1, 1, 1 }, // 2
new int[] { 1, 1, 1, 1 }, // 2
};

int[] outputs = // those are the class labels
{
0, 0, 0, 0, 0,
1, 1, 1, 1, 1,
2, 2, 2, 2, 2,
};

// Let's say we want to measure the cross-validation
// performance of Naive Bayes on the above data set:
var cv = CrossValidation.Create(

k: 10, // We will be using 10-fold cross validation

// First we define the learning algorithm:
learner: (p) => new NaiveBayesLearning(),

// Now we have to specify how the tree performance should be measured:
loss: (actual, expected, p) => new ZeroOneLoss(expected).Loss(actual),

// This function can be used to perform any special
// operations before the actual learning is done, but
// here we will just leave it as simple as it can be:
fit: (teacher, x, y, w) => teacher.Learn(x, y, w),

// Finally, we have to pass the input and output data
// that will be used in cross-validation.
x: inputs, y: outputs
);

// After the cross-validation object has been created,
// we can call its .Learn method with the input and
// output data that will be partitioned into the folds:
var result = cv.Learn(inputs, outputs);

// We can grab some information about the problem:
int numberOfSamples = result.NumberOfSamples; // should be 15
int numberOfInputs = result.NumberOfInputs; // should be 4
int numberOfOutputs = result.NumberOfOutputs; // should be 3

double trainingError = result.Training.Mean; // should be 0
double validationError = result.Validation.Mean; // should be 0.05
#endregion

Assert.AreEqual(15, numberOfSamples);
Assert.AreEqual(4, numberOfInputs);
Assert.AreEqual(3, numberOfOutputs);

Assert.AreEqual(10, cv.K);
Assert.AreEqual(0, result.Training.Mean, 1e-10);
Assert.AreEqual(0.05, result.Validation.Mean, 1e-10);

Assert.AreEqual(0, result.Training.Variance, 1e-10);
Assert.AreEqual(0.025000000000000005, result.Validation.Variance, 1e-10);

Assert.AreEqual(10, cv.Folds.Length);
Assert.AreEqual(10, result.Models.Length);

}
}
}
#endif

0 comments on commit 3035647

Please sign in to comment.