Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,24 @@
import hat.callgraph.KernelCallGraph;
import hat.codebuilders.C99HATKernelBuilder;
import hat.dialect.HATF16Op;
import hat.dialect.HATTensorOp;
import hat.dialect.HATVectorOp;
import hat.types.F16;
import hat.types.Tensor;
import jdk.incubator.code.dialect.core.CoreOp;
import jdk.incubator.code.dialect.java.ClassType;
import jdk.incubator.code.dialect.java.FieldRef;
import jdk.incubator.code.dialect.java.JavaOp;
import optkl.OpHelper;
import optkl.codebuilders.CodeBuilder;
import optkl.codebuilders.ScopedCodeBuilderContext;
import hat.types.BF16;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;
import jdk.incubator.code.dialect.java.PrimitiveType;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SequencedSet;

import static optkl.OpHelper.Invoke.invoke;

public class CudaHATKernelBuilder extends C99HATKernelBuilder<CudaHATKernelBuilder> {

protected CudaHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilderContext scopedCodeBuilderContext) {
super(kernelCallGraph, scopedCodeBuilderContext);
}

@Override
protected CudaHATKernelBuilder hatWarpSize() {
return constant("32");
}

private CudaHATKernelBuilder half2float() {
return id("__half2float");
}
Expand Down Expand Up @@ -175,8 +158,7 @@ public CudaHATKernelBuilder defines() {
.includeSys("cuda_fp16.h", "cuda_bf16.h")
.hashDefine("BFLOAT16", _ -> keyword("__nv_bfloat16"))
.typedefSingleValueStruct("F16", "half")
.typedefSingleValueStruct("BF16", "BFLOAT16")
.includeSys("mma.h"); // only enable if tensor views are used
.typedefSingleValueStruct("BF16", "BFLOAT16");
}

