Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
635 changes: 316 additions & 319 deletions hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ PtxSource *CudaBackend::nvcc(const CudaSource *cudaSource) {
// create var/cuda directory
std::string localDirectory = "./var/cuda";
std::filesystem::create_directories(localDirectory);
// create temp file for cuda generarated code
// create temp file for cuda generated code
const uint64_t time = timeSinceEpochMillisec();
const std::string ptxPath = tmpFileName(time, localDirectory, ".ptx");
const std::string cudaPath = tmpFileName(time, localDirectory, ".cu");
Expand Down Expand Up @@ -338,6 +338,12 @@ void CudaBackend::computeStart() {
queue->computeStart();
}

std::string* CudaBackend::getDeviceVendor() {
// The CUDA Backend is owned by NVIDIA. Thus, no need to query
auto *vendor = new std::string("NVIDIA");
return reinterpret_cast<std::string *>(vendor->data());
}

bool CudaBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memorySegmentLength) {
if (config->traceCalls) {
std::cout << "getBufferFromDeviceIfDirty(" << std::hex << reinterpret_cast<long>(memorySegment) << "," <<
Expand Down
20 changes: 13 additions & 7 deletions hat/backends/ffi/cuda/src/main/native/include/cuda_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,36 @@ class CudaSource final :public Text {

class CudaBackend final : public Backend {
public:
class CudaQueue final : public Backend::Queue {
class CudaQueue final : public Backend::Queue {
public:
std::thread::id streamCreationThread;
CUstream cuStream;

explicit CudaQueue(Backend *backend);

void init();

void wait() override;

void release() override;
void release() override;

void computeStart() override;
void computeStart() override;

void computeEnd() override;
void computeEnd() override;

void copyToDevice(Buffer *buffer) override;
void copyToDevice(Buffer *buffer) override;

void copyFromDevice(Buffer *buffer) override;
void copyFromDevice(Buffer *buffer) override;

int estimateThreadsPerBlock(int dimensions);

int estimateThreadsPerBlock(int dimensions, int globalSizePerDimension, int localSize);

void dispatch(KernelContext *kernelContext, CompilationUnit::Kernel *kernel) override;


~CudaQueue() override;
};
};

class CudaBuffer final : public Buffer {
public:
Expand Down Expand Up @@ -188,6 +192,8 @@ class CudaQueue final : public Backend::Queue {

explicit CudaBackend(int mode);

std::string *getDeviceVendor() override;

~CudaBackend() override;
static CudaBackend * of(long backendHandle);
static CudaBackend * of(Backend *backend);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
package hat.backend.ffi;


import hat.ComputeContext;
import hat.Config;
import hat.KernelContext;
Expand All @@ -36,13 +35,13 @@
public class MockBackend extends FFIBackend {

public MockBackend(Arena arena, MethodHandles.Lookup lookup) {
super(arena,lookup,"mock_backend", Config.fromIntBits(0));
super(arena, lookup, "mock_backend", Config.fromIntBits(0));
}

@Override
public void computeContextHandoff(ComputeContext computeContext) {
System.out.println("Mock backend received closed closure");
computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp())));
computeContext.computeCallGraph().callDag.entryPoint.funcOp((injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp())));
}

@Override
Expand Down
4 changes: 4 additions & 0 deletions hat/backends/ffi/mock/src/main/native/cpp/mock_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class MockBackend final : public Backend {
std::cout << "mock compute start()" << std::endl;
}

std::string *getDeviceVendor() override {
return new std::string("Mock Vendor");
}

CompilationUnit *compile(int len, char *source) override {
std::cout << "mock compileProgram()" << std::endl;
size_t srcLen = ::strlen(source);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*/
package hat.backend.ffi;


import hat.ComputeContext;
import hat.Config;
import hat.KernelContext;
Expand All @@ -36,26 +35,28 @@

public class OpenCLBackend extends C99FFIBackend {
public OpenCLBackend(Config config) {
super(Arena.global(), MethodHandles.lookup(),"opencl_backend", config);
super(Arena.global(), MethodHandles.lookup(), "opencl_backend", config);
}

public OpenCLBackend() {
this(Config.fromEnvOrProperty());
}

@Override
public void computeContextHandoff(ComputeContext computeContext) {
computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(),lookup(),computeContext.computeCallGraph().callDag.entryPoint.funcOp()));
computeContext.computeCallGraph().callDag.entryPoint.funcOp(injectBufferTracking(config(), lookup(), computeContext.computeCallGraph().callDag.entryPoint.funcOp()));
}

@Override
public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernelContext, Object... args) {
CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (_) -> {
String code = createC99(kernelCallGraph, args);
CompiledKernel compiledKernel = kernelCallGraphCompiledCodeMap.computeIfAbsent(kernelCallGraph, (KernelCallGraph _) -> {
String code = createC99(kernelCallGraph, args);
if (config().showCode()) {
System.out.println(code);
IO.println(code);
}
var compilationUnit = backendBridge.compile(code);
if (compilationUnit.ok()) {
var kernel = compilationUnit.getKernel( kernelCallGraph.callDag.entryPoint.method().getName());
var kernel = compilationUnit.getKernel(kernelCallGraph.callDag.entryPoint.method().getName());
return new CompiledKernel(this, kernelCallGraph, kernel, args);
} else {
// TODO: We should capture the log from OpenCL and provide as exception message
Expand All @@ -65,8 +66,9 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, KernelContext kernel
compiledKernel.dispatch(kernelContext, args);
}

String createC99(KernelCallGraph kernelCallGraph, Object[] args){
return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(),kernelCallGraph.callDag.entryPoint.funcOp())), args);
String createC99(KernelCallGraph kernelCallGraph, Object[] args) {
kernelCallGraph.setDeviceVendor(backendBridge.getDeviceVendor());
return createCode(kernelCallGraph, new OpenCLHATKernelBuilder(kernelCallGraph, new ScopedCodeBuilderContext(kernelCallGraph.lookup(), kernelCallGraph.callDag.entryPoint.funcOp())), args);
}

}
Loading