diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CUDACodeGenException.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CUDACodeGenException.java deleted file mode 100644 index b13f2cebe8f..00000000000 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CUDACodeGenException.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.backend.ffi; - -import optkl.exceptions.CodeGenException; - -public class CUDACodeGenException extends CodeGenException { - - protected CUDACodeGenException(String message) { - super(message); - } -} diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java index a580f9632c3..1445c780955 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java @@ -27,15 +27,8 @@ import hat.callgraph.KernelCallGraph; import hat.codebuilders.C99HATKernelBuilder; import hat.dialect.HATF16Op; -import hat.dialect.HATTensorOp; import hat.dialect.HATVectorOp; import hat.types.F16; -import hat.types.Tensor; -import jdk.incubator.code.dialect.core.CoreOp; -import jdk.incubator.code.dialect.java.ClassType; -import jdk.incubator.code.dialect.java.FieldRef; -import jdk.incubator.code.dialect.java.JavaOp; -import optkl.OpHelper; import optkl.codebuilders.CodeBuilder; import optkl.codebuilders.ScopedCodeBuilderContext; import hat.types.BF16; @@ -43,13 +36,8 @@ import jdk.incubator.code.Value; import jdk.incubator.code.dialect.java.PrimitiveType; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.SequencedSet; - -import static optkl.OpHelper.Invoke.invoke; public class CudaHATKernelBuilder extends C99HATKernelBuilder { @@ -57,11 +45,6 @@ protected CudaHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilde super(kernelCallGraph, scopedCodeBuilderContext); } - @Override - protected CudaHATKernelBuilder hatWarpSize() { - return constant("32"); - } - private CudaHATKernelBuilder half2float() { return id("__half2float"); } @@ -175,8 +158,7 @@ public CudaHATKernelBuilder defines() { .includeSys("cuda_fp16.h", "cuda_bf16.h") .hashDefine("BFLOAT16", _ -> keyword("__nv_bfloat16")) .typedefSingleValueStruct("F16", "half") - .typedefSingleValueStruct("BF16", "BFLOAT16") - .includeSys("mma.h"); // only enable if tensor views are used + .typedefSingleValueStruct("BF16", "BFLOAT16"); } @Override @@ -192,7 +174,7 @@ public CudaHATKernelBuilder hatVectorStoreOp(HATVectorOp.HATVectorStoreView hatV paren(_ -> { ampersand().recurseResultOrThrow(dest); either(hatVectorStoreView instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow); - id(ARRAY).sbrace(_ -> recurseResultOrThrow(index)); + id("array").sbrace(_ -> recurseResultOrThrow(index)); }); sbrace(_ -> intConstZero()); sp().equals().sp(); @@ -266,7 +248,7 @@ public CudaHATKernelBuilder hatVectorLoadOp(HATVectorOp.HATVectorLoadOp hatVecto paren(_ -> { ampersand();recurseResultOrThrow(source); either(hatVectorLoadOp instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow); - id(ARRAY).sbrace(_ -> recurseResultOrThrow(index)); + id("array").sbrace(_ -> recurseResultOrThrow(index)); }); sbrace(_ -> intConstZero()); return self(); @@ -440,267 +422,4 @@ private CudaHATKernelBuilder generateFloat16ConversionToFloat(Class float16Cl protected String mapMathIntrinsic(String hatMathIntrinsicName) { return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName); } - - public static final String WMMA_MEM_COL_MAJOR = "nvcuda::wmma::mem_col_major"; - public static final String WMMA_MEM_ROW_MAJOR = "nvcuda::wmma::mem_row_major"; - public static final String WMMA_STORE_TENSOR = "nvcuda::wmma::store_matrix_sync"; - public static final String WMMA_LOAD_TENSOR = "nvcuda::wmma::load_matrix_sync"; - public static final String WMMA_MMA_TENSOR = "nvcuda::wmma::mma_sync"; - public static final String WMMA_FILL_TENSOR = "nvcuda::wmma::fill_fragment"; - public static final String WMMA_COL_MAJOR = "nvcuda::wmma::col_major"; - public static final String WMMA_ROW_MAJOR = "nvcuda::wmma::row_major"; - public static final String WMMA_FRAGMENT_BASE = "nvcuda::wmma::fragment shape, String matrixOrder, String type, Value access) { - id(WMMA_FRAGMENT_BASE) - .id(matrixOrder) - .comma().sp() - .intValue(shape.getFirst()) - .comma().sp() - .intValue(shape.get(1)) - .comma().sp() - .intValue(shape.get(2)) - .comma().sp() - .type(type); - - if (matrixOrder.equals(TENSOR_ACC)) { - gt(); - } else {// infer from the last parameter - if (access.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - // Expecting an invokeOp - var invoke = invoke(scopedCodeBuilderContext().lookup(), invokeOp); - comma(); - if (invoke.resultTypeIs(Tensor.ColumMajor.class)) { - id(WMMA_COL_MAJOR); - } else if (invoke.resultTypeIs(Tensor.RowMajor.class)) { - id(WMMA_ROW_MAJOR); - } else { - throw new CUDACodeGenException("[Error]"); - } - gt(); - } - } - return self(); - } - - private static final String TENSOR_MATRIX_A = "matrix_a"; - private static final String TENSOR_MATRIX_B = "matrix_b"; - private final String TENSOR_ACC = "accumulator"; - - private String getMatrixOrder(Value valueParameter) { - if (valueParameter instanceof Op.Result r && r.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) { - FieldRef fieldRef = fieldLoadOp.fieldReference(); - return switch (fieldRef.name()) { - case "FIRST" -> TENSOR_MATRIX_A; - case "SECOND" -> TENSOR_MATRIX_B; - default -> TENSOR_ACC; - }; - } - return null; - } - - @Override - public CudaHATKernelBuilder hatTensorCreateOp(HATTensorOp.TensorCreateOp tensorCreateOp) { - // infer first parameter - List operands = tensorCreateOp.operands(); - Value first = operands.getFirst(); - // The first operand gives us the matrix order or accumulator - String matrixOrder = getMatrixOrder(first); - - // Second parameters: analysis of the shape - List shape = new ArrayList<>(); - - Value second = operands.get(1); - if (second.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - List shapeOperands = invokeOp.operands(); - for (Value shapeOperand : shapeOperands) { - if (shapeOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - shape.add((int) constantOp.value()); - } else { - throw new CUDACodeGenException("Error: expected to find a ConstantOp, but found a " + shapeOperand.declaringElement().getClass()); - } - } - } else { - throw new CUDACodeGenException("InvokeOp expected, but found: " + second.declaringElement().getClass()); - } - if (shape.size() != 3) { - throw new CUDACodeGenException("Shape must have three values"); - } - - // The third parameter is the type. It could be `half` or `float` as first implementation - // This parameter is another constant with the type - Value classOperand = operands.get(2); - Object klass = null; - if (classOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - klass = constantOp.value(); - } - - String type = ""; - if (klass instanceof ClassType classType && classType.toClassName().equals(F16.class.getCanonicalName())) { - type = "half"; - } else if (klass instanceof PrimitiveType primitiveType && primitiveType.equals(PrimitiveType.FLOAT)) { - type = "float"; - } - - Value access = operands.getLast(); - return generateCreateTensor(shape, matrixOrder, type, access); - } - - @Override - public CudaHATKernelBuilder hatTensorVarLoadOp(HATTensorOp.TensorVarLoadOp hatTensorVarLoadOp) { - Value operand = hatTensorVarLoadOp.operands().getFirst(); - if (operand instanceof Op.Result r && r.op() instanceof HATTensorOp.TensorVarOp tensorVarOp) { - varName(tensorVarOp.varName()); - } else { - throw new CUDACodeGenException("[ERROR] Expected HATTensorVarOp"); - } - return self(); - } - - @Override - public CudaHATKernelBuilder hatTensorFillOp(HATTensorOp.TensorFillOp tensorFillOp) { - id(WMMA_FILL_TENSOR).paren( _-> { - List operands = tensorFillOp.operands(); - recurseResultOrThrow(operands.getFirst()) - .comma() - .recurseResultOrThrow(operands.get(1)); - }); - return self(); - } - - @Override - public CUDACodeGenException launchBackendException(String message) { - return new CUDACodeGenException(message); - } - - @Override - public CudaHATKernelBuilder hatTensorMMAOp(HATTensorOp.TensorMMAOp tensorMMAOp) { - id(WMMA_MMA_TENSOR).paren( _-> commaSeparated(tensorMMAOp.operands(), this::recurseResultOrThrow)); - return self(); - } - - private CudaHATKernelBuilder generateLoadTensor(HATTensorOp.TensorLoadOp tensorLoadOp, boolean isColumnMajor, String tensorName) { - // First operand is the reference to global memory - List operands = tensorLoadOp.operands(); - Value reference = operands.getFirst(); - id(WMMA_LOAD_TENSOR) - .paren(_ -> { - id(tensorName).comma(); - paren(_ -> type("half").asterisk()); - recurseResultOrThrow(reference); - rarrow().id(ARRAY) - .sp().plus().sp() - .indexForTensor(isColumnMajor, operands.get(1), operands.get(2), operands.get(3)) - .comma(); - recurseResultOrThrow(operands.get(3)); - }); - - return self(); - } - - /** - * Example of code being generated: - * - *

- * - * wmma::load_matrix_sync(a_frag, matrix->array + headSize + aRow + aCol * lda, lda); - * - *

- * - * @param tensorLoadOp - * - * @return {@link CudaHATKernelBuilder} - */ - @Override - public CudaHATKernelBuilder hatTensorLoadOp(HATTensorOp.TensorLoadOp tensorLoadOp) { - // Find name tensor of the first argument - String tensorName = ""; - SequencedSet uses = tensorLoadOp.result().uses(); - HATTensorOp.TensorVarOp tensorVarOp = null; - for (Op.Result result : uses) { - if (result.declaringElement() instanceof HATTensorOp.TensorStoreLoadOp storeLoadOp) { - // obtain first arg from tensorStoreOp - Value first = storeLoadOp.operands().getFirst(); - if (first.declaringElement() instanceof HATTensorOp.TensorVarOp varOp) { - tensorVarOp = varOp; - tensorName = tensorVarOp.varName(); - } - } - } - - boolean isColumnMajor = true; - if (tensorVarOp != null) { - Value value = tensorVarOp.operands().getFirst(); - if (value.declaringElement() instanceof HATTensorOp.TensorCreateOp createOp) { - Value tensorLayout = createOp.operands().getLast(); - isColumnMajor = isColumnMajor(tensorLayout); - } - } - return generateLoadTensor(tensorLoadOp, isColumnMajor, tensorName); - } - - @Override - public CudaHATKernelBuilder hatTensorStoreLoadOp(HATTensorOp.TensorStoreLoadOp hatTensorStoreLoadOp) { - List operands = hatTensorStoreLoadOp.operands(); - recurseResultOrThrow(operands.getLast()); - return self(); - } - - /** - * Example of code being generated: - * - *

- * - * store_matrix_sync(matrix->array + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major); - * - *

- * - * @param operands - * @param isColumnMajor - * - * @return {@link CudaHATKernelBuilder} - */ - private CudaHATKernelBuilder generateStoreTensor(List operands, boolean isColumnMajor) { - Value reference = operands.getFirst(); - id(WMMA_STORE_TENSOR).paren(_ -> { - Value iIndex = operands.get(1); - Value jIndex = operands.get(2); - Value tensorToStore = operands.get(3); - Value ldSize = operands.get(4); - - recurseResultOrThrow(reference) - .rarrow().id(ARRAY) - .sp().plus().sp() - .indexForTensor(isColumnMajor, iIndex, jIndex, ldSize) - .comma() - .recurseResultOrThrow(tensorToStore) - .comma() - .recurseResultOrThrow(ldSize) - .comma(); - - if (isColumnMajor) { - id(WMMA_MEM_COL_MAJOR); - } else { - id(WMMA_MEM_ROW_MAJOR); - } - }); - return self(); - } - - @Override - public CudaHATKernelBuilder hatTensorStoreOp(HATTensorOp.TensorStoreOp tensorStoreOp) { - List operands = tensorStoreOp.operands(); - // Access layout is the last operand - final boolean isColumnMajor = isColumnMajor(operands.get(5)); - return generateStoreTensor(operands, isColumnMajor); - } - - private static final String ARRAY = "array"; } diff --git a/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend_queue.cpp b/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend_queue.cpp index 966e97e6e60..5346d5ca965 100644 --- a/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend_queue.cpp +++ b/hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend_queue.cpp @@ -151,31 +151,14 @@ void CudaBackend::CudaQueue::dispatch(KernelContext *kernelContext, CompilationU int threadsPerBlockY = estimateThreadsPerBlock(kernelContext->dimensions, kernelContext->gsy, kernelContext->lsy); int threadsPerBlockZ = estimateThreadsPerBlock(kernelContext->dimensions, kernelContext->gsz, kernelContext->lsz); - - int warpFactor[3] = {1, 1, 1}; - if (kernelContext->wsx) { - warpFactor[0] = 32; - } - if (kernelContext->wsy) { - warpFactor[1] = 32; - } - if (kernelContext->wsz) { - warpFactor[2] = 32; - } - - int globalSize[3] = {kernelContext->gsx, kernelContext->gsy, kernelContext->gsz}; - globalSize[0] = kernelContext->tlx? ((kernelContext->gsx + kernelContext->tlx - 1) / kernelContext->tlx) * warpFactor[0]: kernelContext->gsx; - globalSize[1] = kernelContext->tly? ((kernelContext->gsy + kernelContext->tly - 1) / kernelContext->tly) * warpFactor[1]: kernelContext->gsy; - globalSize[2] = kernelContext->tlz? ((kernelContext->gsz + kernelContext->tlz - 1) / kernelContext->tlz) * warpFactor[2]: kernelContext->gsz; - - int blocksPerGridX = ((globalSize[0] + threadsPerBlockX - 1) / threadsPerBlockX); + int blocksPerGridX = (kernelContext->gsx + threadsPerBlockX - 1) / threadsPerBlockX; int blocksPerGridY = 1; int blocksPerGridZ = 1; if (kernelContext->dimensions > 1) { - blocksPerGridY = ((globalSize[1] + threadsPerBlockY - 1) / threadsPerBlockY); + blocksPerGridY = (kernelContext->gsy + threadsPerBlockY - 1) / threadsPerBlockY; } if (kernelContext->dimensions > 2) { - blocksPerGridZ = ((globalSize[2] + threadsPerBlockZ - 1) / threadsPerBlockZ); + blocksPerGridZ = (kernelContext->gsz + threadsPerBlockZ - 1) / threadsPerBlockZ; } // Enable debug information with info: HAT=INFO diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLCodeGenException.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLCodeGenException.java deleted file mode 100644 index b6159e9915c..00000000000 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLCodeGenException.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.backend.ffi; - -import optkl.exceptions.CodeGenException; - -public class OpenCLCodeGenException extends CodeGenException { - - protected OpenCLCodeGenException(String message) { - super(message); - } -} diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 7e73f40bc79..df30b251247 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -27,25 +27,15 @@ import hat.callgraph.KernelCallGraph; import hat.codebuilders.C99HATKernelBuilder; import hat.dialect.HATF16Op; -import hat.dialect.HATTensorOp; import hat.dialect.HATVectorOp; import hat.types.BF16; import hat.types.F16; -import optkl.OpHelper; import optkl.codebuilders.CodeBuilder; -import jdk.incubator.code.Value; -import jdk.incubator.code.dialect.core.CoreOp; -import jdk.incubator.code.dialect.java.ClassType; -import jdk.incubator.code.dialect.java.JavaOp; -import jdk.incubator.code.dialect.java.PrimitiveType; import optkl.codebuilders.ScopedCodeBuilderContext; import jdk.incubator.code.Op; -import optkl.exceptions.CodeGenException; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.Random; public class OpenCLHATKernelBuilder extends C99HATKernelBuilder { @@ -53,11 +43,6 @@ protected OpenCLHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuil super(kernelCallGraph,scopedCodeBuilderContext); } - @Override - protected OpenCLHATKernelBuilder hatWarpSize() { - return constant("1"); - } - public OpenCLHATKernelBuilder vstore(int dims) { return id("vstore" + dims); } @@ -98,7 +83,7 @@ public OpenCLHATKernelBuilder defines() { .when(kernelCallGraph.accessedKernelContextFields.contains("bsz"), _->hashDefine("HAT_BSZ", _ -> paren(_ -> id("get_num_groups").paren(_ -> intConstTwo())))) .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->maxMacro("MAX_HAT")) .when(!kernelCallGraph.accessedFP16Classes.isEmpty(), _->minMacro("MIN_HAT")) - .when(kernelCallGraph.usesBarrier || kernelCallGraph.useTensors, _ ->hashDefine("HAT_BARRIER", _ -> id("barrier").oparen().id("CLK_LOCAL_MEM_FENCE").cparen())) + .when(kernelCallGraph.usesBarrier, _ ->hashDefine("HAT_BARRIER", _ -> id("barrier").oparen().id("CLK_LOCAL_MEM_FENCE").cparen())) /*.when(callgraphState.usesFp16,_->*/.hashDefine("BFLOAT16", _ -> keyword("ushort"))//) /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("F16", "half")//) /*.when(callgraphState.usesFp16,_->*/.typedefSingleValueStruct("BF16", "BFLOAT16")//) @@ -213,7 +198,7 @@ public OpenCLHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp } else if (!hatF16ToFloatConvOp.wasFloat()) { dot(); } else{ - throw new OpenCLCodeGenException("Can we get here"); + throw new RuntimeException("Can we get here"); } id("value"); }); @@ -254,662 +239,4 @@ public OpenCLHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp protected String mapMathIntrinsic(String hatMathIntrinsicName) { return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName); } - - @Override - public OpenCLCodeGenException launchBackendException(String message) { - throw new OpenCLCodeGenException(message); - } - - @Override - public OpenCLHATKernelBuilder hatTensorVarOp(HATTensorOp.TensorVarOp tensorVarOp) { - recurse(OpHelper.asResultOrThrow(tensorVarOp.operands().getFirst()).op()); - // We don't need to generate the name at this point, but rather during tensor create. - // That's the place we know all information, including type, shape, and name - return self(); - } - - @Override - public OpenCLHATKernelBuilder hatTensorCreateOp(HATTensorOp.TensorCreateOp tensorCreateOp) { - List operands = tensorCreateOp.operands(); - - // Second parameters: analysis of the shape - int[] shape = new int[3]; - Value second = operands.get(1); - if (second.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - List shapeOperands = invokeOp.operands(); - for (int i = 0; i < shapeOperands.size(); i++) { - Value shapeOperand = shapeOperands.get(i); - if (shapeOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - shape[i] = (int) constantOp.value(); - } - } - } - - // The third parameter is the type. It could be `half` or `float` as first implementation - // This parameter is another constant with the type - Value classOperand = operands.get(2); - Object klass = null; - if (classOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - klass = constantOp.value(); - } - - var tensorVarValue = tensorCreateOp.result().uses().getFirst(); - String varTensorName = null; - if (tensorVarValue.declaringElement() instanceof HATTensorOp.TensorVarOp tensorVarOp) { - varTensorName = tensorVarOp.varName(); - } - final int size = shape[0] * shape[1]; - if (tensorCreateOp.operands().size() > 3) { - // Share memory only for the input tiles (tensors) - // The accumulator is stored in private memory - HAT_LOCAL_MEM().sp(); - } - - switch (klass) { - case ClassType classType when classType.toClassName().equals(F16.class.getCanonicalName()) -> f16Type(); - case PrimitiveType primitiveType when primitiveType.equals(PrimitiveType.FLOAT) -> type("float"); - case null, default -> throw new OpenCLCodeGenException("[ERROR] Codegen. Type " + klass + " not expected"); - } - sp().varName(varTensorName).sbrace(_-> constant(Integer.toString(size))); - return self(); - } - - static HATTensorOp.TensorVarOp findTensorVarOp(Value varLoadOp) { - return switch (varLoadOp.declaringElement()) { - case HATTensorOp.TensorVarLoadOp tensorVarLoadOp -> findTensorVarOp(tensorVarLoadOp.operands().getFirst()); - case CoreOp.VarAccessOp.VarLoadOp varLoadOp2 -> findTensorVarOp(varLoadOp2.operands().getFirst()); - case HATTensorOp.TensorVarOp tensorVarOp -> tensorVarOp; - case null, default -> null; - }; - } - - static float getValueConstantTensor(Value v) { - if ((v instanceof Op.Result r && r.op() instanceof CoreOp.ConstantOp constant)) { - Object valueConstant = constant.value(); - return (float) valueConstant; - - } else if (v instanceof Op.Result r) { - return getValueConstantTensor(r.op().operands().getFirst()); - } - return -1.0f; - } - - private int[] getShapeFromTensorCreateValue(Value tensorCreateValue) { - if (tensorCreateValue.declaringElement() instanceof HATTensorOp.TensorCreateOp tensorCreateOp) { - // Second parameters: analysis of the shape - int[] shape = new int[3]; - Value second = tensorCreateOp.operands().get(1); - if (second.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - List shapeOperands = invokeOp.operands(); - for (int i = 0; i < shapeOperands.size(); i++) { - Value shapeOperand = shapeOperands.get(i); - if (shapeOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - shape[i] = (int) constantOp.value(); - } - } - } - return shape; - } - return new int[]{}; - } - - private int[] getShapeFromTensorVarOp(HATTensorOp.TensorVarOp tensorVarOp) { - Value tensorCreateValueOp = tensorVarOp.operands().getFirst(); - if (tensorCreateValueOp.declaringElement() instanceof HATTensorOp.TensorCreateOp tensorCreateOp) { - // Second parameters: analysis of the shape - int[] shape = new int[3]; - Value second = tensorCreateOp.operands().get(1); - if (second.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - List shapeOperands = invokeOp.operands(); - for (int i = 0; i < shapeOperands.size(); i++) { - Value shapeOperand = shapeOperands.get(i); - if (shapeOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) { - shape[i] = (int) constantOp.value(); - } - } - } - return shape; - } - return new int[]{}; - } - - private boolean isColumnMajorFromVarOp(HATTensorOp.TensorVarOp tensorVarOp) { - Value tensorCreateValueOp = tensorVarOp.operands().getFirst(); - if (tensorCreateValueOp.declaringElement() instanceof HATTensorOp.TensorCreateOp tensorCreateOp) { - // Parameter 3 defines the access layout - Value valueLayout = tensorCreateOp.operands().get(3); - return isColumnMajor(valueLayout); - } - return false; - } - - private String generateVariableName(String prefix) { - String vocab = "abcdefghijklmnopqrstuvxyz"; - Random r = new Random(); - StringBuilder varA = new StringBuilder(prefix); - for (int i = 0; i < 3; i++) { - varA.append(vocab.charAt(r.nextInt(vocab.length()))); - } - return varA.toString(); - } - - private static final String INDEX_PREFIX = "index_$"; - - /** - * Code example being generated: - * - *

- * - * for (int m = 0; m < " + shape[0] + "; m++) { - * for (int n = 0; n < " + shape[1] + "; n++) { - * tensorVarOp.varName() + "[m * " + shape[0] + " + n] = " + initValue + "f;" + "}" + "}"); - * - *

- * - * @param from - * @param to - * @param tensorVarOp - * @param initValue - * - * @return {@link OpenCLHATKernelBuilder} - */ - private OpenCLHATKernelBuilder emitForLoopWithBound(int from, int to, HATTensorOp.TensorVarOp tensorVarOp, float initValue) { - String prefix = INDEX_PREFIX; - String varA = generateVariableName(prefix); - String varB = generateVariableName(prefix); - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varA).assign().intValue(from).semicolon(); - id(varA).sp().lt().sp().intValue(to).semicolon(); - id(varA).plusplus(); - }).sp().brace(_ -> { - in().nl().forKeyword().sp().paren(_ -> { - s32Type().sp().id(varB).assign().intValue(from).semicolon(); - id(varB).sp().lt().sp().intValue(to).semicolon(); - id(varB).plusplus(); - }).sp().in(); - - brace(_ -> nl() - .id(tensorVarOp.varName()) - .sbrace(_ -> - id(varA).mul() - .id(Integer.toString(to)) - .plus() - .id(varB)) - .assign() - .constant(Float.toString(initValue)).id("f") - .semicolon().nl()).out().out(); - }); - return self(); - } - - private OpenCLHATKernelBuilder emitTensorFill(int[] shape, HATTensorOp.TensorVarOp tensorVarOp, float initValue) { - - return emitForLoopWithBound(0, shape[0], tensorVarOp, initValue); - } - - /** - * Code example being generated: - * - *

- * - * for (int m = 0; m < SHAPE_1; m++) - * for (int n = 0; n < SHAPE_2; n++) - * tensor[m * SHAPE_1 + n] = initValue; - * - *

- * - * @param tensorFillOp - * - * @return {@link OpenCLHATKernelBuilder} - */ - @Override - public OpenCLHATKernelBuilder hatTensorFillOp(HATTensorOp.TensorFillOp tensorFillOp) { - - // 1. Access to the variable name - var tensorValue = tensorFillOp.operands().getFirst(); - HATTensorOp.TensorVarOp tensorVarOp = findTensorVarOp(tensorValue); - if (tensorVarOp == null) { - throw new OpenCLCodeGenException("[Error][Codegen] Expected a tensorVarOp, but found `null` instead"); - } - - // 2. Access the shape - // Second parameters: analysis of the shape - Value tensorAccDecl = tensorVarOp.operands().getFirst(); - int[] shape = getShapeFromTensorCreateValue(tensorAccDecl); - - // 3. Access the layout - var tensorInitValue = tensorFillOp.operands().get(1); - float initValue = getValueConstantTensor(tensorInitValue); - - emitTensorFill(shape, tensorVarOp, initValue); - return self(); - } - - @Override - public OpenCLHATKernelBuilder hatTensorVarLoadOp(HATTensorOp.TensorVarLoadOp hatTensorVarLoadOp) { - Value operand = hatTensorVarLoadOp.operands().getFirst(); - if (operand instanceof Op.Result r && r.op() instanceof HATTensorOp.TensorVarOp tensorVarOp) { - varName(tensorVarOp.varName()); - } else { - throw new OpenCLCodeGenException("[ERROR] Expected HATTensorVarOp"); - } - return self(); - } - - /** - * Example of code being generated: - * - *

- * - * for (int m = 0; m < WMMA_M; m++) { - for (int n = 0; n < WMMA_N; n++) { - * float sum = acc[m][n]; - * for (int k = 0; k < WMMA_K; k++) { - * F16_t ha = a_frag[m * WMMA_M + k]; - * F16_t hb = b_frag[k * WMMA_M + n]; - * F16_t result = (F16_t){(ha.value * hb.value)}; - * sum += (float)(result.value); - * } - * acc[m][n] = sum; - * } - * } - * - *

- * - * @param shape - * @param tensorA - * @param tensorB - * @param tensorC - * @param result - * - * @return {@link OpenCLHATKernelBuilder} - */ - private OpenCLHATKernelBuilder generateTensorMMA(int[] shape, HATTensorOp.TensorVarOp tensorA, HATTensorOp.TensorVarOp tensorB, HATTensorOp.TensorVarOp tensorC, HATTensorOp.TensorVarOp result) { - String prefix = INDEX_PREFIX; - String varA = generateVariableName(prefix); - String varB = generateVariableName(prefix); - String varC = generateVariableName(prefix); - String acc = generateVariableName("sum_"); - final int from = 0; - final int to = shape[0]; - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varA).assign().intValue(from).semicolon(); - id(varA).sp().lt().sp().intValue(to).semicolon(); - id(varA).plusplus(); - }).sp().brace(_ -> { - in().nl().forKeyword().sp().paren(_ -> { - s32Type().sp().id(varB).assign().intValue(from).semicolon(); - id(varB).sp().lt().sp().intValue(to).semicolon(); - id(varB).plusplus(); - }).in(); - - brace(_ -> { - nl().f32Type().sp().id(acc).assign().id(tensorC.varName()).sbrace( _-> { - id(varA).mul().id(Integer.toString(shape[0])).sp().plus().id(varB); - }).semicolon().nl(); - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varC).assign().intValue(from).semicolon(); - id(varC).sp().lt().sp().intValue(to).semicolon(); - id(varC).plusplus(); - }).sp().in(); - - brace(_ -> { - nl(); - String ha = generateVariableName("ha_"); - String hb = generateVariableName("hb_"); - String resultTensor = generateVariableName("h_res_"); - f16Type().sp().id(ha).assign().id(tensorA.varName()).sbrace( _ -> id(varA).mul().id(Integer.toString(shape[0])).sp().plus().id(varC)).semicolon().nl(); - f16Type().sp().id(hb).assign().id(tensorB.varName()).sbrace( _ -> id(varC).mul().id(Integer.toString(shape[0])).sp().plus().id(varB)).semicolon().nl(); - f16Type().sp().id(resultTensor).assign().paren( _ -> f16Type()).brace( _ -> paren( _ -> id(ha).dot().id("value").mul().id(hb).dot().id("value"))).semicolon().nl(); - id(acc).sp().plusEquals().cast( _ -> f32Type()).paren( _-> id(resultTensor).dot().id("value")).semicolon().nl(); - }).nl().out(); - - id(result.varName()).sbrace( _ -> id(varA).sp().mul().sp().id(Integer.toString(shape[0])).sp().plus().sp().id(varB)).assign().id(acc).semicolon().nl(); - - }).semicolon().nl(); - - }).out().out(); - return self(); - } - - @Override - public OpenCLHATKernelBuilder hatTensorMMAOp(HATTensorOp.TensorMMAOp tensorMMAOp) { - var resulTensorValue = tensorMMAOp.operands().getFirst(); - var tensorAValue = tensorMMAOp.operands().get(1); - var tensorBValue = tensorMMAOp.operands().get(2); - var tensorCValue = tensorMMAOp.operands().get(3); - var tensorA = findTensorVarOp(tensorAValue); - var tensorB = findTensorVarOp(tensorBValue); - var tensorC = findTensorVarOp(tensorCValue); - var tensorResult = findTensorVarOp(resulTensorValue); - if (tensorA == null || tensorB == null || tensorC == null || tensorResult == null) { - throw new OpenCLCodeGenException("[Error][CodeGen] Expected a tensorValue, but found `null` instead"); - } - int[] shape = getShapeFromTensorVarOp(tensorA); - return generateTensorMMA(shape, tensorA, tensorB, tensorC, tensorResult); - } - - - @Override - public OpenCLHATKernelBuilder hatTensorStoreLoadOp(HATTensorOp.TensorStoreLoadOp storeLoadOp) { - List operands = storeLoadOp.operands(); - if (operands.getLast() instanceof Op.Result r) { - recurse(r.op()); - } - return self(); - } - - private HATTensorOp.TensorVarOp findTensorVarOp(HATTensorOp.TensorLoadOp tensorLoadOp) { - var tensorStoreLoadValue = tensorLoadOp.result().uses().getFirst(); - if (tensorStoreLoadValue.declaringElement() instanceof HATTensorOp.TensorStoreLoadOp tensorStoreLoadOp) { - Value first = tensorStoreLoadOp.operands().getFirst(); - if (first.declaringElement() instanceof HATTensorOp.TensorVarOp tensorVarOp) { - return tensorVarOp; - } else { - return null; - } - } else { - return null; - } - } - - /** - * Code example being generated: - * - *

- * - * for (int m = 0; m < WMMA_M; m++) { - * int rowA = aRow + m; - * for (int n = 0; n < WMMA_N; n++) { - * int colA = aCol + n; - * int idxA = rowA + colA * lda; - * HAT_GLOBAL_MEM F16Impl_t* ha = &matrixA->array[idxA]; - * F16_t r = (F16_t){ha->value}; - * tensorA[m * WMMA_M + n] = r; - * } - * } - * - *

- * - * @param shape - * @param iIndexValue - * @param jIndexValue - * @param isColumnMajor - * @param leadingDimension - * @param ptrValue - * @param tensorVarOp - * - * @return {@link OpenCLHATKernelBuilder} - */ - private OpenCLHATKernelBuilder generateTensorLoad(int[] shape, Value iIndexValue, Value jIndexValue, boolean isColumnMajor, Value leadingDimension, Value ptrValue, HATTensorOp.TensorVarOp tensorVarOp) { - - String prefix = INDEX_PREFIX; - String varA = generateVariableName(prefix); - String varB = generateVariableName(prefix); - final int to = shape[0]; - final int from = 0; - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varA).assign().intValue(from).semicolon(); - id(varA).sp().lt().sp().intValue(to).semicolon(); - id(varA).plusplus(); - }).in(); - - String row = generateVariableName("row_"); - - brace(_ -> { - nl().s32Type().sp().id(row).assign(); - - if (iIndexValue instanceof Op.Result r) { - recurse(r.op()); - } - plus().id(varA).semicolon().nl(); - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varB).assign().intValue(from).semicolon(); - id(varB).sp().lt().sp().intValue(to).semicolon(); - id(varB).plusplus(); - }).sp().in(); - - String col = generateVariableName("col_"); - - brace(_ -> { - nl().s32Type().sp().id(col).assign(); - - if (jIndexValue instanceof Op.Result r) { - recurse(r.op()); - } - plus().id(varB).semicolon().nl(); - - String index = generateVariableName(INDEX_PREFIX); - s32Type().sp().id(index).assign(); - - String aVal = row; - String bVal = col; - if (isColumnMajor) { - aVal = col; - bVal = row; - } - - id(aVal).sp().mul().sp(); - if (leadingDimension instanceof Op.Result r) { - recurse(r.op()); - } - sp().plus().id(bVal).semicolon().nl(); - - - // TODO: We assume a load from global memory. In - // future version, we will process loads from other - // memory regions of the accelerator - - String ha = generateVariableName("ha_"); - id("HAT_GLOBAL_MEM F16Impl_t").asterisk().sp().id(ha).assign().ampersand(); - - if (ptrValue instanceof Op.Result r) { - recurse(r.op()); - } - rarrow().id("array").sbrace( _ -> id(index)).semicolon().nl(); - - String r = generateVariableName("r_"); - f16Type().sp().id(r).assign().cast( _ -> f16Type()).brace( _-> id(ha).rarrow().id("value")).semicolon().nl(); - - // store into the acc - emitText(tensorVarOp.varName()).sbrace( _ -> id(varA).sp().mul().id(Integer.toString(shape[0])).sp().plus().id(varB)); - equals().sp().id(r).semicolon().nl(); - }).out(); - }).out(); - return self(); - } - - /** - * Code example being generated: - * - *

- * - * for (int m = 0; m < WMMA_M; m++) { - int rowB = bRow + m; - * for (int n = 0; n < WMMA_N; n++) { - * int colB = bCol + n; - * int idxB = rowB + colB * ldb; - * HAT_GLOBAL_MEM F16Impl_t* hb = &matrixB->array[idxB]; - * F16_t r = (F16_t){hb->value}; - * b_frag[m * WMMA_M + n] = r; - * } - * } - * - *

- * - * @param tensorLoadOp - * - * @return {@link OpenCLHATKernelBuilder} - */ - @Override - public OpenCLHATKernelBuilder hatTensorLoadOp(HATTensorOp.TensorLoadOp tensorLoadOp) { - - List operands = tensorLoadOp.operands(); - var ptrValue = operands.getFirst(); - var iIndexValue = operands.get(1); - var jIndexValue = operands.get(2); - var leadingDimension = operands.get(3); - HATTensorOp.TensorVarOp tensorVarOp = findTensorVarOp(tensorLoadOp); - int[] shape; - boolean isColumnMajor; - if (tensorVarOp != null) { - shape = getShapeFromTensorVarOp(tensorVarOp); - isColumnMajor = isColumnMajorFromVarOp(tensorVarOp); - } else { - throw new OpenCLCodeGenException("[Error][CodeGen] Expected to see an instance of tensorVarOp but `null` found"); - } - generateTensorLoad(shape, iIndexValue, jIndexValue, isColumnMajor, leadingDimension, ptrValue, tensorVarOp); - HAT_BARRIER(); - return self(); - } - - /** - * Example of code being generated: - * - *

- * - * for (int m = 0; m < WMMA_M; m++) { - * `int rowC = cRow + m; - * for (int n = 0; n < WMMA_N; n++) { - * int colC = cCol + n; - * int idxC = (cRow) + (cCol) * ldc; - * matrixC->array[idxC] = acc[m * 16 + n]; - * } - * } - * - *

- * - * @param shape - * @param iIndexValue - * @param jIndexValue - * @param isColumnMajor - * @param leadingDimension - * @param ptrValue - * @param tensorVarOp - * - * @return {@link OpenCLHATKernelBuilder} - */ - private OpenCLHATKernelBuilder generateTensorStore(int[] shape, Value iIndexValue, Value jIndexValue, boolean isColumnMajor, Value leadingDimension, Value ptrValue, HATTensorOp.TensorVarOp tensorVarOp) { - String prefix = INDEX_PREFIX; - String varA = generateVariableName(prefix); - String varB = generateVariableName(prefix); - final int to = shape[0]; - final int from = 0; - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varA).assign().intValue(from).semicolon(); - id(varA).sp().lt().sp().intValue(to).semicolon(); - id(varA).plusplus(); - }).in(); - - String row = generateVariableName("row_"); - - brace(_ -> { - nl().s32Type().sp().id(row).assign(); - - if (iIndexValue instanceof Op.Result r) { - recurse(r.op()); - } - plus().id(varA).semicolon().nl(); - - forKeyword().sp().paren(_ -> { - s32Type().sp().id(varB).assign().intValue(from).semicolon(); - id(varB).sp().lt().sp().intValue(to).semicolon(); - id(varB).plusplus(); - }).sp().in(); - - String col = generateVariableName("col_"); - - brace(_ -> { - nl().s32Type().sp().id(col).assign(); - - if (jIndexValue instanceof Op.Result r) { - recurse(r.op()); - } - plus().id(varB).semicolon().nl(); - - String index = generateVariableName(INDEX_PREFIX); - s32Type().sp().id(index).assign(); - - String aVal = row; - String bVal = col; - if (isColumnMajor) { - aVal = col; - bVal = row; - } - - id(aVal).sp().mul().sp(); - if (leadingDimension instanceof Op.Result r) { - recurse(r.op()); - } - sp().plus().id(bVal).semicolon().nl(); - - // TODO: We assume a load from global memory. In - // future version, we will process loads from other - // memory regions of the accelerator - if (ptrValue instanceof Op.Result r) { - recurse(r.op()); - } - rarrow().id("array").sbrace( _ -> id(index)).assign(); - id(tensorVarOp.varName()).sbrace( _ -> id(varA).mul().id(Integer.toString(shape[0])).plus().id(varB)); - semicolon().nl(); - }).out(); - }).out(); - return self(); - } - - /** - * Code example being generated: - * - *

- * - * for (int m = 0; m < WMMA_M; m++) { - * `int rowC = cRow + m; - * for (int n = 0; n < WMMA_N; n++) { - * int colC = cCol + n; - * int idxC = (cRow) + (cCol) * ldc; - * matrixC->array[idxC] = acc[m * 16 + n]; - * } - * } - * - *

- * - * @param tensorStoreOp - * - * @return {@link OpenCLHATKernelBuilder} - */ - @Override - public OpenCLHATKernelBuilder hatTensorStoreOp(HATTensorOp.TensorStoreOp tensorStoreOp) { - // 1. We need the global ptr - // 2. We need the indexes (i, j) - // 3. We need leading dimension - // 4. We need the name of the tensor - // 5. We need the shape - // 6. We need the access layout - - List operands = tensorStoreOp.operands(); - var ptrValue = operands.getFirst(); - var iIndexValue = operands.get(1); - var jIndexValue = operands.get(2); - var tensorValue = operands.get(3); - var leadingDimension = operands.get(4); - - HATTensorOp.TensorVarOp tensorVarOp = findTensorVarOp(tensorValue); - if (tensorVarOp == null) { - throw new OpenCLCodeGenException("[Error][CodeGen] Expected to find a tensorVarOp, but `null` instead."); - } - - int[] shape = getShapeFromTensorVarOp(tensorVarOp); - - Value accessLayout = operands.get(5); - final boolean isColumnMajor = isColumnMajor(accessLayout); - - generateTensorStore(shape, iIndexValue, jIndexValue, isColumnMajor, leadingDimension, ptrValue, tensorVarOp); - return self(); - } - } diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend_queue.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend_queue.cpp index c958340671f..fa5bad18c5d 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend_queue.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend_queue.cpp @@ -239,43 +239,6 @@ OpenCLBackend::OpenCLQueue::~OpenCLQueue() { delete []events; } -void printWarningLocalGroupResized(const size_t local_work_size[]) { - std::cout << "[Warning] Thread-Block size got automatically resized: [" << local_work_size[0] << "," << local_work_size[1] << "," << local_work_size[2] << "]" << std::endl; -} - -void checkThreadBlockFits(OpenCLBackend *backend, const KernelContext *kernelContext, const size_t global_work_size[], size_t *local_work_size) { - const PlatformInfo platformInfo(backend); - size_t max_group_size = platformInfo.deviceInfo.maxWorkGroupSize; - size_t totalThreads = kernelContext->lsx * kernelContext->lsy * kernelContext->lsz; - - // Adjust depending on the total number of threads in the local-work-group - while (totalThreads > max_group_size) { - // Here just a simple heuristic, starting with the first dimension, 16, 4, 1 local group sizxe - if (local_work_size[0] >= 16) { - local_work_size[0] /= 2; - } else if (local_work_size[1] >= 4) { - local_work_size[1] /= 2; - } else if (local_work_size[2] >= 1) { - local_work_size[2] /= 2; - } - totalThreads = local_work_size[0] * local_work_size[1] * local_work_size[2]; - if (backend->config->info) { - printWarningLocalGroupResized(local_work_size); - } - } - - // Adjust also depending on the global size. We can't launch more threads as local work than global work for - // each dimension - for (int i = 0; i < 3; i++) { - while (local_work_size[i] > global_work_size[i]) { - local_work_size[i] /= 2; - if (backend->config->info) { - printWarningLocalGroupResized(local_work_size); - } - } - } -} - void OpenCLBackend::OpenCLQueue::dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) { size_t numDimensions = kernelContext->dimensions; @@ -291,23 +254,6 @@ void OpenCLBackend::OpenCLQueue::dispatch(KernelContext *kernelContext, Compilat static_cast(kernelContext->lsz), }; - if (kernelContext->tlx > 0) { - global_work_size[0] /= kernelContext->tlx; - } - if (kernelContext->tly > 0) { - global_work_size[1] /= kernelContext->tly; - } - if (kernelContext->tlz > 0) { - global_work_size[2] /= kernelContext->tlz; - } - - // In the OpenCL backend, we don't currently support warp-sizes to be able to run with OpenCL 1.2 (Apple) - // The CUDA backend supports warp-sizes - - // Check the local-sizes fit - auto backendInstance = dynamic_cast(this->backend); - checkThreadBlockFits(backendInstance, kernelContext, global_work_size, local_work_size); - if (backend->config->info) { backend->shortDeviceInfo(); std::cout << "[INFO] OpenCLBackend::OpenCLQueue::dispatch" << std::endl; diff --git a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java index 977956118a1..9c309e0e9ad 100644 --- a/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java +++ b/hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java @@ -130,46 +130,6 @@ public void dispatch(KernelContext kernelContext, Object[] args) { kernelBufferContext.lsz(0); } - // Set Tile - kernelBufferContext.tlx(0); - kernelBufferContext.tly(0); - kernelBufferContext.tlz(0); - if (kernelContext.ndRange.hasTile()) { - switch (kernelContext.ndRange.tile()) { - case NDRange.Tile1D tile1D -> kernelBufferContext.tlx(tile1D.x()); - case NDRange.Tile2D tile2D -> { - kernelBufferContext.tlx(tile2D.x()); - kernelBufferContext.tly(tile2D.y()); - } - case NDRange.Tile3D tile3D -> { - kernelBufferContext.tlx(tile3D.x()); - kernelBufferContext.tly(tile3D.y()); - kernelBufferContext.tlz(tile3D.z()); - } - case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.tile().getClass()); - } - } - - // Set warp - kernelBufferContext.wsx(false); - kernelBufferContext.wsy(false); - kernelBufferContext.wsz(false); - if (kernelContext.ndRange.hasWarp()) { - switch (kernelContext.ndRange.warp()) { - case NDRange.Warp1D warp1D -> kernelBufferContext.wsx(warp1D.x()); - case NDRange.Warp2D warp2D -> { - kernelBufferContext.wsx(warp2D.x()); - kernelBufferContext.wsy(warp2D.y()); - } - case NDRange.Warp3D warp3D -> { - kernelBufferContext.wsx(warp3D.x()); - kernelBufferContext.wsy(warp3D.y()); - kernelBufferContext.wsz(warp3D.z()); - } - case null, default -> throw new IllegalArgumentException("Unknown global range " + kernelContext.ndRange.warp().getClass()); - } - } - args[0] = this.kernelBufferContext; ArgArray.update(argArray, kernelCallGraph, args); kernelBridge.ndRange(this.argArray); diff --git a/hat/backends/ffi/shared/src/main/native/include/shared.h b/hat/backends/ffi/shared/src/main/native/include/shared.h index 82021d8ba62..2c302802a1e 100644 --- a/hat/backends/ffi/shared/src/main/native/include/shared.h +++ b/hat/backends/ffi/shared/src/main/native/include/shared.h @@ -402,16 +402,6 @@ class KernelContext { int bsx; int bsy; int bsz; - - // Tile Size - int tlx; - int tly; - int tlz; - - // Warp sizes - bool wsx; - bool wsy; - bool wsz; }; class Backend { diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLJExtractedHATKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLJExtractedHATKernelBuilder.java index 30241004a98..2c9aa24628e 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLJExtractedHATKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLJExtractedHATKernelBuilder.java @@ -27,7 +27,6 @@ import hat.callgraph.KernelCallGraph; import hat.codebuilders.C99HATKernelBuilder; import hat.dialect.HATF16Op; -import hat.dialect.HATTensorOp; import hat.dialect.HATVectorOp; import optkl.codebuilders.CodeBuilder; import optkl.codebuilders.ScopedCodeBuilderContext; @@ -43,11 +42,6 @@ protected OpenCLJExtractedHATKernelBuilder(KernelCallGraph kernelCallGraph, Scop super(kernelCallGraph,scopedCodeBuilderContext); } - @Override - protected OpenCLJExtractedHATKernelBuilder hatWarpSize() { - return constant("1"); - } - @Override public OpenCLJExtractedHATKernelBuilder defines() { return self() @@ -248,44 +242,4 @@ public OpenCLJExtractedHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFl protected String mapMathIntrinsic(String hatMathIntrinsicName) { return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName); } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorVarOp(HATTensorOp.TensorVarOp tensorVarOp) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorCreateOp(HATTensorOp.TensorCreateOp tensorCreateOp) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorFillOp(HATTensorOp.TensorFillOp tensorFillOp) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorVarLoadOp(HATTensorOp.TensorVarLoadOp hatTensorVarLoadOp) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorMMAOp(HATTensorOp.TensorMMAOp tensorMMAOp) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorStoreLoadOp(HATTensorOp.TensorStoreLoadOp $) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorLoadOp(HATTensorOp.TensorLoadOp $) { - return blockComment("Not supported yet"); - } - - @Override - public OpenCLJExtractedHATKernelBuilder hatTensorStoreOp(HATTensorOp.TensorStoreOp $) { - return blockComment("Not supported yet"); - } } diff --git a/hat/core/src/main/java/hat/KernelContext.java b/hat/core/src/main/java/hat/KernelContext.java index 5e549a232cb..b1f36f0c1bc 100644 --- a/hat/core/src/main/java/hat/KernelContext.java +++ b/hat/core/src/main/java/hat/KernelContext.java @@ -68,9 +68,6 @@ public class KernelContext { public int bsy; public int bsz; - // Warp size - public int wrs; - final int dimensions; public final NDRange ndRange; diff --git a/hat/core/src/main/java/hat/NDRange.java b/hat/core/src/main/java/hat/NDRange.java index 02110769416..470dce1c9b3 100644 --- a/hat/core/src/main/java/hat/NDRange.java +++ b/hat/core/src/main/java/hat/NDRange.java @@ -41,16 +41,8 @@ public interface NDRange { Global global(); - Tile tile(); - - Warp warp(); - boolean hasLocal(); - boolean hasTile(); - - boolean hasWarp(); - sealed interface Dim permits Marker1D, Marker2D, Marker3D { default int dimension() { return switch (this) { @@ -88,22 +80,7 @@ sealed interface M3D extends Marker3D { int z(); } - sealed interface B1D extends Marker1D { - boolean x(); - } - - sealed interface B2D extends Marker2D { - boolean x(); - boolean y(); - } - - sealed interface B3D extends Marker3D { - boolean x(); - boolean y(); - boolean z(); - } - - sealed interface Range permits Global, Local, Tile, Warp { + sealed interface Range permits Global, Local { } @@ -115,14 +92,6 @@ sealed interface Local extends Range { } - sealed interface Tile extends Range { - - } - - sealed interface Warp extends Range { - - } - sealed interface Global1D extends M1D, Global { record Impl(int x) implements Global1D { } @@ -188,103 +157,21 @@ static Local3D of(int x, int y, int z) { Local3D EMPTY = Local3D.of(0, 0, 0); } - sealed interface Tile1D extends M1D, Tile { - record Impl(int x) implements Tile1D { - } - - static Tile1D of(int x) { - return new Impl(x); - } - - Tile1D EMPTY = Tile1D.of(0); - } - - sealed interface Tile2D extends M2D, Tile { - record Impl(int x, int y) implements Tile2D { - } - - static Tile2D of(int x, int y) { - return new Impl(x, y); - } - - Tile2D EMPTY = Tile2D.of(0, 0); - - } - - sealed interface Tile3D extends M3D, Tile { - record Impl(int x, int y, int z) implements Tile3D { - } - - static Tile3D of(int x, int y, int z) { - return new Impl(x, y, z); - } - - Tile3D EMPTY = Tile3D.of(0, 0, 0); - } - - sealed interface Warp1D extends B1D, Warp { - record Impl(boolean x) implements Warp1D { - } - - static Warp1D of(boolean x) { - return new Impl(x); - } - - Warp1D EMPTY = Warp1D.of(false); - } - - sealed interface Warp2D extends B2D, Warp { - record Impl(boolean x, boolean y) implements Warp2D { - } - - static Warp2D of(boolean x, boolean y) { - return new Impl(x, y); - } - - Warp2D EMPTY = Warp2D.of(false, false); - } - - sealed interface Warp3D extends B3D, Warp { - record Impl(boolean x, boolean y, boolean z) implements Warp3D { - } - - static Warp3D of(boolean x, boolean y, boolean z) { - return new Impl(x, y, z); - } - - Warp3D EMPTY = Warp3D.of(false, false, false); - } - - sealed interface NDRange1D extends NDRange, Marker1D { @Override default boolean hasLocal() { return local() != Local1D.EMPTY; } - @Override - default boolean hasTile() { - return tile() != Tile1D.EMPTY; - } - - @Override - default boolean hasWarp() { - return warp() != Warp1D.EMPTY; - } - - record Impl(Global1D global, Local1D local, Tile1D tile, Warp1D warp) implements NDRange1D { - } - - static NDRange1D of(Global1D global, Local1D local, Tile1D tile, Warp1D warp) { - return new Impl( global, local, tile, warp); + record Impl(Global1D global, Local1D local) implements NDRange1D { } static NDRange1D of(Global1D global, Local1D local) { - return new Impl(global, local, Tile1D.EMPTY, Warp1D.EMPTY); + return new Impl(global, local); } static NDRange1D of(Global1D global) { - return new Impl(global, Local1D.EMPTY, Tile1D.EMPTY, Warp1D.EMPTY); + return new Impl(global, Local1D.EMPTY); } } @@ -302,29 +189,15 @@ default boolean hasLocal() { return local() != Local2D.EMPTY; } - @Override - default boolean hasTile() { - return tile() != Tile2D.EMPTY; - } - - @Override - default boolean hasWarp() { - return warp() != Warp2D.EMPTY; - } - - record Impl(Global2D global, Local2D local, Tile2D tile, Warp2D warp) implements NDRange2D { - } - - static NDRange2D of(Global2D global, Local2D local, Tile2D tile, Warp2D warp) { - return new Impl(global, local, tile, warp); + record Impl(Global2D global, Local2D local) implements NDRange2D { } static NDRange2D of(Global2D global, Local2D local) { - return new Impl(global, local, Tile2D.EMPTY, Warp2D.EMPTY); + return new Impl(global, local); } static NDRange2D of(Global2D global) { - return new Impl(global, Local2D.EMPTY, Tile2D.EMPTY, Warp2D.EMPTY); + return new Impl(global, Local2D.EMPTY); } } @@ -342,29 +215,15 @@ default boolean hasLocal() { return local() != Local3D.EMPTY; } - @Override - default boolean hasTile() { - return tile() != Tile3D.EMPTY; - } - - @Override - default boolean hasWarp() { - return warp() != Warp3D.EMPTY; - } - - record Impl(Global3D global, Local3D local, Tile3D tile, Warp3D warp) implements NDRange3D { - } - - static NDRange3D of(Global3D global, Local3D local, Tile3D tile, Warp3D warp) { - return new Impl(global, local, tile, warp); + record Impl(Global3D global, Local3D local) implements NDRange3D { } static NDRange3D of(Global3D global, Local3D local) { - return new Impl(global, local, Tile3D.EMPTY, Warp3D.EMPTY); + return new Impl(global, local); } static NDRange3D of(Global3D global) { - return new Impl(global, Local3D.EMPTY, Tile3D.EMPTY, Warp3D.EMPTY); + return new Impl(global, Local3D.EMPTY); } } diff --git a/hat/core/src/main/java/hat/buffer/ArgArray.java b/hat/core/src/main/java/hat/buffer/ArgArray.java index ed1ace00b24..4b548f1f457 100644 --- a/hat/core/src/main/java/hat/buffer/ArgArray.java +++ b/hat/core/src/main/java/hat/buffer/ArgArray.java @@ -307,7 +307,6 @@ static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... buf.access(accessType.value); } else { // otherwise, we rely on the buffer-tagger to set the accessor - //buf.access(accessType.value); // TODO: Temporary for tensors - remove this before the merge buf.access(bufferAccessList.get(i).value); } } diff --git a/hat/core/src/main/java/hat/buffer/KernelBufferContext.java b/hat/core/src/main/java/hat/buffer/KernelBufferContext.java index 4e478343cf2..271807821a5 100644 --- a/hat/core/src/main/java/hat/buffer/KernelBufferContext.java +++ b/hat/core/src/main/java/hat/buffer/KernelBufferContext.java @@ -41,8 +41,6 @@ default void schema(){ lsx(); lsy(); lsz(); // local sizes bix(); biy(); biz(); // block index bsx(); bsy(); bsz(); // block sizes - tlx(); tly(); tlz(); // tile sizes - wsx(); wsy(); wsz(); // warp sizes } Schema schema = Schema.of(KernelBufferContext.class); @@ -109,22 +107,6 @@ default void schema(){ int bsz(); void bsz(int bsz); - // Tile size - int tlx(); - void tlx(int tlx); - int tly(); - void tly(int tly); - int tlz(); - void tlz(int tlz); - - // Warp Size - boolean wsx(); - void wsx(boolean wsx); - boolean wsy(); - void wsy(boolean wsy); - boolean wsz(); - void wsz(boolean wsz); - static KernelBufferContext createDefault(ArenaAndLookupCarrier cc) { KernelBufferContext kernelBufferContext = BoundSchema.of(cc ,schema).allocate(); @@ -155,14 +137,6 @@ static KernelBufferContext createDefault(ArenaAndLookupCarrier cc) { kernelBufferContext.bsy(0); kernelBufferContext.bsz(0); - kernelBufferContext.tlx(0); - kernelBufferContext.tly(0); - kernelBufferContext.tlz(0); - - kernelBufferContext.wsx(false); - kernelBufferContext.wsy(false); - kernelBufferContext.wsz(false); - return kernelBufferContext; } } diff --git a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java index 006fce8d904..7a1c0316ad5 100644 --- a/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java +++ b/hat/core/src/main/java/hat/callgraph/KernelCallGraph.java @@ -29,7 +29,6 @@ import hat.device.NonMappableIface; import hat.phases.HATTier; import hat.types.S16ImplOfF16; -import hat.types.Tensor; import jdk.incubator.code.CodeTransformer; import jdk.incubator.code.Op; import jdk.incubator.code.CodeType; @@ -74,7 +73,6 @@ public class KernelCallGraph implements LookupCarrier { public final Set> accessedVecClasses; public final Set> accessedFP16Classes; public boolean usesBarrier; - public boolean useTensors; public boolean usesAtomics; public final Set accessedKernelContextFields; @@ -114,8 +112,6 @@ public class KernelCallGraph implements LookupCarrier { var inlinedEntryPoint = ssaFunc; this.usesBarrier = OpHelper.Invoke.stream(lookup(), inlinedEntryPoint) .anyMatch(invoke -> invoke.refIs(KernelContext.class) && invoke.named("barrier")); - this.useTensors = OpHelper.Invoke.stream(lookup(), inlinedEntryPoint) - .anyMatch(invoke -> invoke.refIs(Tensor.class) && invoke.named("load")); this.accessedKernelContextFields = new HashSet<>(OpHelper.FieldAccess.stream(lookup(), inlinedEntryPoint) .filter(fieldAccess -> fieldAccess.refType(KernelContext.class)).map(OpHelper.FieldAccess::name).toList() ); diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATCodeBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATCodeBuilder.java index 81fb1897d04..d35aa8250c3 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATCodeBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATCodeBuilder.java @@ -25,7 +25,6 @@ package hat.codebuilders; import hat.dialect.HATF16Op; -import hat.dialect.HATTensorOp; import hat.dialect.HATVectorOp; import hat.dialect.HATMemoryVarOp; import optkl.codebuilders.C99CodeBuilder; @@ -71,10 +70,4 @@ public final T varName(HATF16Op.HATF16VarOp hatF16VarOp) { id(hatF16VarOp.varName()); return self(); } - - public final T varName(HATTensorOp.TensorVarOp tensorVarOp) { - id(tensorVarOp.varName()); - return self(); - } - } diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index 107e4f11604..87068140863 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -33,14 +33,11 @@ import hat.dialect.HATMemoryDefOp; import hat.dialect.HATMemoryVarOp; import hat.dialect.HATPtrOp; -import hat.dialect.HATTensorOp; import hat.dialect.HATThreadOp; import hat.dialect.HATVectorOp; import hat.types.BF16; import hat.types.F16; -import hat.types.Tensor; import jdk.incubator.code.dialect.java.ClassType; -import jdk.incubator.code.dialect.java.FieldRef; import jdk.incubator.code.dialect.java.JavaOp; import jdk.incubator.code.dialect.java.JavaType; import jdk.incubator.code.dialect.java.PrimitiveType; @@ -49,7 +46,6 @@ import jdk.incubator.code.Value; import optkl.OpHelper; import optkl.codebuilders.ScopedCodeBuilderContext; -import optkl.exceptions.CodeGenException; import optkl.ifacemapper.BoundSchema; import optkl.ifacemapper.Schema; import jdk.incubator.code.Op; @@ -59,7 +55,6 @@ import jdk.incubator.code.dialect.core.CoreOp; import optkl.codebuilders.CodeBuilder; -import java.lang.invoke.MethodHandles; import java.util.List; import java.util.SequencedSet; import java.util.concurrent.ThreadLocalRandom; @@ -195,12 +190,6 @@ public final T HAT_BSZ() { return id("HAT_BSZ"); } - public final T HAT_WARP_SIZE() { - return hatWarpSize(); - } - - protected abstract T hatWarpSize(); - @Override public final T hatThreadIdOp( HATThreadOp threadOp) { return (switch (threadOp) { @@ -222,7 +211,6 @@ public final T hatThreadIdOp( HATThreadOp threadOp) { case HATThreadOp.HAT_BS.HAT_BSX _ -> HAT_BSX(); case HATThreadOp.HAT_BS.HAT_BSY _ -> HAT_BSY(); case HATThreadOp.HAT_BS.HAT_BSZ _ -> HAT_BSZ(); - case HATThreadOp.HAT_WARP_SIZE _ -> HAT_WARP_SIZE(); }); } @@ -753,7 +741,6 @@ public final T varLoadOp( CoreOp.VarAccessOp.VarLoadOp varLoadOp) { case HATVectorOp.HATVectorLoadOp $ -> varName($); case HATVectorOp.HATVectorBinaryOp $ -> varName($); case HATF16Op.HATF16VarOp $ -> varName($); - case HATTensorOp.TensorVarOp $ -> varName($); case null, default -> { } } @@ -778,7 +765,6 @@ public final T varStoreOp( CoreOp.VarAccessOp.VarStoreOp varStoreOp) { case HATMemoryVarOp.HATPrivateVarOp hatPrivateVarOp -> varName(hatPrivateVarOp); case HATMemoryVarOp.HATLocalVarOp hatLocalVarOp -> varName(hatLocalVarOp); case HATVectorOp.HATVectorVarOp hatVectorVarOp -> varName(hatVectorVarOp); - case HATTensorOp.TensorVarOp hattensorVarOp -> varName(hattensorVarOp); case null, default -> throw new IllegalStateException("What type of varStoreOp is this?"); } equals().parenthesisIfNeeded( varStoreOp, ((Op.Result)varStoreOp.operands().get(1)).op()); @@ -927,39 +913,4 @@ private void generateMathIntrinsicOperation(Invoke invoke) { } protected abstract String mapMathIntrinsic(String name); - - protected T indexForTensor(boolean isColumnMajor, Value iIndex, Value jIndex, Value ldSize) { - Value a = isColumnMajor ? iIndex : jIndex; - Value b = isColumnMajor ? jIndex : iIndex; - - if (a instanceof Op.Result r) { - recurse(r.op()); - } - plus(); - oparen(); - if (b instanceof Op.Result r) { - recurse(r.op()); - } - mul(); - if (ldSize instanceof Op.Result r) { - recurse(r.op()); - } - cparen(); - return self(); - } - - protected boolean isColumnMajor(Value tensorLayout) { - if (tensorLayout.declaringElement() instanceof JavaOp.InvokeOp invokeOp) { - var invoke = invoke(scopedCodeBuilderContext().lookup(), invokeOp); - if (invoke.resultTypeIs(Tensor.ColumMajor.class)) { - return true; - } else if (invoke.resultTypeIs(Tensor.RowMajor.class)) { - return false; - } else { - throw new RuntimeException("[Error]"); - } - } - return false; - } - } diff --git a/hat/core/src/main/java/hat/codebuilders/HATOpDispatcher.java b/hat/core/src/main/java/hat/codebuilders/HATOpDispatcher.java index a799dcd593c..c7c33a07d51 100644 --- a/hat/core/src/main/java/hat/codebuilders/HATOpDispatcher.java +++ b/hat/core/src/main/java/hat/codebuilders/HATOpDispatcher.java @@ -30,7 +30,6 @@ import hat.dialect.HATMemoryVarOp; import hat.dialect.HATOp; import hat.dialect.HATPtrOp; -import hat.dialect.HATTensorOp; import hat.dialect.HATThreadOp; import hat.dialect.HATVectorOp; import jdk.incubator.code.Op; @@ -87,23 +86,6 @@ public interface HATOpDispatcher hatPtrLengthOp($); case HATF16Op.HATF16ToFloatConvOp $ -> hatF16ToFloatConvOp($); case HATMemoryDefOp.HATMemoryLoadOp $ -> hatMemoryLoadOp($); - case HATTensorOp.TensorVarOp $ -> hatTensorVarOp($); - case HATTensorOp.TensorCreateOp $ -> hatTensorCreateOp($); - case HATTensorOp.TensorVarLoadOp $ -> hatTensorVarLoadOp($); - case HATTensorOp.TensorFillOp $ -> hatTensorFillOp($); - case HATTensorOp.TensorMMAOp $ -> hatTensorMMAOp($); - case HATTensorOp.TensorStoreLoadOp $ -> hatTensorStoreLoadOp($); - case HATTensorOp.TensorLoadOp $ -> hatTensorLoadOp($); - case HATTensorOp.TensorStoreOp $ -> hatTensorStoreOp($); default -> throw new IllegalStateException("handle nesting of hat op " + op); } } else { diff --git a/hat/core/src/main/java/hat/dialect/HATOp.java b/hat/core/src/main/java/hat/dialect/HATOp.java index 10eeab7363c..01e87589666 100644 --- a/hat/core/src/main/java/hat/dialect/HATOp.java +++ b/hat/core/src/main/java/hat/dialect/HATOp.java @@ -30,7 +30,7 @@ import java.util.List; -public abstract sealed class HATOp extends Op permits HATBarrierOp, HATF16Op, HATMemoryDefOp, HATMemoryVarOp, HATPtrOp, HATTensorOp, HATThreadOp, HATVectorOp { +public abstract sealed class HATOp extends Op permits HATBarrierOp, HATF16Op, HATMemoryDefOp, HATMemoryVarOp, HATPtrOp, HATThreadOp, HATVectorOp { protected HATOp(List operands) { super(operands); } diff --git a/hat/core/src/main/java/hat/dialect/HATTensorOp.java b/hat/core/src/main/java/hat/dialect/HATTensorOp.java deleted file mode 100644 index a43cf18c391..00000000000 --- a/hat/core/src/main/java/hat/dialect/HATTensorOp.java +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.dialect; - -import jdk.incubator.code.CodeContext; -import jdk.incubator.code.CodeTransformer; -import jdk.incubator.code.Op; -import jdk.incubator.code.CodeType; -import jdk.incubator.code.Value; -import jdk.incubator.code.dialect.core.VarType; -import optkl.util.ops.Precedence; -import optkl.util.ops.StatementLikeOp; -import optkl.util.ops.VarLikeOp; - -import java.util.List; -import java.util.Map; - -public abstract sealed class HATTensorOp extends HATOp { - - protected HATTensorOp(List operands) { - super(operands); - } - - protected HATTensorOp(Op that, CodeContext cc) { - super(that, cc); - } - - public static final class TensorVarOp extends HATTensorOp implements VarLikeOp, StatementLikeOp { - - private final VarType codeType; - private final String varName; - - public TensorVarOp(String varName, VarType codeType, List operands) { - super(operands); - this.varName = varName; - this.codeType = codeType; - } - - public TensorVarOp(TensorVarOp op, CodeContext copyContext) { - super(op, copyContext); - this.varName = op.varName; - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorVarOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public Map externalize() { - return Map.of("hat.dialect.TensorVarOp." + varName, codeType); - } - - @Override - public String varName() { - return varName; - } - } - - public static final class TensorCreateOp extends HATTensorOp implements Precedence.Invoke { - - private final CodeType codeType; - - public TensorCreateOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - } - - public TensorCreateOp(TensorCreateOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorCreateOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.Tensor.CreateOp"; - } - } - - public static final class TensorVarLoadOp extends HATTensorOp implements Precedence.LoadOrConv { - - private final CodeType codeType; - - public TensorVarLoadOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - - } - - public TensorVarLoadOp(TensorVarLoadOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorVarLoadOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.TensorVarLoadOp"; - } - } - - public static final class TensorFillOp extends HATTensorOp implements Precedence.Invoke { - - private final CodeType codeType; - - public TensorFillOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - } - - public TensorFillOp(TensorFillOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorFillOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.Tensor.fill"; - } - } - - public static final class TensorMMAOp extends HATTensorOp implements Precedence.Invoke { - - private final CodeType codeType; - - public TensorMMAOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - } - - public TensorMMAOp(TensorMMAOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorMMAOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.Tensor.MMA"; - } - } - - public static final class TensorStoreLoadOp extends HATTensorOp implements Precedence.Store { - - private final CodeType codeType; - - public TensorStoreLoadOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - - } - - public TensorStoreLoadOp(TensorStoreLoadOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorStoreLoadOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.TensorStoreLoadOp"; - } - } - - public static final class TensorLoadOp extends HATTensorOp implements Precedence.LoadOrConv { - - private final CodeType codeType; - - public TensorLoadOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - } - - public TensorLoadOp(TensorLoadOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorLoadOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.TensorLoadOp"; - } - } - - public static final class TensorStoreOp extends HATTensorOp implements Precedence.LoadOrConv { - - private final CodeType codeType; - - public TensorStoreOp(CodeType codeType, List operands) { - super(operands); - this.codeType = codeType; - } - - public TensorStoreOp(TensorStoreOp op, CodeContext copyContext) { - super(op, copyContext); - this.codeType = op.codeType; - } - - @Override - public Op transform(CodeContext copyContext, CodeTransformer opTransformer) { - return new TensorStoreOp(this, copyContext); - } - - @Override - public CodeType resultType() { - return codeType; - } - - @Override - public String externalizeOpName() { - return "hat.dialect.TensorStoreOp"; - } - } - - -} diff --git a/hat/core/src/main/java/hat/dialect/HATThreadOp.java b/hat/core/src/main/java/hat/dialect/HATThreadOp.java index 3364b8af449..8c0a8e9a568 100644 --- a/hat/core/src/main/java/hat/dialect/HATThreadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATThreadOp.java @@ -69,27 +69,10 @@ public static HATThreadOp create(String name) { case "bsx" -> new HATThreadOp.HAT_BS.HAT_BSX(); case "bsy" -> new HATThreadOp.HAT_BS.HAT_BSY(); case "bsz" -> new HATThreadOp.HAT_BS.HAT_BSZ(); - case "wrs" -> new HATThreadOp.HAT_WARP_SIZE(); default -> throw new RuntimeException("[ERROR] Illegal/unsupported parallel construct: " + name); }; } - public static final class HAT_WARP_SIZE extends HATThreadOp { - - public HAT_WARP_SIZE() { - super(List.of()); - } - - public HAT_WARP_SIZE(HAT_WARP_SIZE op, CodeContext codeContext) { - super(op, codeContext); - } - - @Override - public Op transform(CodeContext codeContext, CodeTransformer codeTransformer) { - return new HAT_WARP_SIZE(this, codeContext); - } - } - public abstract static sealed class HAT_LI extends HATThreadOp { protected HAT_LI() { diff --git a/hat/core/src/main/java/hat/phases/HATPhase.java b/hat/core/src/main/java/hat/phases/HATPhase.java index e89f7a96fe6..256a3d4031e 100644 --- a/hat/core/src/main/java/hat/phases/HATPhase.java +++ b/hat/core/src/main/java/hat/phases/HATPhase.java @@ -29,6 +29,6 @@ import java.lang.invoke.MethodHandles; public sealed interface HATPhase - permits HATArrayViewPhase, HATBarrierPhase, HATFP16Phase, HATMathLibPhase, HATMemoryPhase, HATTensorsPhase, HATThreadsPhase, HATVectorPhase, HATVectorSelectPhase, HATVectorStorePhase, HATWarpSizePhase { + permits HATArrayViewPhase, HATBarrierPhase, HATFP16Phase, HATMathLibPhase, HATMemoryPhase, HATThreadsPhase, HATVectorPhase, HATVectorSelectPhase, HATVectorStorePhase { CoreOp.FuncOp transform(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp); } diff --git a/hat/core/src/main/java/hat/phases/HATTensorsPhase.java b/hat/core/src/main/java/hat/phases/HATTensorsPhase.java deleted file mode 100644 index 8008ddfaa31..00000000000 --- a/hat/core/src/main/java/hat/phases/HATTensorsPhase.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.phases; - -import hat.types.Tensor; -import jdk.incubator.code.Block; -import jdk.incubator.code.Op; -import jdk.incubator.code.Value; -import jdk.incubator.code.dialect.core.CoreOp; -import jdk.incubator.code.dialect.java.JavaOp; -import optkl.OpHelper; -import optkl.Trxfmr; - -import java.lang.invoke.MethodHandles; -import java.util.HashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; -import java.util.function.BiConsumer; - -import static hat.dialect.HATTensorOp.TensorCreateOp; -import static hat.dialect.HATTensorOp.TensorFillOp; -import static hat.dialect.HATTensorOp.TensorLoadOp; -import static hat.dialect.HATTensorOp.TensorMMAOp; -import static hat.dialect.HATTensorOp.TensorStoreLoadOp; -import static hat.dialect.HATTensorOp.TensorStoreOp; -import static hat.dialect.HATTensorOp.TensorVarLoadOp; -import static hat.dialect.HATTensorOp.TensorVarOp; - -public record HATTensorsPhase() implements HATPhase { - - private interface TensorTransformer { - - void transform(Block.Builder blockBuilder, Op op); - - default void replaceOp(Block.Builder blockBuilder, Op oldOp, Op newOp) { - newOp.setLocation(oldOp.location()); - Op.Result newOpResult = blockBuilder.op(newOp); - blockBuilder.context().mapValue(oldOp.result(), newOpResult); - } - - } - - private static class TensorView implements TensorTransformer { - - @Override - public void transform(Block.Builder blockBuilder, Op op) { - List operands = blockBuilder.context().getValues(op.operands()); - switch (op) { - case CoreOp.VarOp varOp -> replaceOp(blockBuilder, varOp, new TensorVarOp(varOp.varName(), varOp.resultType(), operands)); - case JavaOp.InvokeOp invokeOp -> replaceOp(blockBuilder, invokeOp, new TensorCreateOp(invokeOp.resultType(), operands)); - default -> blockBuilder.op(op); - } - } - } - - private static class TensorFill implements TensorTransformer { - - @Override - public void transform(Block.Builder blockBuilder, Op op) { - List operands = blockBuilder.context().getValues(op.operands()); - switch (op) { - case CoreOp.VarAccessOp.VarLoadOp loadOp -> replaceOp(blockBuilder, loadOp, new TensorVarLoadOp(loadOp.resultType(), operands)); - case JavaOp.InvokeOp invokeOp -> replaceOp(blockBuilder, invokeOp, new TensorFillOp(invokeOp.resultType(), operands)); - default -> blockBuilder.op(op); - } - } - } - - private static class TensorMMA implements TensorTransformer { - - @Override - public void transform(Block.Builder blockBuilder, Op op) { - List operands = blockBuilder.context().getValues(op.operands()); - switch (op) { - case CoreOp.VarAccessOp.VarLoadOp loadOp -> replaceOp(blockBuilder, loadOp, new TensorVarLoadOp(loadOp.resultType(), operands)); - case JavaOp.InvokeOp invokeOp -> replaceOp(blockBuilder, invokeOp, new TensorMMAOp(invokeOp.resultType(), operands)); - default -> blockBuilder.op(op); - } - } - } - - private static class TensorLoad implements TensorTransformer { - - @Override - public void transform(Block.Builder blockBuilder, Op op) { - List operands = blockBuilder.context().getValues(op.operands()); - switch (op) { - case CoreOp.VarAccessOp.VarStoreOp storeOp -> - replaceOp(blockBuilder, storeOp, new TensorStoreLoadOp(storeOp.resultType(), operands)); - case JavaOp.InvokeOp invokeOp -> - replaceOp(blockBuilder, invokeOp, new TensorLoadOp(invokeOp.resultType(), operands)); - default -> blockBuilder.op(op); - } - } - } - - private static class TensorStore implements TensorTransformer { - - @Override - public void transform(Block.Builder blockBuilder, Op op) { - if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp) { - replaceOp(blockBuilder, invokeOp, new TensorStoreOp(invokeOp.resultType(), blockBuilder.context().getValues(invokeOp.operands()))); - } else { - blockBuilder.op(op); - } - } - } - - private CoreOp.FuncOp transformWithPredicate(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, BiConsumer function, Set opsToProcess ) { - return Trxfmr.of(lookup, funcOp).transform(opsToProcess::contains, (blockBuilder, op) -> { - function.accept(blockBuilder, op); - return blockBuilder; - }).funcOp(); - } - - private CoreOp.FuncOp createTensors(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set opsToProcess = new HashSet<>(); - OpHelper.Invoke.stream(lookup, funcOp) - .filter(invoke -> !invoke.returnsVoid()) - .filter(invoke -> invoke.refIs(Tensor.class)) - .filter(invoke -> invoke.name().equals("create") || invoke.name().equals("of")) - .forEach( invoke -> { - opsToProcess.add(invoke.op()); - invoke.op().result().uses().stream() - .filter(result -> (result.op() instanceof CoreOp.VarOp)) - .map(result -> (CoreOp.VarOp) result.op()) - .findFirst() - .ifPresent(opsToProcess::add); - }); - - return transformWithPredicate(lookup, funcOp, new TensorView()::transform, opsToProcess); - } - - private Set filterOps(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp, String methodIntrinsicName) { - Set opsToProcess = new HashSet<>(); - OpHelper.Invoke.stream(lookup, funcOp) - .filter(OpHelper.Invoke::returnsVoid) - .filter(invoke -> invoke.refIs(Tensor.class)) - .filter(invoke -> invoke.name().equals(methodIntrinsicName)) - .forEach(invoke -> { - opsToProcess.add(invoke.op()); - Value varLoadValue = invoke.op().operands().getFirst(); - if (varLoadValue.declaringElement() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { - opsToProcess.add(varLoadOp); - } - }); - return opsToProcess; - } - - private CoreOp.FuncOp fillTensors(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set opsToProcess = filterOps(lookup, funcOp, "fill"); - return transformWithPredicate(lookup, funcOp, new TensorFill()::transform, opsToProcess); - } - - private CoreOp.FuncOp mmaTensor(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set opsToProcess = filterOps(lookup, funcOp, "mma"); - return transformWithPredicate(lookup, funcOp, new TensorMMA()::transform, opsToProcess); - } - - private CoreOp.FuncOp tensorLoad(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set opsToProcess = new HashSet<>(); - OpHelper.Invoke.stream(lookup, funcOp) - .filter(invoke -> !invoke.returnsVoid()) - .filter(invoke -> invoke.refIs(Tensor.class)) - .filter(invoke -> invoke.name().equals("load")) - .forEach(invoke -> { - opsToProcess.add(invoke.op()); - invoke.op().result().uses().stream() - .filter(result -> (result.op() instanceof CoreOp.VarAccessOp.VarStoreOp)) - .map(result -> (CoreOp.VarAccessOp.VarStoreOp) result.op()) - .forEach(opsToProcess::add); - }); - return transformWithPredicate(lookup, funcOp, new TensorLoad()::transform, opsToProcess); - } - - private CoreOp.FuncOp tensorStoreOp(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set opsToProcess = new HashSet<>(); - OpHelper.Invoke.stream(lookup, funcOp) - .filter(OpHelper.Invoke::returnsVoid) - .filter(invoke -> invoke.refIs(Tensor.class)) - .filter(invoke -> invoke.name().equals("store")) - .forEach(invoke -> opsToProcess.add(invoke.op())); - return transformWithPredicate(lookup, funcOp, new TensorStore()::transform, opsToProcess); - } - - @Override - public CoreOp.FuncOp transform(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - funcOp = createTensors(lookup, funcOp); - funcOp = fillTensors(lookup, funcOp); - funcOp = mmaTensor(lookup, funcOp); - funcOp = tensorLoad(lookup, funcOp); - funcOp = tensorStoreOp(lookup, funcOp); - return funcOp; - } -} diff --git a/hat/core/src/main/java/hat/phases/HATTier.java b/hat/core/src/main/java/hat/phases/HATTier.java index e21a67c226e..8b1378e9e74 100644 --- a/hat/core/src/main/java/hat/phases/HATTier.java +++ b/hat/core/src/main/java/hat/phases/HATTier.java @@ -43,7 +43,6 @@ public class HATTier { // ID's /thread access new HATThreadsPhase(), - new HATWarpSizePhase(), // MathLib phase new HATMathLibPhase(), // views for vector types @@ -60,10 +59,7 @@ public class HATTier { // Vector Select individual lines new HATVectorSelectPhase(), // F16 type - new HATFP16Phase(), - - // Tensors - new HATTensorsPhase() + new HATFP16Phase() ); public static void transform(List phases, MethodHandles.Lookup lookup, FuncOpCarrier funcOpCarrier, boolean showCompilationPhases){ diff --git a/hat/core/src/main/java/hat/phases/HATWarpSizePhase.java b/hat/core/src/main/java/hat/phases/HATWarpSizePhase.java deleted file mode 100644 index 59a0062ad11..00000000000 --- a/hat/core/src/main/java/hat/phases/HATWarpSizePhase.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.phases; - -import hat.KernelContext; -import hat.dialect.HATThreadOp; -import jdk.incubator.code.CodeElement; -import jdk.incubator.code.dialect.core.CoreOp; -import optkl.OpHelper; -import optkl.Trxfmr; - -import java.lang.invoke.MethodHandles; -import java.util.HashSet; -import java.util.Set; - -import static optkl.OpHelper.FieldAccess.fieldAccess; - -public record HATWarpSizePhase() implements HATPhase { - - @Override - public CoreOp.FuncOp transform(MethodHandles.Lookup lookup, CoreOp.FuncOp funcOp) { - Set> varAccessesToBeRemoved = new HashSet<>(); - return Trxfmr.of(lookup, funcOp) - .transform(c -> { - if (fieldAccess(lookup, c.op()) instanceof OpHelper.FieldAccess.Instance fieldAccess - && fieldAccess.refType(KernelContext.class) && fieldAccess.nameMatchesRegex("wrs")) { - varAccessesToBeRemoved.add(fieldAccess.instanceVarAccess().op()); - c.replace(HATThreadOp.create(fieldAccess.name())); - }}) - .remap(varAccessesToBeRemoved) - .remove(varAccessesToBeRemoved::contains) - .funcOp(); - } -} diff --git a/hat/core/src/main/java/hat/types/Tensor.java b/hat/core/src/main/java/hat/types/Tensor.java deleted file mode 100644 index fda512fc6c9..00000000000 --- a/hat/core/src/main/java/hat/types/Tensor.java +++ /dev/null @@ -1,74 +0,0 @@ -package hat.types; - -import hat.buffer.F16Array; -import hat.buffer.F32Array; -import hat.buffer.F32ArrayPadded; -import optkl.IfaceValue; - -// Tensors are immutable -public record Tensor(int first, Shape shape, Class klass, Access tensorAccess) implements IfaceValue { - - public static final int FIRST = 0; - public static final int SECOND = 1; - public static final int ACC = 2; - - public static Shape shape(int dim1, int dim2, int dim3) { - return new Shape(dim1, dim2, dim3); - } - - public static Tensor create(int first, Shape shape, Class klass, final Access tensorAccess) { - return new Tensor(first, shape, klass, tensorAccess); - } - - public static Tensor create(int first, Shape shape, Class klass) { - return new Tensor(first, shape, klass, null); - } - - // Do we do a = fill(a, v)? or void fill(a, v)? - public static void fill(Tensor acc, float value) { - } - - public static void mma(Tensor result, Tensor tensorA, Tensor tensorB, Tensor acc) { - } - - public static Tensor load(F16Array matrix, int i, int j, int ld) { - return null; - } - - public static void store(F32Array matrix, int i, int j, Tensor resultTensor, int ld, Access tensorAccess) { - } - - public static void store(F32ArrayPadded matrix, int i, int j, Tensor resultTensor, int ld, Access tensorAccess) { - } - - public record Shape(int x, int y, int z) { - } - - public static class Accessor { - public static final int ROW_MAJOR = 0; - public static final int COL_MAJOR = 1; - public static final int NOT_DEFINED = -1; - - private Accessor() { - } - } - - public interface Access { - - } - - public record ColumMajor() implements Access { - } - - public record RowMajor() implements Access { - } - - public static ColumMajor ofColumnMajor() { - return new ColumMajor(); - } - - public static RowMajor ofRowMajor() { - return new RowMajor(); - } - -} diff --git a/hat/examples/matmul/src/main/java/matmul/Main.java b/hat/examples/matmul/src/main/java/matmul/Main.java index 178d62b5cd8..d46dbf0079e 100644 --- a/hat/examples/matmul/src/main/java/matmul/Main.java +++ b/hat/examples/matmul/src/main/java/matmul/Main.java @@ -30,7 +30,6 @@ import hat.KernelContext; import hat.NDRange.Global2D; import hat.NDRange.Local2D; -import hat.annotations.Kernel; import hat.backend.Backend; import hat.examples.common.HATExampleException; import hat.examples.common.ParseArgs; diff --git a/hat/tests/src/main/java/hat/test/TestTensors.java b/hat/tests/src/main/java/hat/test/TestTensors.java deleted file mode 100644 index bd3dddbada0..00000000000 --- a/hat/tests/src/main/java/hat/test/TestTensors.java +++ /dev/null @@ -1,382 +0,0 @@ -/* - * Copyright (c) 2026, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * This code is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License version 2 only, as - * published by the Free Software Foundation. Oracle designates this - * particular file as subject to the "Classpath" exception as provided - * by Oracle in the LICENSE file that accompanied this code. - * - * This code is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - * version 2 for more details (a copy is included in the LICENSE file that - * accompanied this code). - * - * You should have received a copy of the GNU General Public License version - * 2 along with this work; if not, write to the Free Software Foundation, - * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. - * - * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA - * or visit www.oracle.com if you need additional information or have any - * questions. - */ -package hat.test; - -import hat.Accelerator; -import hat.ComputeContext; -import hat.KernelContext; -import hat.NDRange.Tile2D; -import hat.backend.Backend; -import hat.buffer.F16Array; -import hat.buffer.F32Array; -import hat.buffer.F32ArrayPadded; -import hat.test.annotation.HatTest; -import hat.test.exceptions.HATAssertionError; -import hat.test.exceptions.HATAsserts; -import hat.test.exceptions.HATExpectedPrecisionError; -import hat.types.F16; -import hat.types.Tensor; -import jdk.incubator.code.Reflect; - -import java.lang.invoke.MethodHandles; -import java.util.Random; - -import static hat.NDRange.Global2D; -import static hat.NDRange.Local2D; -import static hat.NDRange.NDRange2D; -import static hat.NDRange.Warp2D; -import static optkl.ifacemapper.MappableIface.RO; -import static optkl.ifacemapper.MappableIface.WO; - -/** - * Check tensor operations in HAT. How to run? - * - *

- * - * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-cuda hat.test.TestTensors - * HAT=SHOW_CODE java -cp hat/job.jar hat.java test ffi-opencl hat.test.TestTensors - * - *

- * - */ -public class TestTensors { - - @Reflect - public static void mxmTensorsColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) { - final int WMMA_M = 16; - final int WMMA_N = 16; - final int WMMA_K = 16; - int warpM = kc.gix / kc.wrs; - int warpN = kc.giy; - - final int lda = 1024; - final int ldb = 1024; - final int ldc = 1024; - - Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor()); - Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor()); - Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class); - - Tensor.fill(acc, 0.0f); - - for (int i = 0; i < size; i += WMMA_K) { - int aRow = warpM * WMMA_M; - int aCol = i; - - int bRow = i; - int bCol = warpN * WMMA_N; - - if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) { - - tensorA = Tensor.load(matrixA, aRow, aCol, lda); - tensorB = Tensor.load(matrixB, bRow, bCol, ldb); - - // acc = tensorA * tensorB + acc - Tensor.mma(acc, tensorA, tensorB, acc); - } - } - int cRow = warpM * WMMA_M; - int cCol = warpN * WMMA_N; - Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor()); - } - - @Reflect - public static void mxmTensorsColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) { - // The total number of threads is calculated as follows: - // [ (size / tile), (size / tile) ] - // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the - // HAT runtime and HAT JIT compiler handle this automatically. - - var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize), - Local2D.of(128, 4), - Tile2D.of(16, 16), - Warp2D.of(true, false)); - - cc.dispatchKernel(ndRange, kc -> mxmTensorsColumnMajor(kc, matrixA, matrixB, matrixC, globalSize)); - } - - @Reflect - public static void mxmTensorsRowColumnMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int size) { - final int WMMA_M = 16; - final int WMMA_N = 16; - final int WMMA_K = 16; - int warpM = kc.gix / kc.wrs; - int warpN = kc.giy; - - final int lda = 1024; - final int ldb = 1024; - final int ldc = 1024; - - Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor()); - Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofColumnMajor()); - Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class); - - Tensor.fill(acc, 0.0f); - - for (int i = 0; i < size; i += WMMA_K) { - int aRow = warpM * WMMA_M; - int aCol = i; - - int bRow = i; - int bCol = warpN * WMMA_N; - - if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) { - - tensorA = Tensor.load(matrixA, aRow, aCol, lda); - tensorB = Tensor.load(matrixB, bRow, bCol, ldb); - - // acc = tensorA * tensorB + acc - Tensor.mma(acc, tensorA, tensorB, acc); - } - } - int cRow = warpM * WMMA_M; - int cCol = warpN * WMMA_N; - Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofColumnMajor()); - } - - @Reflect - public static void mxmTensorsRowColumnMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32Array matrixC, int globalSize) { - // The total number of threads is calculated as follows: - // [ (size / tile), (size / tile) ] - // If warpSize > 1, then each dimension using warp operations is multiplied by the value of the warp-size. This is architecture dependent, but the - // HAT runtime and HAT JIT compiler handle this automatically. - - var ndRange = NDRange2D.of(Global2D.of(globalSize, globalSize), - Local2D.of(128, 4), - Tile2D.of(16, 16), - Warp2D.of(true, false)); - - cc.dispatchKernel(ndRange, kc -> mxmTensorsRowColumnMajor(kc, matrixA, matrixB, matrixC, globalSize)); - } - - @Reflect - public static void mxmTensorsRowMajor(@RO KernelContext kc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int size) { - final int WMMA_M = 16; - final int WMMA_N = 16; - final int WMMA_K = 16; - int warpM = kc.gix / kc.wrs; - int warpN = kc.giy; - - final int lda = 1024; - final int ldb = 1024; - final int ldc = 1024; - - Tensor tensorA = Tensor.create(Tensor.FIRST, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor()); - Tensor tensorB = Tensor.create(Tensor.SECOND, Tensor.shape(16, 16, 16), F16.class, Tensor.ofRowMajor()); - Tensor acc = Tensor.create(Tensor.ACC, Tensor.shape(16, 16, 16), float.class); - - Tensor.fill(acc, 0.0f); - - for (int i = 0; i < size; i += WMMA_K) { - int aRow = warpM * WMMA_M; - int aCol = i; - - int bRow = i; - int bCol = warpN * WMMA_N; - - if (aRow < lda && aCol < lda && bRow < ldb && bCol < ldb) { - - tensorA = Tensor.load(matrixA, aRow, aCol, lda); - tensorB = Tensor.load(matrixB, bRow, bCol, ldb); - - // acc = tensorA * tensorB + acc - Tensor.mma(acc, tensorA, tensorB, acc); - } - } - int cRow = warpM * WMMA_M; - int cCol = warpN * WMMA_N; - Tensor.store(matrixC, cRow, cCol, acc, ldc, Tensor.ofRowMajor()); - } - - @Reflect - public static void mxmTensorsRowMajor(@RO ComputeContext cc, @RO F16Array matrixA, @RO F16Array matrixB, @WO F32ArrayPadded matrixC, int globalSize) { - var ndRange = NDRange2D.of( - Global2D.of(globalSize, globalSize), - Local2D.of(128, 4), - Tile2D.of(16, 16), - Warp2D.of(true, false)); - cc.dispatchKernel(ndRange, kc -> mxmTensorsRowMajor(kc, matrixA, matrixB, matrixC, globalSize)); - } - - private static void runSequentialColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) { - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - float sum = 0.0f; - for (int k = 0; k < size; k++) { - F16 a = matrixA.array((long) k * size + i); - F16 b = matrixB.array((long) j * size + k); - F16 mul = F16.mul(a, b); - sum += F16.f16ToFloat(mul); - } - matrixC.array((long) j * size + i, sum); - } - } - } - - private static void runSequentialRowAndColMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) { - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - float sum = 0.0f; - for (int k = 0; k < size; k++) { - F16 a = matrixA.array((long) i * size + k); - F16 b = matrixB.array((long) j * size + k); - F16 mul = F16.mul(a, b); - sum += F16.f16ToFloat(mul); - } - matrixC.array((long) j * size + i, sum); - } - } - } - - private static void runSequentialRowMajor(F16Array matrixA, F16Array matrixB, F32Array matrixC, final int size) { - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - float sum = 0.0f; - for (int k = 0; k < size; k++) { - F16 a = matrixA.array((long) i * size + k); - F16 b = matrixB.array((long) k * size + j); - F16 mul = F16.mul(a, b); - sum += F16.f16ToFloat(mul); - } - matrixC.array((long) i * size + j, sum); - } - } - } - - @HatTest - @Reflect - public void testTensor01() { - var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST); - final int size = 1024; - - F16Array matrixAHalf = F16Array.create(accelerator, size * size); - F16Array matrixBHalf = F16Array.create(accelerator, size * size); - F32Array matrixC = F32Array.create(accelerator, size * size); - F32Array resultSequential = F32Array.create(accelerator, size * size); - - Random r = new Random(19); - for (int j = 0; j < matrixAHalf.length(); j++) { - matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - } - - for (int i = 0; i < 10; i++) { - accelerator.compute(cc -> mxmTensorsColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size)); - } - - runSequentialColMajor(matrixAHalf, matrixBHalf, resultSequential, size); - - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - final int index = j * size + i; - float expectedValue = resultSequential.array(index); - float gotValue = matrixC.array(index); - try { - HATAsserts.assertEquals(expectedValue, gotValue, 0.1f); - } catch (HATAssertionError e) { - throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue); - } - } - } - } - - @HatTest - @Reflect - public void testTensor02() { - var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST); - final int size = 1024; - - F16Array matrixAHalf = F16Array.create(accelerator, size * size); - F16Array matrixBHalf = F16Array.create(accelerator, size * size); - F32Array matrixC = F32Array.create(accelerator, size * size); - F32Array resultSequential = F32Array.create(accelerator, size * size); - - Random r = new Random(19); - for (int j = 0; j < matrixAHalf.length(); j++) { - matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - } - - for (int i = 0; i < 10; i++) { - accelerator.compute(cc -> mxmTensorsRowColumnMajor(cc, matrixAHalf, matrixBHalf, matrixC, size)); - } - - runSequentialRowAndColMajor(matrixAHalf, matrixBHalf, resultSequential, size); - - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - final int index = j * size + i; - float expectedValue = resultSequential.array(index); - float gotValue = matrixC.array(index); - try { - HATAsserts.assertEquals(expectedValue, gotValue, 0.1f); - } catch (HATAssertionError e) { - throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue); - } - } - } - } - - @HatTest - @Reflect - public void testTensor03() { - - // To be able to run tensor-matmul in a row-major layout, we need to add padding. - // Thus, the result matrix must be of type F32ArrayPadded. - - var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST); - final int size = 1024; - - F16Array matrixAHalf = F16Array.create(accelerator, size * size); - F16Array matrixBHalf = F16Array.create(accelerator, size * size); - F32ArrayPadded matrixC = F32ArrayPadded.create(accelerator, size * size); - F32Array resultSequential = F32Array.create(accelerator, size * size); - - Random r = new Random(19); - for (int j = 0; j < matrixAHalf.length(); j++) { - matrixAHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - matrixBHalf.array(j).value(F16.floatToF16(r.nextFloat()).value()); - } - - for (int i = 0; i < 10; i++) { - accelerator.compute(cc -> mxmTensorsRowMajor(cc, matrixAHalf, matrixBHalf, matrixC, size)); - } - - runSequentialRowMajor(matrixAHalf, matrixBHalf, resultSequential, size); - - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - final int index = j * size + i; - float expectedValue = resultSequential.array(index); - float gotValue = matrixC.array(index); - try { - HATAsserts.assertEquals(expectedValue, gotValue, 0.1f); - } catch (HATAssertionError e) { - throw new HATExpectedPrecisionError("Expected: " + expectedValue + " but got " + gotValue); - } - } - } - } -}