Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
635 changes: 316 additions & 319 deletions hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ PtxSource *CudaBackend::nvcc(const CudaSource *cudaSource) {
// create var/cuda directory
std::string localDirectory = "./var/cuda";
std::filesystem::create_directories(localDirectory);
// create temp file for cuda generarated code
// create temp file for cuda generated code
const uint64_t time = timeSinceEpochMillisec();
const std::string ptxPath = tmpFileName(time, localDirectory, ".ptx");
const std::string cudaPath = tmpFileName(time, localDirectory, ".cu");
Expand Down Expand Up @@ -338,6 +338,12 @@ void CudaBackend::computeStart() {
queue->computeStart();
}

std::string* CudaBackend::getDeviceVendor() {
// The CUDA Backend is owned by NVIDIA. Thus, no need to query
auto *vendor = new std::string("NVIDIA");
return reinterpret_cast<std::string *>(vendor->data());
}

bool CudaBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) {
if (config->traceCalls) {
std::cout << "getBufferFromDeviceIfDirty(" << std::hex << reinterpret_cast<long>(memorySegment) << "," <<
Expand Down
20 changes: 13 additions & 7 deletions hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,36 @@ class CudaSource final :public Text {

class CudaBackend final : public Backend {
public:
class CudaQueue final : public Backend::Queue {
class CudaQueue final : public Backend::Queue {
public:
std::thread::id streamCreationThread;
CUstream cuStream;

explicit CudaQueue(Backend *backend);

void init();

void wait() override;

void release() override;
void release() override;

void computeStart() override;
void computeStart() override;

void computeEnd() override;
void computeEnd() override;

void copyToDevice(Buffer *buffer) override;
void copyToDevice(Buffer *buffer) override;

void copyFromDevice(Buffer *buffer) override;
void copyFromDevice(Buffer *buffer) override;

int estimateThreadsPerBlock(int dimensions);

int estimateThreadsPerBlock(int dimensions, int globalSizePerDimension, int localSize);

void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;


~CudaQueue() override;
};
};

class CudaBuffer final : public Buffer {
public:
Expand Down Expand Up @@ -188,6 +192,8 @@ class CudaQueue final : public Backend::Queue {

explicit CudaBackend(int mode);

std::string *getDeviceVendor() override;

~CudaBackend() override;
static CudaBackend * of(long backendHandle);
static CudaBackend * of(Backend *backend);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
package hat.backend.ffi;


import hat.ComputeContext;
import hat.Config;
import hat.KernelContext;
Expand All @@ -36,13 +35,13 @@
public class MockBackend extends FFIBackend {

public MockBackend(Arena arena, MethodHandles.Lookup lookup) {
super(arena,lookup,"mock_backend", Config.fromIntBits(0));
super(arena, lookup, "mock_backend", Config.fromIntBits(0));
}

@Override
public void computeContextHandoff(ComputeContext computeContext) {
System.out.println("Mock backend received closed closure");
computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp())));
computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp())));
}

@Override
Expand Down
4 changes: 4 additions & 0 deletions hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class MockBackend final : public Backend {
std::cout << "mock compute start()" << std::endl;
}

std::string *getDeviceVendor() override {
return new std::string("Mock Vendor");
}

CompilationUnit *compile(int len, char *source) override {
std::cout << "mock compileProgram()" << std::endl;
size_t srcLen = ::strlen(source);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
package hat.backend.ffi;


import hat.ComputeContext;
import hat.Config;
import hat.KernelContext;
Expand All @@ -36,26 +35,28 @@

public class OpenCLBackend extends C99FFIBackend {
public OpenCLBackend(Config config) {
super(Arena.global(), MethodHandles.lookup(),"opencl_backend", config);
super(Arena.global(), MethodHandles.lookup(), "opencl_backend", config);
}

public OpenCLBackend() {
this(Config.fromEnvOrProperty());
}

@Override
public void computeContextHandoff(ComputeContext computeContext) {
computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp()));
computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp()));
}

