Skip to content

Commit c28a7c1

Browse files
committed
[GR-21974] More new model preformance fixes: matrix multiplication, round, specials.
PullRequest: fastr/2416
2 parents 5b7d871 + d59db46 commit c28a7c1

File tree

7 files changed

+234
-177
lines changed

7 files changed

+234
-177
lines changed

com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nodes/AsRealNode.java

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@
2222

2323
import com.oracle.truffle.api.dsl.Cached;
2424
import com.oracle.truffle.api.dsl.Fallback;
25+
import com.oracle.truffle.api.dsl.ImportStatic;
2526
import com.oracle.truffle.api.dsl.Specialization;
2627
import com.oracle.truffle.api.dsl.TypeSystemReference;
28+
import com.oracle.truffle.api.library.CachedLibrary;
29+
import com.oracle.truffle.api.profiles.BranchProfile;
2730
import com.oracle.truffle.r.nodes.unary.CastDoubleNode;
31+
import com.oracle.truffle.r.runtime.DSLConfig;
2832
import com.oracle.truffle.r.runtime.RInternalError;
2933
import com.oracle.truffle.r.runtime.RRuntime;
3034
import com.oracle.truffle.r.runtime.data.RDoubleVector;
3135
import com.oracle.truffle.r.runtime.data.RTypes;
36+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3237
import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector;
3338
import com.oracle.truffle.r.runtime.data.RIntVector;
3439

