Skip to content

Commit bf71d53

Browse files
committed
Simplify CreateKernel overloads, move shaders.h to utils/
1 parent f22c664 commit bf71d53

File tree

5 files changed

+34
-91
lines changed

5 files changed

+34
-91
lines changed

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ endif()
7272
# Build the library target (libgpu)
7373

7474
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
75-
set(SRC_LIB gpu.h nn/shaders.h utils/array_utils.h utils/logging.h)
75+
set(SRC_LIB gpu.h utils/shaders.h utils/array_utils.h utils/logging.h)
7676
add_library(gpu SHARED ${SRC_LIB})
7777
set_target_properties(gpu PROPERTIES LINKER_LANGUAGE CXX)
7878

7979
# For additional targets see directories under `examples/`, which have their own CMakeLists.txt
8080

8181
# Test of basic kernels
8282

83-
set(SRC_TESTS utils/test_kernels.cpp gpu.h nn/shaders.h utils/array_utils.h utils/logging.h)
83+
set(SRC_TESTS utils/test_kernels.cpp gpu.h utils/shaders.h utils/array_utils.h utils/logging.h)
8484
add_executable(run_tests ${SRC_TESTS})
8585
target_link_libraries(run_tests PRIVATE ${LIBDL} ${CMAKE_DL_LIBS} webgpu)
8686
target_include_directories(run_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})

README.md

+19-34
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ Here's an GELU kernel implemented (based on the CUDA implementation of
3030
invoked from the host using this library.
3131

3232
```
33+
#include "gpu.h"
3334
#include <array>
3435
#include <cstdio>
35-
#include "gpu.h"
36+
#include <future>
3637
3738
using namespace gpu; // CreateContext, CreateTensor, CreateKernel,
3839
// CreateShader, DispatchKernel, Wait, ToCPU
@@ -42,20 +43,22 @@ static const char *kGelu = R"(
4243
const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI)
4344
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
4445
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
46+
@group(0) @binding(1) var<storage, read_write> dummy: array<{{precision}}>;
4547
@compute @workgroup_size({{workgroupSize}})
4648
fn main(
4749
@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
4850
let i: u32 = GlobalInvocationID.x;
4951
if (i < arrayLength(&inp)) {
5052
let x: f32 = inp[i];
51-
// select is more stable for larger values of x
53+
// select is more stable than tanh for large x
5254
out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR
53-
* (x + .044715 * x * x * x))), x, x > 10.0);
55+
* (x + .044715 * x * x * x))), x, x > 10.0);
5456
}
5557
}
5658
)";
5759
5860
int main(int argc, char **argv) {
61+
printf("\nHello, gpu.cpp\n\n");
5962
Context ctx = CreateContext();
6063
static constexpr size_t N = 3072;
6164
std::array<float, N> inputArr, outputArr;
@@ -64,25 +67,28 @@ int main(int argc, char **argv) {
6467
}
6568
Tensor input = CreateTensor(ctx, Shape{N}, kf32, inputArr.data());
6669
Tensor output = CreateTensor(ctx, Shape{N}, kf32);
67-
Kernel op = CreateKernel(ctx, CreateShader(kGelu, 256, kf32), input, output);
68-
DispatchKernel(ctx, op);
69-
Wait(ctx, op.future);
70+
std::promise<void> promise;
71+
std::future<void> future = promise.get_future();
72+
Kernel op = CreateKernel(ctx, CreateShader(kGelu, 256, kf32), TensorList{input, output},
73+
/* nthreads */ {N, 1, 1});
74+
DispatchKernel(ctx, op, promise);
75+
Wait(ctx, future);
7076
ToCPU(ctx, output, outputArr.data(), sizeof(outputArr));
71-
for (int i = 0; i < 10; ++i) {
77+
for (int i = 0; i < 32; ++i) {
7278
printf("out[%d] : gelu(%.2f) = %.2f\n", i, inputArr[i], outputArr[i]);
7379
}
7480
printf("...\n\n");
7581
return 0;
7682
}
7783
```
7884

85+
This example is available in `examples/hello_world/run.cpp`.
86+
7987
For those curious about what happens under the hood with the raw WebGPU API,
8088
the equivalent functionality is implemented using the WebGPU C API in
8189
`examples/webgpu_intro/run.cpp`.
8290

83-
## Quick Start: Building and Running
84-
85-
*Tutorial App*
91+
## Quick Start: Dependencies and Installation
8692

