3030import com .oracle .truffle .api .CompilerDirectives ;
3131import com .oracle .truffle .api .dsl .Cached ;
3232import com .oracle .truffle .api .dsl .Specialization ;
33+ import com .oracle .truffle .api .library .CachedLibrary ;
3334import com .oracle .truffle .api .profiles .ConditionProfile ;
35+ import com .oracle .truffle .r .runtime .data .VectorDataLibrary ;
3436import com .oracle .truffle .r .runtime .data .nodes .ExtractListElement ;
3537import com .oracle .truffle .r .runtime .data .nodes .attributes .SpecialAttributesFunctions .GetDimAttributeNode ;
3638import com .oracle .truffle .r .runtime .data .nodes .attributes .SpecialAttributesFunctions .GetDimNamesAttributeNode ;
@@ -80,8 +82,10 @@ public static CrossprodCommon createTCrossprod() {
8082 return CrossprodCommonNodeGen .create (false );
8183 }
8284
83- @ Specialization (guards = {"getXDimsNode.isMatrix(x)" , "getYDimsNode.isMatrix(y)" })
85+ @ Specialization (guards = {"getXDimsNode.isMatrix(x)" , "getYDimsNode.isMatrix(y)" }, limit = "getTypedVectorDataLibraryCacheSize()" )
8486 protected RDoubleVector crossprod (RDoubleVector x , RDoubleVector y ,
87+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary xDataLib ,
88+ @ CachedLibrary ("y.getData()" ) VectorDataLibrary yDataLib ,
8589 @ Cached ("create()" ) GetDimAttributeNode getXDimsNode ,
8690 @ Cached ("create()" ) GetDimAttributeNode getYDimsNode ) {
8791 int [] xDims = getXDimsNode .getDimensions (x );
@@ -90,7 +94,8 @@ protected RDoubleVector crossprod(RDoubleVector x, RDoubleVector y,
9094 int xCols = transposeX ? xDims [0 ] : xDims [1 ];
9195 int yRows = transposeX ? yDims [0 ] : yDims [1 ];
9296 int yCols = transposeX ? yDims [1 ] : yDims [0 ];
93- RDoubleVector result = matMult .doubleMatrixMultiply (x , y , xRows , xCols , yRows , yCols , getXRowStride (xDims [0 ]), getXColStride (xDims [0 ]), getYRowStride (yDims [0 ]), getYColStride (yDims [0 ]),
97+ RDoubleVector result = matMult .doubleMatrixMultiply (xDataLib , x .getData (), x , yDataLib , y .getData (), y , xRows , xCols , yRows , yCols , getXRowStride (xDims [0 ]), getXColStride (xDims [0 ]),
98+ getYRowStride (yDims [0 ]), getYColStride (yDims [0 ]),
9499 false );
95100 return copyDimNames (x , y , result );
96101 }
@@ -100,8 +105,9 @@ protected Object crossprod(RAbstractVector x, RAbstractVector y) {
100105 return copyDimNames (x , y , (RAbstractVector ) matMult .executeObject (transposeX (x ), transposeY (y )));
101106 }
102107
103- @ Specialization (guards = "getDimsNode.isMatrix(x)" )
108+ @ Specialization (guards = "getDimsNode.isMatrix(x)" , limit = "getTypedVectorDataLibraryCacheSize()" )
104109 protected RDoubleVector crossprodDoubleMatrix (RDoubleVector x , @ SuppressWarnings ("unused" ) RNull y ,
110+ @ CachedLibrary ("x.getData()" ) VectorDataLibrary xDataLib ,
105111 @ Cached ("create()" ) GetReadonlyData .Double getReadonlyData ,
106112 @ Cached ("create()" ) GetDimAttributeNode getDimsNode ,
107113 @ Cached ("create()" ) GetDimAttributeNode getResultDimsNode ) {
@@ -110,8 +116,10 @@ protected RDoubleVector crossprodDoubleMatrix(RDoubleVector x, @SuppressWarnings
110116 int xCols = transposeX ? xDims [0 ] : xDims [1 ];
111117 int yRows = transposeX ? xDims [0 ] : xDims [1 ];
112118 int yCols = transposeX ? xDims [1 ] : xDims [0 ];
119+ Object xData = x .getData ();
113120 RDoubleVector result = mirror (
114- matMult .doubleMatrixMultiply (x , x , xRows , xCols , yRows , yCols , getXRowStride (xDims [0 ]), getXColStride (xDims [0 ]), getYRowStride (xDims [0 ]), getYColStride (xDims [0 ]), true ),
121+ matMult .doubleMatrixMultiply (xDataLib , xData , x , xDataLib , xData , x , xRows , xCols , yRows , yCols , getXRowStride (xDims [0 ]), getXColStride (xDims [0 ]), getYRowStride (xDims [0 ]),
122+ getYColStride (xDims [0 ]), true ),
115123 getResultDimsNode ,
116124 getReadonlyData );
117125 return copyDimNames (x , x , result );
0 commit comments