Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions opennlp-api/src/main/java/opennlp/tools/ml/ArrayMath.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,55 @@
*/
public class ArrayMath {

private static final String OS_NAME = System.getProperty("os.name", "Unknown");
private static final String OS_ARCH = System.getProperty("os.arch", "Unknown");
private static final boolean MAC_OS_X = OS_NAME.startsWith("Mac OS X");

private static boolean hasHWVectorFMA() {
// aarch64 has hw fma, but not on silicon
if (OS_ARCH.equals("aarch64") && !MAC_OS_X) {
return true;
}
// intel et al. support it nowadays
if (OS_ARCH.equals("amd64")) {
return true;
}
// otherwise
return false;
}

public static double innerProduct(double[] vecA, double[] vecB) {
if (vecA == null || vecB == null || vecA.length != vecB.length)
return Double.NaN;

double product = 0.0;
for (int i = 0; i < vecA.length; i++) {
product += vecA[i] * vecB[i];
if (hasHWVectorFMA()) {
double product = 0;
int i = 0;

// unroll, in case the arrays are large enough
if (vecA.length > 32) {
double acc1 = 0, acc2 = 0, acc3 = 0, acc4 = 0;
int upperBound = vecA.length & ~(4 - 1);
for (; i < upperBound; i += 4) {
acc1 = StrictMath.fma(vecA[i], vecB[i], acc1);
acc2 = StrictMath.fma(vecA[i + 1], vecB[i + 1], acc2);
acc3 = StrictMath.fma(vecA[i + 2], vecB[i + 2], acc3);
acc4 = StrictMath.fma(vecA[i + 3], vecB[i + 3], acc4);
}
product += acc1 + acc2 + acc3 + acc4;
}

for (; i < vecA.length; i++) {
product = StrictMath.fma(vecA[i], vecB[i], product);
}
return product;
} else {
double product = 0.0;
for (int i = 0; i < vecA.length; i++) {
product += vecA[i] * vecB[i];
}
return product;
}
return product;
}

/**
Expand Down