8793
The only dependency of this library is a WebGPU implementation. Currently we
8894
recommend using the Dawn backend until further testing, but we plan to support
@@ -93,32 +99,11 @@ you can install cmake using [homebrew](https://brew.sh/) with: `brew install
9399
cmake`. On Ubuntu, you can install cmake using `apt-get` with: `sudo apt-get
94100
install cmake`.
95101

96-
The build is handled by cmake. Some useful common cmake invocations are wrapped
97-
in the convenience Makefile. To start you can try building a terminal demo
98-
tutorial which also tests the functionality of the library, this builds the
99-
demo tutorial in `run.cpp`:
100-
101-
```
102-
make demo
103-
```
104-
105-
You should see an introductory message:
106-
```
107-
____ _____ __ __ _________ ____
108-
/ __ `/ __ \/ / / // ___/ __ \/ __ \
109-
/ /_/ / /_/ / /_/ // /__/ /_/ / /_/ /
110-
\__, / .___/\__,_(_)___/ .___/ .___/
111-
/____/_/ /_/ /_/
112-
113-
================================================================================
114-
115-
Welcome!
116-
--------
102+
## Quick Start: Building and Running
117103

118-
This program is a brief intro to the gpu.cpp library.
119-
...
120104

121-
```
105+
The build is handled by cmake. Some useful common cmake invocations are wrapped
106+
in the convenience Makefile.
122107

123108
The first time you build and run the project, it will download the WebGPU
124109
backend implementation (Dawn by default) and build it which may take a few

gpu.h

+10-52
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,11 @@ struct KernelPool {
206206
KernelPool(Context *ctx) : ctx(ctx), data() {}
207207
Context *ctx;
208208
std::set<Kernel *> data;
209-
// std::set<MultiKernel *> multiData;
210209
~KernelPool() {
211210
// Note : Some kernel resources such as commandBuffer are harvested by
212211
// queue submission, explicitly destroying readback and callback buffers
213212
// produces runtime errors.
214213
data.clear();
215-
// multiData.clear();
216214
}
217215
};
218216

@@ -664,8 +662,6 @@ void ResetCommandBuffer(WGPUDevice &device, const Shape &nThreads, Kernel &op) {
664662
op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr);
665663
check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__);
666664
}
667-
// op.promise = std::promise<void>();
668-
// op.future = op.promise.get_future();
669665
}
670666

671667
/**
@@ -800,7 +796,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
800796
.entries = bindGroupEntries.data(),
801797
};
802798
op.bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc);
803-
804799
{
805800
WGPUPipelineLayoutDescriptor pipelineLayoutDesc = {
806801
.bindGroupLayoutCount = 1,
@@ -833,42 +828,6 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
833828
return op;
834829
}
835830

836-
/**
837-
* @brief Overload which wraps the CreateKernel factory function to create a
838-
* kernel on the GPU with a statically determined ParamsType instead of casting
839-
* params to a void pointer. paramSize is then determined by the size of the
840-
* ParamsType.
841-
*
842-
* @param[in] ctx Context instance to manage the kernel
843-
* @param[in] shader Shader code for the kernel
844-
* @param[in] inputs A span of input tensors as a pointer
845-
* @param[in] numInputs Number of input tensors, effectively the size of the
846-
* *inputs span.
847-
* @param[in] output Output tensor for the kernel
848-
* @param[in] nThreads Shape of the workgroup size for the kernel, must be of
849-
* rank 3.
850-
* @param[in] params Optional parameters for the kernel. If the kernel does not
851-
* have any parameters, use NoParam.
852-
* @example Kernel kernel = CreateKernel(ctx, shader, inputs, numInputs, output,
853-
* nThreads, params);
854-
*/
855-
template <typename ParamsType = NoParam>
856-
Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
857-
const Tensor *inputs, size_t numInputs,
858-
const Shape &nThreads,
859-
const ParamsType &params = ParamsType{}) {
860-
if constexpr (!IsNoParam<ParamsType>) {
861-
log(kDefLog, kInfo, "Using params of size %d bytes", sizeof(ParamsType));
862-
return CreateKernel(ctx, shader, inputs, numInputs, nThreads,
863-
reinterpret_cast<const void *>(&params),
864-
sizeof(ParamsType));
865-
} else {
866-
log(kDefLog, kInfo, "No params");
867-
return CreateKernel(ctx, shader, inputs, numInputs, nThreads,
868-
nullptr, 0);
869-
}
870-
}
871-
872831
/**
873832
* @brief Overload which wraps the CreateKernel factory function to create a
874833
* kernel on the GPU. This overload uses takes a static collection of input
@@ -892,17 +851,16 @@ Kernel CreateKernel(Context &ctx, const ShaderCode &shader,
892851
const TensorList<numInputs> &inputs,
893852
const Shape &nThreads,
894853
const ParamsType &params = ParamsType{}) {
895-
// first .data gets the array, second .data() gets the pointer
896-
return CreateKernel<ParamsType>(ctx, shader, inputs.data.data(), numInputs,
897-
nThreads, params);
898-
}
899-
900-
// Convenience wrapper: specialization for single input passed by reference
901-
template <typename ParamsType = NoParam>
902-
Kernel CreateKernel(Context &ctx, const ShaderCode &shader, const Tensor &input,
903-
const Shape &nThreads,
904-
const ParamsType &params = ParamsType{}) {
905-
return CreateKernel(ctx, shader, &input, 1, nThreads, params);
854+
if constexpr (!IsNoParam<ParamsType>) {
855+
log(kDefLog, kInfo, "Using params of size %d bytes", sizeof(ParamsType));
856+
return CreateKernel(ctx, shader, inputs.data.data(), numInputs, nThreads,
857+
reinterpret_cast<const void *>(&params),
858+
sizeof(ParamsType));
859+
} else {
860+
log(kDefLog, kInfo, "No params");
861+
return CreateKernel(ctx, shader, inputs.data.data(), numInputs, nThreads,
862+
nullptr, 0);
863+
}
906864
}
907865

908866
/**

nn/shaders.h utils/shaders.h

File renamed without changes.

utils/test_kernels.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
#include <memory>
44
#include <random>
55

6-
#include "array_utils.h"
76
#include "gpu.h"
8-
#include "nn/shaders.h"
9-
#include "reference_impls.h"
7+
#include "utils/array_utils.h"
8+
#include "utils/reference_impls.h"
109
#include "utils/logging.h"
10+
#include "utils/shaders.h"
1111

1212
using namespace gpu;
1313

0 commit comments

Comments
 (0)