Skip to content

Commit 69afbf7

Browse files
rbrchenGary Frost
authored andcommitted
Buffer tagger cleanup
1 parent b446644 commit 69afbf7

35 files changed

Lines changed: 228 additions & 233 deletions

File tree

hat/backends/ffi/shared/src/main/native/cpp/shared.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ long Backend::CompilationUnit::Kernel::ndrange(void *argArray) {
403403
}
404404

405405
auto *buffer = static_cast<Buffer *>(bufferState->vendorPtr);
406-
if (kernelWroteToThisArg || compilationUnit->backend->config->alwaysCopy) {
406+
if (kernelWroteToThisArg && compilationUnit->backend->config->alwaysCopy) {
407407
compilationUnit->backend->queue->copyFromDevice(buffer);
408408
bufferState->state = BufferState::HOST_OWNED;
409409
if (compilationUnit->backend->config->traceCopies || compilationUnit->backend->config->traceEnqueues) {

hat/core/src/main/java/hat/BufferTagger.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
package hat;
2727

28+
import hat.phases.HATPhaseUtils;
2829
import jdk.incubator.code.dialect.java.JavaOp;
2930
import optkl.IfaceValue;
3031
import optkl.OpHelper;
@@ -147,23 +148,22 @@ private static void mapBranch(MethodHandles.Lookup lookup, Block.Reference block
147148
}
148149
}
149150

150-
// retrieves "root" value of an op, the origin of the parameter (or value) used by the op
151+
// retrieves "root" value of an op, which is how we track accesses
152+
// we will map the return value of this method to the accessType
151153
private static Value getRootValue(Op op) {
152-
if (op.operands().isEmpty()) {
153-
return op.result();
154-
} else if (op.operands().getFirst() instanceof Block.Parameter param) {
155-
return param;
156-
}
154+
// the op is a field load, an invoke, or something that reduces to one or the other
155+
// first, check if we can retrieve a fieldloadop from the given op
156+
Op fieldOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.FieldAccessOp.FieldLoadOp.class);
157+
if (fieldOp != null) return fieldOp.operands().getFirst(); // if so, we use its first operand to map to accesses
157158

158-
while (op.operands().getFirst() instanceof Op.Result result) { // Only first?
159-
op = result.op(); // we are changing our par here I assume intended
160-
if (op.operands().isEmpty()) { // if the "root op" is an invoke
161-
return op.result();
162-
}else{
163-
// or else
164-
}
159+
// we then check if there's an invokeop that has no operands (meaning a shared or private buffer that was created)
160+
// or if there's an invokeop with a parameter as its first operation (this is a global buffer)
161+
Op invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(op, JavaOp.InvokeOp.class);
162+
while (invokeOp != null && !invokeOp.operands().isEmpty()) {
163+
if (invokeOp.operands().getFirst() instanceof Block.Parameter p) return p; // return the parameter that is the global buffer
164+
invokeOp = HATPhaseUtils.findOpInResultFromFirstOperandsOrNull(invokeOp.operands().getFirst().result().op(), JavaOp.InvokeOp.class);
165165
}
166-
return op.operands().getFirst();
166+
return (invokeOp == null) ? null : invokeOp.result(); // return the shared/private buffer invokeop that creates the buffer
167167
}
168168

169169
// updates accessMap
@@ -175,7 +175,7 @@ private static void updateAccessType(Value value, AccessType currentAccess) {
175175
} else if (currentAccess != storedAccess && storedAccess != AccessType.RW) {
176176
accessMap.put(remappedValue, AccessType.RW);
177177
} else {
178-
// or else
178+
// this is the same access type as what's already stored
179179
}
180180
}
181181
}