@Override
public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) {
CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (_) -> {
String code = createC99(kernelCallGraph, args);
CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (KernelCallGraph _) -> {
String code = createC99(kernelCallGraph, args);
if (config().showCode()) {
System.out.println(code);
IO.println(code);
}
var compilationUnit = backendBridge.compile(code);
if (compilationUnit.ok()) {
var kernel = compilationUnit.getKernel( kernelCallGraph.callDag.entryPoint.method().getName());
var kernel = compilationUnit.getKernel(kernelCallGraph.callDag.entryPoint.method().getName());
return new CompiledKernel(this, kernelCallGraph, kernel, args);
} else {
// TODO: We should capture the log from OpenCL and provide as exception message
Expand All @@ -65,8 +66,9 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernel
compiledKernel.dispatch(kernelContext, args);
}

String createC99(KernelCallGraph kernelCallGraph, Object[] args){
return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(),kernelCallGraph.callDag.entryPoint.funcOp())), args);
String createC99(KernelCallGraph kernelCallGraph, Object[] args) {
kernelCallGraph.setDeviceVendor(backendBridge.getDeviceVendor());
return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.callDag.entryPoint.funcOp())), args);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import hat.types.BF16;
import hat.types.F16;
import optkl.OpHelper;
import jdk.incubator.code.Op;
import optkl.codebuilders.CodeBuilder;
import jdk.incubator.code.Value;
import jdk.incubator.code.dialect.core.CoreOp;
Expand All @@ -50,7 +51,7 @@
public class OpenCLHATKernelBuilder extends C99HATKernelBuilder<OpenCLHATKernelBuilder> {

protected OpenCLHATKernelBuilder(KernelCallGraph kernelCallGraph, ScopedCodeBuilderContext scopedCodeBuilderContext) {
super(kernelCallGraph,scopedCodeBuilderContext);
super(kernelCallGraph, scopedCodeBuilderContext);
}

@Override
Expand Down Expand Up @@ -222,6 +223,7 @@ public OpenCLHATKernelBuilder hatF16ToFloatConvOp( HATF16Op.HATF16ToFloatConvOp

// Mapping between API function names and OpenCL intrinsics for the math operations
private static final Map<String, String> MATH_FUNCTIONS = new HashMap<>();

static {
MATH_FUNCTIONS.put("maxf", "max");
MATH_FUNCTIONS.put("maxd", "max");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ void OpenCLBackend::computeEnd() {
}
}

std::string* OpenCLBackend::getDeviceVendor() {
const PlatformInfo platformInfo(this);
return new std::string(platformInfo.vendorName);
}

OpenCLBackend::OpenCLProgram *OpenCLBackend::compileProgram(OpenCLSource &openclSource) {
return compileProgram(&openclSource);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class OpenCLBackend final : public Backend {

void computeEnd() override;

std::string *getDeviceVendor() override;

bool getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) override;

void shortDeviceInfo() override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
*/
package hat.backend.ffi;


import hat.Config;
import hat.backend.Backend;
import hat.buffer.ArgArray;
import optkl.ifacemapper.MappableIface;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandles;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -55,16 +55,19 @@ public static class KernelBridge {
final FFILib.VoidHandleMethodPtr releaseKernel_MPtr;
String name;
final FFILib.LongHandleLongAddressMethodPtr ndrange_MPtr;

KernelBridge(CompilationUnitBridge compilationUnitBridge, String name, long handle) {
this.compilationUnitBridge = compilationUnitBridge;
this.handle = handle;
this.releaseKernel_MPtr = compilationUnitBridge.backendBridge.ffiLib.voidHandleFunc("releaseKernel");
this.ndrange_MPtr = compilationUnitBridge.backendBridge.ffiLib.longHandleLongAddressFunc("ndrange");
this.name = name;
}

public void ndRange(ArgArray argArray) {
this.ndrange_MPtr.invoke(handle, MappableIface.getMemorySegment(argArray));
}

void release() {
releaseKernel_MPtr.invoke(handle);
}
Expand All @@ -86,12 +89,15 @@ void release() {
this.compilationUnitOK_MPtr = backendBridge.ffiLib.booleanHandleFunc("compilationUnitOK");
this.getKernel_MPtr = backendBridge.ffiLib.longHandleIntAddressFunc("getKernel");
}

void release() {
this.releaseCompilationUnit_MPtr.invoke(handle);
}

boolean ok() {
return this.compilationUnitOK_MPtr.invoke(handle);
}

public KernelBridge getKernel(String kernelName) {
return kernels.computeIfAbsent(kernelName, _ ->
new KernelBridge(this, kernelName,
Expand All @@ -111,6 +117,9 @@ public KernelBridge getKernel(String kernelName) {

final FFILib.VoidHandleMethodPtr showDeviceInfo_MPtr;
final FFILib.BooleanHandleAddressLongMethodPtr getBufferFromDeviceIfDirty_MPtr;
final FFILib.StringFunctionMethodPtr getVendorFunction;
final FFILib.StringFunctionLengthMethodPtr stringFunctionLength;

BackendBridge(FFILib ffiLib, Config config) {
this.ffiLib = ffiLib;
this.getBackend_MPtr = ffiLib.longHandleIntFunc("getBackend");
Expand All @@ -122,10 +131,13 @@ public KernelBridge getKernel(String kernelName) {
this.showDeviceInfo_MPtr = ffiLib.voidHandleFunc("showDeviceInfo");
this.computeStart_MPtr = ffiLib.voidHandleFunc("computeStart");
this.computeEnd_MPtr = ffiLib.voidHandleFunc("computeEnd");
this.getVendorFunction = ffiLib.stringHandleFunc("getDeviceVendor");
this.stringFunctionLength = ffiLib.stringFunctionLengthMethodPtr("getStringLength");
this.getBufferFromDeviceIfDirty_MPtr = ffiLib.booleanHandleAddressLongFunc("getBufferFromDeviceIfDirty");
}

void release() {}
void release() {
}

public long getBackend(int configBits) {
return getBackend_MPtr.invoke(configBits);
Expand All @@ -142,20 +154,29 @@ public CompilationUnitBridge compile(String source) {
return compilationUnit(compilationUnitHandle, source);
}

public Vendor getDeviceVendor() {
MemorySegment vendorNameSegment = getVendorFunction.invoke(handle);
long sizeString = stringFunctionLength.invoke(vendorNameSegment);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We created a memorySegment just for the Vendor Name?

Maybe we can create a MemorSegment with some structure. To receive other vendor info later.

So maybe a MemorySegment containing a list of Name/Value/Type triples

We can a MappedMemorySegment then for all values....

byte[] content = vendorNameSegment.reinterpret(sizeString).toArray(ValueLayout.JAVA_BYTE);
return Vendor.of(new String(content));
}

public MappableIface getBufferFromDeviceIfDirty(MappableIface buffer) {
MemorySegment memorySegment = MappableIface.getMemorySegment(buffer);
if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())){
if (!getBufferFromDeviceIfDirty_MPtr.invoke(handle, memorySegment, memorySegment.byteSize())) {
throw new IllegalStateException("Failed to get buffer from backend");
}
return buffer;

}

public void computeStart() {
computeStart_MPtr.invoke(handle);
}

public void computeEnd() {
computeEnd_MPtr.invoke(handle);
}

public void showDeviceInfo() {
showDeviceInfo_MPtr.invoke(handle);
}
Expand All @@ -164,8 +185,8 @@ public void showDeviceInfo() {
public final FFILib ffiLib;
public final BackendBridge backendBridge;

public FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup,String libName, Config config) {
super(arena,lookup,config);
protected FFIBackendDriver(Arena arena, MethodHandles.Lookup lookup, String libName, Config config) {
super(arena, lookup, config);
this.ffiLib = new FFILib(libName);
this.backendBridge = new BackendBridge(ffiLib, config);
}
Expand Down
Loading