@Override
Expand All @@ -192,7 +174,7 @@ public CudaHATKernelBuilder hatVectorStoreOp(HATVectorOp.HATVectorStoreView hatV
paren(_ -> {
ampersand().recurseResultOrThrow(dest);
either(hatVectorStoreView instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
id(ARRAY).sbrace(_ -> recurseResultOrThrow(index));
id("array").sbrace(_ -> recurseResultOrThrow(index));
});
sbrace(_ -> intConstZero());
sp().equals().sp();
Expand Down Expand Up @@ -266,7 +248,7 @@ public CudaHATKernelBuilder hatVectorLoadOp(HATVectorOp.HATVectorLoadOp hatVecto
paren(_ -> {
ampersand();recurseResultOrThrow(source);
either(hatVectorLoadOp instanceof HATVectorOp.Shared, CodeBuilder::dot, CodeBuilder::rarrow);
id(ARRAY).sbrace(_ -> recurseResultOrThrow(index));
id("array").sbrace(_ -> recurseResultOrThrow(index));
});
sbrace(_ -> intConstZero());
return self();
Expand Down Expand Up @@ -440,267 +422,4 @@ private CudaHATKernelBuilder generateFloat16ConversionToFloat(Class<?> float16Cl
protected String mapMathIntrinsic(String hatMathIntrinsicName) {
return MATH_FUNCTIONS.getOrDefault(hatMathIntrinsicName, hatMathIntrinsicName);
}

public static final String WMMA_MEM_COL_MAJOR = "nvcuda::wmma::mem_col_major";
public static final String WMMA_MEM_ROW_MAJOR = "nvcuda::wmma::mem_row_major";
public static final String WMMA_STORE_TENSOR = "nvcuda::wmma::store_matrix_sync";
public static final String WMMA_LOAD_TENSOR = "nvcuda::wmma::load_matrix_sync";
public static final String WMMA_MMA_TENSOR = "nvcuda::wmma::mma_sync";
public static final String WMMA_FILL_TENSOR = "nvcuda::wmma::fill_fragment";
public static final String WMMA_COL_MAJOR = "nvcuda::wmma::col_major";
public static final String WMMA_ROW_MAJOR = "nvcuda::wmma::row_major";
public static final String WMMA_FRAGMENT_BASE = "nvcuda::wmma::fragment<nvcuda::wmma::";

@Override
public CudaHATKernelBuilder hatTensorVarOp(HATTensorOp.TensorVarOp tensorVarOp) {
recurse(OpHelper.asResultOrThrow(tensorVarOp.operands().getFirst()).op());
sp().id(tensorVarOp.varName());
return self();
}

private CudaHATKernelBuilder generateCreateTensor(List<Integer> shape, String matrixOrder, String type, Value access) {
id(WMMA_FRAGMENT_BASE)
.id(matrixOrder)
.comma().sp()
.intValue(shape.getFirst())
.comma().sp()
.intValue(shape.get(1))
.comma().sp()
.intValue(shape.get(2))
.comma().sp()
.type(type);

if (matrixOrder.equals(TENSOR_ACC)) {
gt();
} else {// infer from the last parameter
if (access.declaringElement() instanceof JavaOp.InvokeOp invokeOp) {
// Expecting an invokeOp
var invoke = invoke(scopedCodeBuilderContext().lookup(), invokeOp);
comma();
if (invoke.resultTypeIs(Tensor.ColumMajor.class)) {
id(WMMA_COL_MAJOR);
} else if (invoke.resultTypeIs(Tensor.RowMajor.class)) {
id(WMMA_ROW_MAJOR);
} else {
throw new CUDACodeGenException("[Error]");
}
gt();
}
}
return self();
}

private static final String TENSOR_MATRIX_A = "matrix_a";
private static final String TENSOR_MATRIX_B = "matrix_b";
private final String TENSOR_ACC = "accumulator";

private String getMatrixOrder(Value valueParameter) {
if (valueParameter instanceof Op.Result r && r.op() instanceof JavaOp.FieldAccessOp.FieldLoadOp fieldLoadOp) {
FieldRef fieldRef = fieldLoadOp.fieldReference();
return switch (fieldRef.name()) {
case "FIRST" -> TENSOR_MATRIX_A;
case "SECOND" -> TENSOR_MATRIX_B;
default -> TENSOR_ACC;
};
}
return null;
}

@Override
public CudaHATKernelBuilder hatTensorCreateOp(HATTensorOp.TensorCreateOp tensorCreateOp) {
// infer first parameter
List<Value> operands = tensorCreateOp.operands();
Value first = operands.getFirst();
// The first operand gives us the matrix order or accumulator
String matrixOrder = getMatrixOrder(first);

// Second parameters: analysis of the shape
List<Integer> shape = new ArrayList<>();

Value second = operands.get(1);
if (second.declaringElement() instanceof JavaOp.InvokeOp invokeOp) {
List<Value> shapeOperands = invokeOp.operands();
for (Value shapeOperand : shapeOperands) {
if (shapeOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) {
shape.add((int) constantOp.value());
} else {
throw new CUDACodeGenException("Error: expected to find a ConstantOp, but found a " + shapeOperand.declaringElement().getClass());
}
}
} else {
throw new CUDACodeGenException("InvokeOp expected, but found: " + second.declaringElement().getClass());
}
if (shape.size() != 3) {
throw new CUDACodeGenException("Shape must have three values");
}

// The third parameter is the type. It could be `half` or `float` as first implementation
// This parameter is another constant with the type
Value classOperand = operands.get(2);
Object klass = null;
if (classOperand.declaringElement() instanceof CoreOp.ConstantOp constantOp) {
klass = constantOp.value();
}

String type = "";
if (klass instanceof ClassType classType && classType.toClassName().equals(F16.class.getCanonicalName())) {
type = "half";
} else if (klass instanceof PrimitiveType primitiveType && primitiveType.equals(PrimitiveType.FLOAT)) {
type = "float";
}

Value access = operands.getLast();
return generateCreateTensor(shape, matrixOrder, type, access);
}

@Override
public CudaHATKernelBuilder hatTensorVarLoadOp(HATTensorOp.TensorVarLoadOp hatTensorVarLoadOp) {
Value operand = hatTensorVarLoadOp.operands().getFirst();
if (operand instanceof Op.Result r && r.op() instanceof HATTensorOp.TensorVarOp tensorVarOp) {
varName(tensorVarOp.varName());
} else {
throw new CUDACodeGenException("[ERROR] Expected HATTensorVarOp");
}
return self();
}

@Override
public CudaHATKernelBuilder hatTensorFillOp(HATTensorOp.TensorFillOp tensorFillOp) {
id(WMMA_FILL_TENSOR).paren( _-> {
List<Value> operands = tensorFillOp.operands();
recurseResultOrThrow(operands.getFirst())
.comma()
.recurseResultOrThrow(operands.get(1));
});
return self();
}

@Override
public CUDACodeGenException launchBackendException(String message) {
return new CUDACodeGenException(message);
}

@Override
public CudaHATKernelBuilder hatTensorMMAOp(HATTensorOp.TensorMMAOp tensorMMAOp) {
id(WMMA_MMA_TENSOR).paren( _-> commaSeparated(tensorMMAOp.operands(), this::recurseResultOrThrow));
return self();
}

private CudaHATKernelBuilder generateLoadTensor(HATTensorOp.TensorLoadOp tensorLoadOp, boolean isColumnMajor, String tensorName) {
// First operand is the reference to global memory
List<Value> operands = tensorLoadOp.operands();
Value reference = operands.getFirst();
id(WMMA_LOAD_TENSOR)
.paren(_ -> {
id(tensorName).comma();
paren(_ -> type("half").asterisk());
recurseResultOrThrow(reference);
rarrow().id(ARRAY)
.sp().plus().sp()
.indexForTensor(isColumnMajor, operands.get(1), operands.get(2), operands.get(3))
.comma();
recurseResultOrThrow(operands.get(3));
});

return self();
}

/**
* Example of code being generated:
*
* <p>
* <code>
* wmma::load_matrix_sync(a_frag, matrix->array + headSize + aRow + aCol * lda, lda);
* </code>
* </p>
*
* @param tensorLoadOp
*
* @return {@link CudaHATKernelBuilder}
*/
@Override
public CudaHATKernelBuilder hatTensorLoadOp(HATTensorOp.TensorLoadOp tensorLoadOp) {
// Find name tensor of the first argument
String tensorName = "";
SequencedSet<Op.Result> uses = tensorLoadOp.result().uses();
HATTensorOp.TensorVarOp tensorVarOp = null;
for (Op.Result result : uses) {
if (result.declaringElement() instanceof HATTensorOp.TensorStoreLoadOp storeLoadOp) {
// obtain first arg from tensorStoreOp
Value first = storeLoadOp.operands().getFirst();
if (first.declaringElement() instanceof HATTensorOp.TensorVarOp varOp) {
tensorVarOp = varOp;
tensorName = tensorVarOp.varName();
}
}
}

boolean isColumnMajor = true;
if (tensorVarOp != null) {
Value value = tensorVarOp.operands().getFirst();
if (value.declaringElement() instanceof HATTensorOp.TensorCreateOp createOp) {
Value tensorLayout = createOp.operands().getLast();
isColumnMajor = isColumnMajor(tensorLayout);
}
}
return generateLoadTensor(tensorLoadOp, isColumnMajor, tensorName);
}

@Override
public CudaHATKernelBuilder hatTensorStoreLoadOp(HATTensorOp.TensorStoreLoadOp hatTensorStoreLoadOp) {
List<Value> operands = hatTensorStoreLoadOp.operands();
recurseResultOrThrow(operands.getLast());
return self();
}

/**
* Example of code being generated:
*
* <p>
* <code>
* store_matrix_sync(matrix->array + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
* </code>
* </p>
*
* @param operands
* @param isColumnMajor
*
* @return {@link CudaHATKernelBuilder}
*/
private CudaHATKernelBuilder generateStoreTensor(List<Value> operands, boolean isColumnMajor) {
Value reference = operands.getFirst();
id(WMMA_STORE_TENSOR).paren(_ -> {
Value iIndex = operands.get(1);
Value jIndex = operands.get(2);
Value tensorToStore = operands.get(3);
Value ldSize = operands.get(4);

recurseResultOrThrow(reference)
.rarrow().id(ARRAY)
.sp().plus().sp()
.indexForTensor(isColumnMajor, iIndex, jIndex, ldSize)
.comma()
.recurseResultOrThrow(tensorToStore)
.comma()
.recurseResultOrThrow(ldSize)
.comma();

if (isColumnMajor) {
id(WMMA_MEM_COL_MAJOR);
} else {
id(WMMA_MEM_ROW_MAJOR);
}
});
return self();
}

@Override
public CudaHATKernelBuilder hatTensorStoreOp(HATTensorOp.TensorStoreOp tensorStoreOp) {
List<Value> operands = tensorStoreOp.operands();
// Access layout is the last operand
final boolean isColumnMajor = isColumnMajor(operands.get(5));
return generateStoreTensor(operands, isColumnMajor);
}

private static final String ARRAY = "array";
}
Loading