@@ -38,6 +43,7 @@
3843
* return {@code NA}.
3944
*/
4045
@TypeSystemReference(RTypes.class)
46+
@ImportStatic(DSLConfig.class)
4147
public abstract class AsRealNode extends FFIUpCallNode.Arg1 {
4248

4349
public abstract double execute(Object obj);
@@ -48,35 +54,55 @@ protected double asReal(double obj) {
4854
}
4955

5056
@Specialization
51-
protected double asReal(int obj) {
52-
return RRuntime.isNA(obj) ? RRuntime.DOUBLE_NA : obj;
57+
protected double asReal(int obj,
58+
@Cached BranchProfile naBranchProfile) {
59+
if (RRuntime.isNA(obj)) {
60+
naBranchProfile.enter();
61+
return RRuntime.DOUBLE_NA;
62+
} else {
63+
return obj;
64+
}
5365
}
5466

55-
@Specialization
56-
protected double asReal(RDoubleVector obj) {
57-
if (obj.getLength() == 0) {
67+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
68+
protected double asReal(RDoubleVector obj,
69+
@Cached BranchProfile naBranchProfile,
70+
@CachedLibrary("obj.getData()") VectorDataLibrary dataLib) {
71+
Object data = obj.getData();
72+
if (dataLib.getLength(data) == 0) {
73+
naBranchProfile.enter();
5874
return RRuntime.DOUBLE_NA;
5975
}
60-
return obj.getDataAt(0);
76+
return dataLib.getDoubleAt(data, 0);
6177
}
6278

63-
@Specialization
64-
protected double asReal(RIntVector obj) {
65-
if (obj.getLength() == 0) {
79+
@Specialization(limit = "getTypedVectorDataLibraryCacheSize()")
80+
protected double asReal(RIntVector obj,
81+
@Cached BranchProfile naBranchProfile,
82+
@CachedLibrary("obj.getData()") VectorDataLibrary dataLib) {
83+
Object data = obj.getData();
84+
if (dataLib.getLength(data) == 0) {
85+
naBranchProfile.enter();
86+
return RRuntime.DOUBLE_NA;
87+
}
88+
int result = dataLib.getIntAt(data, 0);
89+
if (RRuntime.isNA(result)) {
90+
naBranchProfile.enter();
6691
return RRuntime.DOUBLE_NA;
92+
} else {
93+
return result;
6794
}
68-
int result = obj.getDataAt(0);
69-
return RRuntime.isNA(result) ? RRuntime.DOUBLE_NA : result;
7095
}
7196

7297
@Specialization(guards = "obj.getLength() > 0")
7398
protected double asReal(RAbstractAtomicVector obj,
99+
@CachedLibrary(limit = "getCacheSize(2)") VectorDataLibrary dataLib,
74100
@Cached("createNonPreserving()") CastDoubleNode castDoubleNode) {
75101
Object castObj = castDoubleNode.executeDouble(obj);
76102
if (castObj instanceof Double) {
77103
return (double) castObj;
78104
} else if (castObj instanceof RDoubleVector) {
79-
return ((RDoubleVector) castObj).getDataAt(0);
105+
return dataLib.getDoubleAt(((RDoubleVector) castObj).getData(), 0);
80106
} else {
81107
throw RInternalError.shouldNotReachHere();
82108
}

com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CrossprodCommon.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import com.oracle.truffle.api.CompilerDirectives;
3131
import com.oracle.truffle.api.dsl.Cached;
3232
import com.oracle.truffle.api.dsl.Specialization;
33+
import com.oracle.truffle.api.library.CachedLibrary;
3334
import com.oracle.truffle.api.profiles.ConditionProfile;
35+
import com.oracle.truffle.r.runtime.data.VectorDataLibrary;
3436
import com.oracle.truffle.r.runtime.data.nodes.ExtractListElement;
3537
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
3638
import com.oracle.truffle.r.runtime.data.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
@@ -80,8 +82,10 @@ public static CrossprodCommon createTCrossprod() {
8082
return CrossprodCommonNodeGen.create(false);
8183
}
8284

83-
@Specialization(guards = {"getXDimsNode.isMatrix(x)", "getYDimsNode.isMatrix(y)"})
85+
@Specialization(guards = {"getXDimsNode.isMatrix(x)", "getYDimsNode.isMatrix(y)"}, limit = "getTypedVectorDataLibraryCacheSize()")
8486
protected RDoubleVector crossprod(RDoubleVector x, RDoubleVector y,
87+
@CachedLibrary("x.getData()") VectorDataLibrary xDataLib,
88+
@CachedLibrary("y.getData()") VectorDataLibrary yDataLib,
8589
@Cached("create()") GetDimAttributeNode getXDimsNode,
8690
@Cached("create()") GetDimAttributeNode getYDimsNode) {
8791
int[] xDims = getXDimsNode.getDimensions(x);
@@ -90,7 +94,8 @@ protected RDoubleVector crossprod(RDoubleVector x, RDoubleVector y,
9094
int xCols = transposeX ? xDims[0] : xDims[1];
9195
int yRows = transposeX ? yDims[0] : yDims[1];
9296
int yCols = transposeX ? yDims[1] : yDims[0];
93-
RDoubleVector result = matMult.doubleMatrixMultiply(x, y, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(yDims[0]), getYColStride(yDims[0]),
97+
RDoubleVector result = matMult.doubleMatrixMultiply(xDataLib, x.getData(), x, yDataLib, y.getData(), y, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]),
98+
getYRowStride(yDims[0]), getYColStride(yDims[0]),
9499
false);
95100
return copyDimNames(x, y, result);
96101
}
@@ -100,8 +105,9 @@ protected Object crossprod(RAbstractVector x, RAbstractVector y) {
100105
return copyDimNames(x, y, (RAbstractVector) matMult.executeObject(transposeX(x), transposeY(y)));
101106
}
102107

103-
@Specialization(guards = "getDimsNode.isMatrix(x)")
108+
@Specialization(guards = "getDimsNode.isMatrix(x)", limit = "getTypedVectorDataLibraryCacheSize()")
104109
protected RDoubleVector crossprodDoubleMatrix(RDoubleVector x, @SuppressWarnings("unused") RNull y,
110+
@CachedLibrary("x.getData()") VectorDataLibrary xDataLib,
105111
@Cached("create()") GetReadonlyData.Double getReadonlyData,
106112
@Cached("create()") GetDimAttributeNode getDimsNode,
107113
@Cached("create()") GetDimAttributeNode getResultDimsNode) {
@@ -110,8 +116,10 @@ protected RDoubleVector crossprodDoubleMatrix(RDoubleVector x, @SuppressWarnings
110116
int xCols = transposeX ? xDims[0] : xDims[1];
111117
int yRows = transposeX ? xDims[0] : xDims[1];
112118
int yCols = transposeX ? xDims[1] : xDims[0];
119+
Object xData = x.getData();
113120
RDoubleVector result = mirror(
114-
matMult.doubleMatrixMultiply(x, x, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(xDims[0]), getYColStride(xDims[0]), true),
121+
matMult.doubleMatrixMultiply(xDataLib, xData, x, xDataLib, xData, x, xRows, xCols, yRows, yCols, getXRowStride(xDims[0]), getXColStride(xDims[0]), getYRowStride(xDims[0]),
122+
getYColStride(xDims[0]), true),
115123
getResultDimsNode,
116124
getReadonlyData);
117125
return copyDimNames(x, x, result);

0 commit comments

Comments
 (0)