Skip to content

Commit d59db46

Browse files
committed
Matrix multiplication and crossproduct for double values: convert to use Truffle libs
1 parent 50ea6a2 commit d59db46

File tree

2 files changed

+90
-62
lines changed

2 files changed

+90
-62
lines changed

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CrossprodCommon.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import com.oracle.truffle.api.CompilerDirectives;
3131
import com.oracle.truffle.api.dsl.Cached;
3232
import com.oracle.truffle.api.dsl.Specialization;
33+
import com.oracle.truffle.api.library.CachedLibrary;
3334
import com.oracle.truffle.api.profiles.ConditionProfile;
35+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3436
import com.oracle.truffle.r.runtime.data.nodes.ExtractListElement;
3537
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
3638
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
@@ -80,8 +82,10 @@ public static CrossprodCommon createTCrossprod() {
8082
return CrossprodCommonNodeGen.create(false);
8183
}
8284

83-
@Specialization(guards = {"getXDimsNode.isMatrix(x)", "getYDimsNode.isMatrix(y)"})
85+
@Specialization(guards = {"getXDimsNode.isMatrix(x)", "getYDimsNode.isMatrix(y)"}, limit = "getTypedVectorDataLibraryCacheSize()")
8486
protected RDoubleVector crossprod(RDoubleVector x, RDoubleVector y,
87+
@CachedLibrary("x.getData()") VectorDataLibrary xDataLib,
88+
@CachedLibrary("y.getData()") VectorDataLibrary yDataLib,
8589
@Cached("create()") GetDimAttributeNode getXDimsNode,
8690
@Cached("create()") GetDimAttributeNode getYDimsNode) {
8791
int[] xDims = getXDimsNode.getDimensions(x);
@@ -90,7 +94,8 @@ protected RDoubleVector crossprod(RDoubleVector x, RDoubleVector y,
9094
int xCols = transposeX ? xDims[0] : xDims[1];
9195
int yRows = transposeX ? yDims[0] : yDims[1];
9296
int yCols = transposeX ? yDims[1] : yDims[0];
93-
RDoubleVector result = matMult.doubleMatrixMultiply(x, y, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(yDims[0]), getYColStride(yDims[0]),
97+
RDoubleVector result = matMult.doubleMatrixMultiply(xDataLib, x.getData(), x, yDataLib, y.getData(), y, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]),
98+
getYRowStride(yDims[0]), getYColStride(yDims[0]),
9499
false);
95100
return copyDimNames(x, y, result);
96101
}
@@ -100,8 +105,9 @@ protected Object crossprod(RAbstractVector x, RAbstractVector y) {
100105
return copyDimNames(x, y, (RAbstractVector) matMult.executeObject(transposeX(x), transposeY(y)));
101106
}
102107

103-
@Specialization(guards = "getDimsNode.isMatrix(x)")
108+
@Specialization(guards = "getDimsNode.isMatrix(x)", limit = "getTypedVectorDataLibraryCacheSize()")
104109
protected RDoubleVector crossprodDoubleMatrix(RDoubleVector x, @SuppressWarnings("unused") RNull y,
110+
@CachedLibrary("x.getData()") VectorDataLibrary xDataLib,
105111
@Cached("create()") GetReadonlyData.Double getReadonlyData,
106112
@Cached("create()") GetDimAttributeNode getDimsNode,
107113
@Cached("create()") GetDimAttributeNode getResultDimsNode) {
@@ -110,8 +116,10 @@ protected RDoubleVector crossprodDoubleMatrix(RDoubleVector x, @SuppressWarnings
110116
int xCols = transposeX ? xDims[0] : xDims[1];
111117
int yRows = transposeX ? xDims[0] : xDims[1];
112118
int yCols = transposeX ? xDims[1] : xDims[0];
119+
Object xData = x.getData();
113120
RDoubleVector result = mirror(
114-
matMult.doubleMatrixMultiply(x, x, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(xDims[0]), getYColStride(xDims[0]), true),
121+
matMult.doubleMatrixMultiply(xDataLib, xData, x, xDataLib, xData, x, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(xDims[0]),
122+
getYColStride(xDims[0]), true),
115123
getResultDimsNode,
116124
getReadonlyData);
117125
return copyDimNames(x, x, result);

0 commit comments

Comments
 (0)