hat/core/src/main/java/hat/buffer/ArgArray.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,8 @@ static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object...
280280
case Buffer buffer -> {
281281
Annotation[] annotations = parameterAnnotations[i];
282282
AccessType accessType = AccessType.NA;
283-
if (annotations.length > 0) {
284-
for (Annotation annotation : annotations) {
285-
accessType = AccessType.of(annotation);
286-
}
287-
} else {
288-
throw new IllegalArgumentException("Argument " + i + " has no access annotations");
283+
for (Annotation annotation : annotations) {
284+
accessType = AccessType.of(annotation);
289285
}
290286
MemorySegment segment = MappableIface.getMemorySegment(buffer);
291287
arg.variant((byte) '&');

hat/examples/blackscholes/src/main/java/blackscholes/Main.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import java.util.Random;
3838

3939
import optkl.ifacemapper.MappableIface.RO;
40-
import optkl.ifacemapper.MappableIface.RW;
4140
import optkl.ifacemapper.MappableIface.WO;
4241

4342
import jdk.incubator.code.Reflect;
@@ -46,12 +45,12 @@ public class Main {
4645
static Random rand;
4746

4847
@Reflect
49-
public static void blackScholesKernel(@RO KernelContext kc,
50-
@WO F32Array call,
51-
@WO F32Array put,
52-
@RO F32Array sArray,
53-
@RO F32Array xArray,
54-
@RO F32Array tArray,
48+
public static void blackScholesKernel(KernelContext kc,
49+
F32Array call,
50+
F32Array put,
51+
F32Array sArray,
52+
F32Array xArray,
53+
F32Array tArray,
5554
float r,
5655
float v) {
5756
if (kc.gix < kc.gsx){

hat/examples/dft/src/main/java/dft/Main.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ static ComplexArray create(Accelerator accelerator, int length) {
103103
}
104104

105105
@Reflect
106-
private static void dftKernel(@RW KernelContext kc, @RO ComplexArray input, @WO ComplexArray output) {
106+
private static void dftKernel(KernelContext kc, ComplexArray input, ComplexArray output) {
107107
int size = input.length();
108108
int idx = kc.gix;
109109
if (idx < kc.gsx) {
@@ -130,7 +130,7 @@ private static void dftCompute(@RW ComputeContext cc, @RO ComplexArray input, @W
130130
}
131131

132132
@Reflect
133-
private static void dftPlainKernel(@RW KernelContext kc, @RO F32Array inReal, @RO F32Array inImag, @WO F32Array outReal, @WO F32Array outImag) {
133+
private static void dftPlainKernel(KernelContext kc, F32Array inReal, F32Array inImag, F32Array outReal, F32Array outImag) {
134134
int size = inReal.length();
135135
int idx = kc.gix;
136136
if (idx < kc.gsx) {

hat/examples/flashattention/src/main/java/flashattention/Main.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ public class Main {
105105
* @param softMaxScale
106106
*/
107107
@Reflect
108-
public static void selfAttentionV2HAT(@RO KernelContext kernelContext,
109-
@RO F32Array Q, @RO F32Array K, @RO F32Array V,
110-
@WO F32Array attentionMatrix, @WO F32Array O,
111-
@RO final int N, @RO final int d, @RO final float softMaxScale) {
108+
public static void selfAttentionV2HAT(KernelContext kernelContext,
109+
F32Array Q, F32Array K, F32Array V,
110+
F32Array attentionMatrix, F32Array O,
111+
final int N, final int d, final float softMaxScale) {
112112
int idx = kernelContext.gix;
113113
if (idx < N) {
114114
// Compute the attention scores: Q * K^T and scale it to sqrt(d) => softMaxScale
@@ -382,9 +382,9 @@ public static int ceilFunction(int N, int blockN) {
382382
* @param softmaxScale
383383
*/
384384
@Reflect
385-
public static void flashAttention(@RO KernelContext kernelContext,
386-
@RO F32Array Q, @RO F32Array K, @RO F32Array V,
387-
@WO F32Array O, @RW F32Array m, @RW F32Array l,
385+
public static void flashAttention(KernelContext kernelContext,
386+
F32Array Q, F32Array K, F32Array V,
387+
F32Array O, F32Array m, F32Array l,
388388
final int N, final int d, final float softmaxScale) {
389389
int bx = kernelContext.bix;
390390
int tid = kernelContext.lix;
@@ -525,9 +525,9 @@ static PrivateF16Array createPrivate() {
525525
}
526526

527527
@Reflect
528-
public static void flashAttentionF16(@RO KernelContext kernelContext,
529-
@RO F16Array Q, @RO F16Array K, @RO F16Array V,
530-
@WO F16Array O, @RW F16Array m, @RW F16Array l,
528+
public static void flashAttentionF16(KernelContext kernelContext,
529+
F16Array Q, F16Array K, F16Array V,
530+
F16Array O, F16Array m, F16Array l,
531531
final int N, final int d, final float softmaxScale) {
532532
int bx = kernelContext.bix;
533533
int tid = kernelContext.lix;

hat/examples/heal/src/main/java/heal/ComputeHeal.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,11 @@ static int blue(int rgb) {
194194

195195
@Reflect
196196
public static void bestFitCore(int id,
197-
@RO S32Array2D s32Array2D,
198-
@RO Box searchArea,
199-
@RO Box selBox,
200-
@RO XYRGBList xyrgbList,
201-
@RW F32Array sumArray) {
197+
S32Array2D s32Array2D,
198+
Box searchArea,
199+
Box selBox,
200+
XYRGBList xyrgbList,
201+
F32Array sumArray) {
202202
int x = searchArea.x1() + id % searchArea.width();
203203
int y = searchArea.y1() + id / searchArea.width();
204204
float sum = 0;
@@ -233,12 +233,12 @@ public static void bestFitCore(int id,
233233
}
234234

235235
@Reflect
236-
public static void bestFitKernel(@RO KernelContext kc,
237-
@RO S32Array2D s32Array2D,
238-
@RO Box searchArea,
239-
@RO Box selectionBox,
240-
@RO XYRGBList xyrgbList,
241-
@RO F32Array sumArray) {
236+
public static void bestFitKernel(KernelContext kc,
237+
S32Array2D s32Array2D,
238+
Box searchArea,
239+
Box selectionBox,
240+
XYRGBList xyrgbList,
241+
F32Array sumArray) {
242242
bestFitCore(kc.gix, s32Array2D, searchArea, selectionBox, xyrgbList, sumArray);
243243
}
244244

hat/examples/life/src/main/java/life/Main.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int val(__global cellGrid_t *CLWrapCellGrid, int from, int w, int x, int y) {
144144

145145

146146
@Reflect
147-
public static int val(@RO CellGrid grid, int from, int w, int x, int y) {
147+
public static int val(CellGrid grid, int from, int w, int x, int y) {
148148
return grid.cell(((long) y * w) + x + from) & 1;
149149
}
150150

@@ -175,7 +175,7 @@ __kernel void life( __global cellGrid_t *CLWrapCellGrid ,__global control_t *CL
175175
""";
176176

177177
@Reflect
178-
public static void lifePerIdx(int idx, @RW Control control, @RW CellGrid cellGrid) {
178+
public static void lifePerIdx(int idx, Control control, CellGrid cellGrid) {
179179
int w = cellGrid.width();
180180
int h = cellGrid.height();
181181
int from = control.from();
@@ -199,7 +199,7 @@ public static void lifePerIdx(int idx, @RW Control control, @RW CellGrid cellGri
199199
}
200200

201201
@Reflect
202-
public static void life(@RO KernelContext kc, @RO Control control, @RW CellGrid cellGrid) {
202+
public static void life(KernelContext kc, Control control, CellGrid cellGrid) {
203203
if (kc.gix < kc.gsx) {
204204
ComputeLife.lifePerIdx(kc.gix, control, cellGrid);
205205
}

hat/examples/mandel/src/main/java/mandel/Main.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
public class Main {
4545
@Reflect
46-
public static void mandel(@RO KernelContext kc, @RW S32Array2D s32Array2D, @RO S32Array pallette, float offsetx, float offsety, float scale) {
46+
public static void mandel(KernelContext kc, S32Array2D s32Array2D, S32Array pallette, float offsetx, float offsety, float scale) {
4747
if (kc.gix < kc.gsx) {
4848
float width = s32Array2D.width();
4949
float height = s32Array2D.height();

hat/examples/matmul/src/main/java/matmul/Main.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public class Main {
8484
* @param size
8585
*/
8686
@Reflect
87-
public static void matrixMultiplyKernel2D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
87+
public static void matrixMultiplyKernel2D(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
8888
if (kc.gix < kc.gsx) {
8989
if (kc.giy < kc.gsy) {
9090
float acc = 0.0f;
@@ -106,7 +106,7 @@ public static void matrixMultiplyKernel2D(@RO KernelContext kc, @RO F32Array mat
106106
* @param size
107107
*/
108108
@Reflect
109-
public static void matrixMultiplyKernel2DLI(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
109+
public static void matrixMultiplyKernel2DLI(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
110110
if (kc.gix < kc.gsx) {
111111
if (kc.giy < kc.gsy) {
112112
float acc = 0.0f;
@@ -141,7 +141,7 @@ static MyLocalArrayFixedSize createLocal() {
141141
}
142142

143143
@Reflect
144-
public static void matrixMultiplyKernel2DTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
144+
public static void matrixMultiplyKernel2DTiling(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
145145

146146
final int tileSize = 16;
147147
MyLocalArrayFixedSize tileA = MyLocalArrayFixedSize.createLocal();
@@ -254,7 +254,7 @@ static FlatPrivate createPrivate() {
254254
* @param size
255255
*/
256256
@Reflect
257-
public static void matrixMultiplyKernel2DRegisterTiling(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
257+
public static void matrixMultiplyKernel2DRegisterTiling(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
258258

259259
// Configuration for the kernel: Keep in mind that if you change the following parameters,
260260
// also change the scheduling (global and local work sizes).
@@ -376,7 +376,7 @@ public static void matrixMultiplyKernel2DRegisterTiling(@RO KernelContext kc, @R
376376
* @param size
377377
*/
378378
@Reflect
379-
public static void matrixMultiplyKernel2DRegisterTilingVectorized(@RO KernelContext kc, @RO F32ArrayPadded matrixA, @RO F32ArrayPadded matrixB, @WO F32ArrayPadded matrixC, int size) {
379+
public static void matrixMultiplyKernel2DRegisterTilingVectorized(KernelContext kc, F32ArrayPadded matrixA, F32ArrayPadded matrixB, F32ArrayPadded matrixC, int size) {
380380

381381
// Configuration for the kernel: Keep in mind that if you change the following parameters,
382382
// also change the scheduling (global and local work sizes).
@@ -524,7 +524,7 @@ static FlatPrivateHalf createPrivate() {
524524
}
525525

526526
@Reflect
527-
public static void matrixMultiplyKernel2DRegisterTilingHalf(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F16Array matrixC, int size) {
527+
public static void matrixMultiplyKernel2DRegisterTilingHalf(KernelContext kc, F16Array matrixA, F16Array matrixB, F16Array matrixC, int size) {
528528

529529
// Configuration for the kernel: Keep in mind that if you change the following parameters,
530530
// also change the scheduling (global and local work sizes).
@@ -648,7 +648,7 @@ public static float compute(@RO KernelContext kc, @RO F32Array matrixA, @RO F32A
648648
* @param size
649649
*/
650650
@Reflect
651-
public static void matrixMultiplyKernel1D(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
651+
public static void matrixMultiplyKernel1D(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
652652
if (kc.gix < kc.gsx) {
653653
for (int j = 0; j < size; j++) {
654654
float acc = 0.0f;
@@ -664,7 +664,7 @@ public static void matrixMultiplyKernel1D(@RO KernelContext kc, @RO F32Array mat
664664
* 1D Matrix Multiply with function calls passing the kernel context ID. This is just for testing purposes.
665665
*/
666666
@Reflect
667-
public static void matrixMultiplyKernel1DWithFunctionCalls(@RO KernelContext kc, @RO F32Array matrixA, @RO F32Array matrixB, @WO F32Array matrixC, int size) {
667+
public static void matrixMultiplyKernel1DWithFunctionCalls(KernelContext kc, F32Array matrixA, F32Array matrixB, F32Array matrixC, int size) {
668668
if (kc.gix < kc.gsx) {
669669
for (int j = 0; j < size; j++) {
670670
float acc = compute(kc, matrixA, matrixB, size, j);

0 commit comments

Comments
 (0)