Skip to content

Commit

Permalink
Add new tests for UniversalRandom and SpatialPooler
Browse files Browse the repository at this point in the history
  • Loading branch information
cogmission committed Sep 5, 2016
1 parent 8c7b8b0 commit 7c9a9a4
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public void connectAndConfigureInputs(Connections c) {
*/
public void compute(Connections c, int[] inputVector, int[] activeArray, boolean learn, boolean stripNeverLearned) {
if(inputVector.length != c.getNumInputs()) {
throw new IllegalArgumentException(
throw new InvalidSPParamValueException(
"Input array must be same size as the defined number of inputs: From Params: " + c.getNumInputs() +
", From Input Vector: " + inputVector.length);
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/org/numenta/nupic/algorithms/TemporalMemory.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;

import static org.numenta.nupic.util.GroupBy2.Slot.NONE;
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/numenta/nupic/util/ArrayUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ public static double[] roundDivide(double[] dividend, double[] divisor, int scal
*
* @param multiplicand
* @param factor
* @param multiplicand adjustment
* @param factor adjustment
* @param multiplicandAdjustment
* @param factorAdjustment
*
* @return
* @throws IllegalArgumentException if the two argument arrays are not the same length
Expand All @@ -816,7 +816,7 @@ public static double[] multiply(

if (multiplicand.length != factor.length) {
throw new IllegalArgumentException(
"The multiplicand array and the factor array must be the same length");
"The multiplicand array and the factor array must be the same length");
}
double[] product = new double[multiplicand.length];
for (int i = 0; i < multiplicand.length; i++) {
Expand Down
47 changes: 43 additions & 4 deletions src/main/java/org/numenta/nupic/util/UniversalRandom.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.util;

import java.math.BigDecimal;
Expand All @@ -11,6 +32,17 @@

import gnu.trove.list.array.TIntArrayList;

/**
* <p>
* This also has a Python version which is guaranteed to output the same random
* numbers if given the same initial seed value.
* </p><p>
* Implementation of George Marsaglia's elegant Xorshift random generator
* 30% faster and better quality than the built-in java.util.random.
* <p>
* see http://www.javamex.com/tutorials/random_numbers/xorshift.shtml.
* @author cogmission
*/
public class UniversalRandom extends Random {
/** serial version */
private static final long serialVersionUID = 1L;
Expand Down Expand Up @@ -44,6 +76,9 @@ public long getSeed() {
return seed;
}

/*
* Internal method used for testing
*/
private int[] sampleWithPrintout(TIntArrayList choices, int[] selectedIndices, List<Integer> collectedRandoms) {
TIntArrayList choiceSupply = new TIntArrayList(choices);
int upperBound = choices.size();
Expand Down Expand Up @@ -105,10 +140,14 @@ public int nextInt(int bound) {
r = (int)((bound * (long)r) >> 31);
else {
r = r % bound;
// for (int u = r;
// u - (r = u % bound) + m < 0;
// u = next(31))
// ;
/*
THIS CODE IS COMMENTED TO WORK IDENTICALLY WITH THE PYTHON VERSION
for (int u = r;
u - (r = u % bound) + m < 0;
u = next(31))
;
*/
}
//System.out.println("nextInt(" + bound + "): " + r);
return r;
Expand Down
21 changes: 21 additions & 0 deletions src/test/java/org/numenta/nupic/QuickTest.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic;

import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE;
Expand Down
132 changes: 131 additions & 1 deletion src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2016, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.algorithms;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.util.Arrays;
import java.util.stream.IntStream;
Expand All @@ -11,6 +33,7 @@
import org.numenta.nupic.Connections;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.SpatialPooler.InvalidSPParamValueException;
import org.numenta.nupic.model.Pool;
import org.numenta.nupic.util.AbstractSparseBinaryMatrix;
import org.numenta.nupic.util.ArrayUtils;
Expand Down Expand Up @@ -235,7 +258,6 @@ public void testOverlapsOutput() {

double[] boostedOverlaps = cn.getBoostedOverlaps();
int[] overlaps = cn.getOverlaps();
System.out.println("out = " + Arrays.toString(activeArray));

for(int i = 0;i < cn.getNumColumns();i++) {
assertEquals(expOutput[i], overlaps[i]);
Expand Down Expand Up @@ -1897,4 +1919,112 @@ public void testGetNeighborsND() {
assertTrue(sbm.all(mask));
assertFalse(sbm.any(neg));
}

@Test
public void testInit() {
setupParameters();
parameters.setNumActiveColumnsPerInhArea(0);
parameters.setLocalAreaDensity(0);

Connections c = new Connections();
parameters.apply(c);

SpatialPooler sp = new SpatialPooler();

// Local Area Density cannot be 0
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Inhibition parameters are invalid", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}

// Local Area Density can't be above 0.5
parameters.setLocalAreaDensity(0.51);
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Inhibition parameters are invalid", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}

// Local Area Density should be sane here
parameters.setLocalAreaDensity(0.5);
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
}catch(Exception e) {
fail();
}

// Num columns cannot be 0
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 0 });
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Invalid number of columns: 0", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}

// Reset column dims
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { 5 });

// Num columns cannot be 0
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 0 });
c = new Connections();
parameters.apply(c);
try {
sp.init(c);
fail();
}catch(Exception e) {
assertEquals("Invalid number of inputs: 0", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}
}

@Test
public void testComputeInputMismatch() {
setupParameters();
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 2, 4 });
parameters.setColumnDimensions(new int[] { 5, 1 });

Connections c = new Connections();
parameters.apply(c);

int misMatchedDims = 6; // not 8
SpatialPooler sp = new SpatialPooler();
sp.init(c);
try {
sp.compute(c, new int[misMatchedDims], new int[25], true, true);
fail();
}catch(Exception e) {
assertEquals("Input array must be same size as the defined number"
+ " of inputs: From Params: 8, From Input Vector: 6", e.getMessage());
assertEquals(InvalidSPParamValueException.class, e.getClass());
}


// Now Do the right thing
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { 2, 4 });
parameters.setColumnDimensions(new int[] { 5, 1 });

c = new Connections();
parameters.apply(c);

int matchedDims = 8; // same as input dimension multiplied, above
sp.init(c);
try {
sp.compute(c, new int[matchedDims], new int[25], true, true);
}catch(Exception e) {
fail();
}
}
}
Loading

0 comments on commit 7c9a9a4

Please sign in to comment.