From ba3271755fe3dfea4ae0b92f1f026c75077ff31c Mon Sep 17 00:00:00 2001 From: buttfa <1662332017@qq.com> Date: Fri, 27 Sep 2024 14:37:28 +0800 Subject: [PATCH 01/44] feat: 1. Modified the 'lib' rule in the Makefile file located in the root directory of the project, allowing it to generate gpu.cpp dynamic libraries with different suffixes based on the system. 2. By the build.py script, the header files in the gpu.hpp file are expanded in order to make gpu.hpp become true header-only source code. 3. 'install' and 'uninstall' rules are provided for scientific researchers who do not care about how to package applications, enabling them to quickly utilize gpu.cpp for gpu computation. --- Makefile | 47 +++++++++++++++++++++++++++++++++++++++++++++-- build.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 build.py diff --git a/Makefile b/Makefile index 3f69c6e..35c3b0f 100644 --- a/Makefile +++ b/Makefile @@ -19,8 +19,51 @@ pch: mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -x c++-header gpu.hpp -o build/gpu.hpp.pch # TODO(avh): change extension based on platform -lib: - mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/libgpucpp.dylib +# Get the current OS name +OS = $(shell uname | tr -d '\n') +LIB_PATH ?= /usr/lib +HEADER_PATH ?= /usr/include +# Set the specific variables for each platform +ifeq ($(OS), Linux) + OS_TYPE ?= Linux + + GPU_CPP_LIB_NAME ?= libgpucpp.so + DAWN_LIB_NAME ?= libdawn.so +else ifeq ($(OS), Darwin) + OS_TYPE ?= macOS + + GPU_CPP_LIB_NAME ?= libgpucpp.dylib + DAWN_LIB_NAME ?= libdawn.dylib +else + OS_TYPE ?= unknown +endif + +lib: check-clang dawnlib +ifneq ($(OS_TYPE), unknown) + mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/$(GPU_CPP_LIB_NAME) + python3 build.py + cp third_party/lib/$(DAWN_LIB_NAME) build/ +else + @echo "Unsupported operating system" +endif + +install: +ifneq ($(OS_TYPE), unknown) + cp build/$(GPU_CPP_LIB_NAME) $(LIB_PATH) + cp build/$(DAWN_LIB_NAME) $(LIB_PATH) + cp build/gpu.hpp $(HEADER_PATH) +else + @echo "Unsupported operating system" +endif + +uninstall: +ifneq ($(OS_TYPE), unknown) + rm $(LIB_PATH)/$(GPU_CPP_LIB_NAME) + rm $(LIB_PATH)/$(DAWN_LIB_NAME) + rm $(HEADER_PATH)/gpu.hpp +else + @echo "Unsupported operating system" +endif examples/hello_world/build/hello_world: check-clang dawnlib examples/hello_world/run.cpp check-linux-vulkan $(LIBSPEC) && cd examples/hello_world && make build/hello_world && ./build/hello_world diff --git a/build.py b/build.py new file mode 100644 index 0000000..ffb5e0d --- /dev/null +++ b/build.py @@ -0,0 +1,32 @@ +# Dictionary of header files and their relative paths +header_files = { + "#include \"webgpu/webgpu.h\"": "third_party/headers/webgpu/webgpu.h", + "#include \"numeric_types/half.hpp\"": "numeric_types/half.hpp", + "#include \"utils/logging.hpp\"": "utils/logging.hpp" +} + +def main(): + # File paths + source_file_path = "gpu.hpp" + output_file_path = "build/gpu.hpp" + + # Open source file and read contents + with open(source_file_path, "r") as source: + file_contents = source.read() + + # Ergodic over header files + for key, value in header_files.items(): + + # Replace header files + with open(value, "r") as header_file: + header_file_contents = header_file.read() + file_contents = file_contents.replace(key, header_file_contents) + + + # Open output file + with open(output_file_path, "w") as output: + # Write contents to output file + output.write(file_contents) + +if __name__ == "__main__": + main() \ No newline at end of file From 7addf836cfd28dd66d461dcff8e4a47c8b91b5d1 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 13 Oct 2024 10:32:00 +0900 Subject: [PATCH 02/44] Add kShaderMatmul2DTiling in kernels.h --- experimental/kernels/kernels.h | 98 ++++++++++ experimental/kernels/ops.cpp | 62 +++++-- .../unittest_llmc/unittest_kernels.cpp | 171 ++++++++++++++++-- 3 files changed, 299 insertions(+), 32 deletions(-) diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h index 212c075..1bce081 100644 --- a/experimental/kernels/kernels.h +++ b/experimental/kernels/kernels.h @@ -309,6 +309,104 @@ fn main(@builtin(global_invocation_id) global_id : vec3) { } } } + +)"; + + +static const char *kShaderMatmul2DTiling = R"( +@group(0) @binding(0) var inp : array<{{precision}}>; +@group(0) @binding(1) var weight : array<{{precision}}>; +@group(0) @binding(2) var bias : array<{{precision}}>; +@group(0) @binding(3) var out : array<{{precision}}>; +@group(0) @binding(4) var params : Params; +struct Params { + B: u32, + T: u32, + C: u32, + OC: u32, +}; +var tileInp: array<{{precision}}, {{BT}} * {{BC}}>; +var tileWeight: array<{{precision}}, {{BOC}} * {{BC}}>; + +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3) { + let B : u32 = params.B; + let T : u32 = params.T; + let C : u32 = params.C; + let OC : u32 = params.OC; + + var localT: array<{{precision}}, {{TT}}>; + var localOC: array<{{precision}}, {{TOC}}>; + + let outB: u32 = groupid.x; + let outT: u32 = groupid.y; + let outOC: u32 = groupid.z; + let numThread: u32 = ({{BT}} * {{BOC}}) / ({{TT}} * {{TOC}}); + + // position of the first c element computed by the thread + let threadRow: u32 = (localID.x / ({{BOC}} / {{TOC}})) * {{TT}}; + let threadCol: u32 = (localID.x % ({{BOC}} / {{TOC}})) * {{TOC}}; + + // inpPtr and weightPtr are the starting positions of the tiles in a and b, + // incremented in the bkidx loop. + // outPtr is the starting position of the tile in c which is fixed. + + var inpPtr = (outB * T + outT * {{BT}}) * C; // BTC + var weightPtr = outOC * {{BOC}} * C; //OCC + var threadResults: array<{{precision}}, {{TT}} * {{TOC}}>; + let outPtr = (outB * T + outT * {{BT}}) * OC + outOC * {{BOC}}; //BTOC + let biasPtr = outOC * {{BOC}}; + + for (var bkidx: u32 = 0; bkidx < C; bkidx += {{BC}}) { + // Load BC x BOC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BC * BOC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEW}}; idx++) { + tileWeight[localID.x + idx * numThread] = weight[weightPtr + ((localID.x + idx * numThread) / {{BC}}) * C + ((localID.x + idx * numThread) % {{BC}})]; + } + weightPtr += {{BC}}; + + // Load tile + // Load BT x BC by numThread(BT * BOC / (TT * TOC)) + // The number of iteration == BT * BC / (BT * BOC / (TT * TOC)) + for (var idx: u32 = 0; idx < {{NUM_TILEI}}; idx++) { + tileInp[localID.x + idx * numThread] = inp[inpPtr + ((localID.x + idx * numThread) / {{BC}}) * C + (localID.x + idx * numThread) % {{BC}}]; + } + inpPtr += {{BC}}; + + workgroupBarrier(); + // Compute tile + for (var dotIdx: u32 = 0; dotIdx < {{BC}}; dotIdx = dotIdx + 1) { + for (var idx: u32 = 0; idx < {{TT}}; idx++) { + localT[idx] = tileInp[(threadRow + idx) * {{BC}} + dotIdx]; + } + for (var idx: u32 = 0; idx < {{TOC}}; idx++) { + localOC[idx] = tileWeight[(threadCol + idx) * {{BC}} + dotIdx]; + } + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + threadResults[resIdxT * {{TOC}} + resIdxOC] += localT[resIdxT] * localOC[resIdxOC]; + } + } + } + workgroupBarrier(); + } + + if (arrayLength(&bias) == 1) { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC]; + } + } + } else { + for (var resIdxT: u32 = 0; resIdxT < {{TT}}; resIdxT++) { + for (var resIdxOC: u32 = 0; resIdxOC < {{TOC}}; resIdxOC++) { + out[outPtr + (threadRow + resIdxT) * OC + threadCol + resIdxOC] = threadResults[resIdxT * {{TOC}} + resIdxOC] + bias[biasPtr + threadCol + resIdxOC]; + } + } + } +} )"; static const char *kShaderMatmulBackward = R"( diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 67fc679..48edf9b 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -6,6 +6,7 @@ #include "kernels.h" #include "ops.hpp" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; @@ -178,18 +179,55 @@ void matmul_forward(Context& ctx, float* out, std::promise promise; std::future future = promise.get_future(); assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); + int version = 1; + if (version == 1){ + static constexpr size_t BT = 64; + static constexpr size_t BC = 8; + static constexpr size_t BOC = 64; + static constexpr size_t TT = BT / BC; + static constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; // This is the same as BK * BK. + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string codeString(kShaderMatmul2DTiling); + replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }); + std::string unrolledCode = loopUnrolling(codeString); + Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } else { + Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + } toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 37cdcaf..9278163 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -5,6 +5,7 @@ #include "kernels.h" #include "unittest_llmc/unittest_kernels.h" +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; // createContext, createTensor, createKernel, // createShader, dispatchKernel, wait, toCPU @@ -51,6 +52,28 @@ using namespace gpu; // createContext, createTensor, createKernel, } \ } +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count())); + } + } +}; + void ENCODER_FORWARD_GPU(float* out, int* inp, float* wte, float* wpe, int B, int T, int C){ @@ -202,9 +225,25 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +void matmul_forward_dummy(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC); + void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + int version = 2; + bool verbose = false; + bool debug = false; + float *out_exp; + DurationTime duration("matmul_forward_gpu with creating context", verbose); + if (verbose) { + printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL); + } + if (debug) { + out_exp = new float[B*T*OC]; + matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + } struct MatmulParams { uint32_t B; uint32_t T; @@ -221,26 +260,118 @@ void MATMUL_FORWARD_GPU(float* out, .requiredLimits = &requiredLimits }); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - std::promise promise; - std::future future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + { + DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + std::promise promise; + std::future future = promise.get_future(); + + if (version == 2) { + DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); + static constexpr size_t BT = 64; + static constexpr size_t BC = 16; + static constexpr size_t BOC = 64; + static constexpr size_t TT = BT / BC; + static constexpr size_t TOC = BOC / BC; + static constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string codeString(kShaderMatmul2DTiling); + replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }); + std::string unrolledCode = loopUnrolling(codeString); + { + DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); + + Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + { + DurationTime duration("matmul_forward_gpu without creating context", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } + } + } else if (version == 1) { + Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); + { + DurationTime duration("matmul_forward_gpu without creating context", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } + } else { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); + } + } + + if (debug) { // compare out with out_exp. + for (int i = 0; i < B*T*OC; i++) { + if (fabs(out[i] - out_exp[i]) > 1e-2) { + printf("matmul forward: out[%d] = %f, out_exp[%d] = %f\n", i, out[i], i, out_exp[i]); + //Dump the first 4 x 4 elements by table, at first output out, then output out_exp + printf("inp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", inp[j * C + k]); + } + printf("\n"); + } + printf("weight:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", weight[j * OC + k]); + } + printf("\n"); + } + printf("out:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out[j * OC + k]); + } + printf("\n"); + } + printf("out_exp:\n"); + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 4; k++) { + printf("%f ", out_exp[j * OC + k]); + } + printf("\n"); + } + exit(1); + } + } + delete[] out_exp; + } } void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, From da1f32d2ce43001b774b05158d560f402173b33e Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 13 Oct 2024 10:51:59 +0900 Subject: [PATCH 03/44] Reduce matmul-kernel creation time --- experimental/kernels/Makefile | 2 +- experimental/kernels/ops.cpp | 46 +++++++++++++++++++---------------- gpu.hpp | 4 ++- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index c233ef5..3f6f763 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -95,7 +95,7 @@ build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< -build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c +build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp mkdir -p build $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 48edf9b..b8ab6ac 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -157,6 +157,29 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } +static constexpr size_t MATMUL_BT = 64; +static constexpr size_t MATMUL_BC = 8; +static constexpr size_t MATMUL_BOC = 64; +static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; +static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; +static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); +static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; +static std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); +static std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(MATMUL_BT)}, + {"{{BC}}", toString(MATMUL_BC)}, + {"{{BOC}}", toString(MATMUL_BOC)}, + {"{{TT}}", toString(MATMUL_TT)}, + {"{{TOC}}", toString(MATMUL_TOC)}, + {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, + {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} + }) + ) + ); + + void matmul_forward(Context& ctx, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -181,27 +204,8 @@ void matmul_forward(Context& ctx, float* out, assert ( (b*t) % 256 == 0 ); int version = 1; if (version == 1){ - static constexpr size_t BT = 64; - static constexpr size_t BC = 8; - static constexpr size_t BOC = 64; - static constexpr size_t TT = BT / BC; - static constexpr size_t TOC = BOC / BC; - size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; // This is the same as BK * BK. - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; - - std::string codeString(kShaderMatmul2DTiling); - replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - }); - std::string unrolledCode = loopUnrolling(codeString); - Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ diff --git a/gpu.hpp b/gpu.hpp index 83fc94b..8b68463 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -415,12 +415,14 @@ struct KernelCode { * @endcode * "f32"}}); */ -inline void +inline const std::string replaceAll(std::string &str, const std::vector> &reps) { for (const auto &rep : reps) { replaceAll(str, rep.first, rep.second); } + + return str; } /** From dd2a25f530dabb784456e44cf5b1f5049f626bea Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Mon, 14 Oct 2024 04:24:37 +0900 Subject: [PATCH 04/44] Change Kernel to shared_ptr to support cached kernels --- experimental/kernels/Makefile | 4 +- experimental/kernels/ops.cpp | 34 ++++- .../unittest_llmc/unittest_kernels.cpp | 135 ++++++++++-------- gpu.hpp | 135 +++++++++++++----- 4 files changed, 215 insertions(+), 93 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index 3f6f763..5817e23 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -79,7 +79,7 @@ endef build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build @@ -90,7 +90,7 @@ build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bi build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index b8ab6ac..b4b7555 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -179,10 +179,34 @@ static std::string kShaderMatmul2D(loopUnrolling( ) ); +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count())); + } + } +}; + void matmul_forward(Context& ctx, float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ + bool verbose = false; + DurationTime duration("matmul_forward_gpu", verbose); struct MatmulParams { uint32_t B; uint32_t T; @@ -204,6 +228,7 @@ void matmul_forward(Context& ctx, float* out, assert ( (b*t) % 256 == 0 ); int version = 1; if (version == 1){ + DurationTime duration("matmul_forward_gpu without creating tensors", verbose); Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, @@ -215,9 +240,12 @@ void matmul_forward(Context& ctx, float* out, static_cast(c), static_cast(oc) }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + { + DurationTime duration("matmul_forward_gpu without creating kernel", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + } } else { Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 9278163..4593bf5 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "kernels.h" #include "unittest_llmc/unittest_kernels.h" @@ -229,6 +230,31 @@ void matmul_forward_dummy(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC); +static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; +static Context ctx = createContext({},{},{ + .requiredLimits = &requiredLimits + }); + +static constexpr size_t BT = 64; +static constexpr size_t BC = 16; +static constexpr size_t BOC = 64; +static constexpr size_t TT = BT / BC; +static constexpr size_t TOC = BOC / BC; +static constexpr size_t num_threads = BT * BOC / (TT * TOC); +static Shape wgSize = {num_threads, 1, 1}; + +static std::string codeString(kShaderMatmul2DTiling); +static std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -255,55 +281,52 @@ void MATMUL_FORWARD_GPU(float* out, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); { DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + // Generate the key of the cache by arguments. + std::string key = std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + std::promise promise; std::future future = promise.get_future(); if (version == 2) { DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); - static constexpr size_t BT = 64; - static constexpr size_t BC = 16; - static constexpr size_t BOC = 64; - static constexpr size_t TT = BT / BC; - static constexpr size_t TOC = BOC / BC; - static constexpr size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; - - std::string codeString(kShaderMatmul2DTiling); - replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - }); - std::string unrolledCode = loopUnrolling(codeString); { DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); - - Kernel op = createKernel(ctx, {unrolledCode, wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); { DurationTime duration("matmul_forward_gpu without creating context", verbose); dispatchKernel(ctx, op, promise); @@ -311,23 +334,23 @@ void MATMUL_FORWARD_GPU(float* out, toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } } - } else if (version == 1) { - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - { - DurationTime duration("matmul_forward_gpu without creating context", verbose); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - } + // } else if (version == 1) { + // Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + // Bindings{inp_i, weight_i, bias_i, out_o}, + // /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + // /* params */ + // MatmulParams{ + // static_cast(b), + // static_cast(t), + // static_cast(c), + // static_cast(oc) + // }); + // { + // DurationTime duration("matmul_forward_gpu without creating context", verbose); + // dispatchKernel(ctx, op, promise); + // wait(ctx, future); + // toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); + // } } else { DurationTime duration("matmul_forward_cpu", verbose); matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); diff --git a/gpu.hpp b/gpu.hpp index 8b68463..9c34d78 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -454,7 +454,7 @@ struct CopyData { * The struct members can be divided into "consumed upon dispatch" * (commandBuffer) and reusable ahead-of-time setup (all other members). */ -struct Kernel { +struct RawKernel { std::unique_ptr buffers; // non-owning std::unique_ptr bufferSizes; size_t numBindings; @@ -462,8 +462,11 @@ struct Kernel { WGPUBindGroup bindGroup; // persists between submission WGPUComputePipeline computePipeline; // persists between submission WGPUCommandBuffer commandBuffer; // destroyed upon submission + bool used; }; +typedef std::shared_ptr Kernel; + /** * @brief A struct to package the result of a WGSL code compilation. @@ -483,7 +486,7 @@ struct CompilationInfo { * @return True if lhs < rhs, false otherwise */ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { - return lhs.commandBuffer < rhs.commandBuffer; + return lhs->commandBuffer < rhs->commandBuffer; } /** @@ -494,7 +497,7 @@ inline bool operator<(const Kernel &lhs, const Kernel &rhs) { struct KernelPool { inline KernelPool(Context *ctx) : ctx(ctx), data() {} Context *ctx; - std::set data; + std::unordered_map data; inline ~KernelPool() { // Note : Some kernel resources such as commandBuffer are harvested by // queue submission, explicitly destroying readback and callback buffers @@ -1080,6 +1083,57 @@ void toCPU(Context &ctx, Tensor &tensor, std::array &data) { toCPU(ctx, tensor, data.data(), sizeof(data)); } +inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, + size_t size) { + uint64_t bufferSize = size; + CopyData op; + op.future = op.promise.get_future(); + { + WGPUBufferDescriptor readbackBufferDescriptor = { + .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, + .size = bufferSize, + }; + op.readbackBuffer = + wgpuDeviceCreateBuffer(ctx.device, &readbackBufferDescriptor); + } + { + WGPUCommandEncoder commandEncoder; + WGPUComputePassEncoder computePassEncoder; + commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); + wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, + op.readbackBuffer, 0, bufferSize); + op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); + } + wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, + &op.future}; + wgpuQueueOnSubmittedWorkDone( + ctx.queue, + [](WGPUQueueWorkDoneStatus status, void *callbackData) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(callbackData); + wgpuBufferMapAsync( + data->buffer, WGPUMapMode_Read, 0, data->bufferSize, + [](WGPUBufferMapAsyncStatus status, void *captureData) { + const auto *data = static_cast(captureData); + check(status == WGPUBufferMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + callbackData); + }, + &callbackData); + wait(ctx, op.future); +} + + /** * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are * effectively a convenience wrapper around the WebGPU API call @@ -1126,8 +1180,8 @@ inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. // If a kernel does not have parameters this will quietly overwrite // the last buffer in the bind group with the parameters buffer. - if (op.numBindings > 0) { - wgpuQueueWriteBuffer(ctx.queue, op.buffers[op.numBindings - 1], 0, + if (op->numBindings > 0) { + wgpuQueueWriteBuffer(ctx.queue, op->buffers[op->numBindings - 1], 0, static_cast(¶ms), sizeof(params)); } } @@ -1150,14 +1204,15 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { wgpuDeviceCreateCommandEncoder(device, nullptr); WGPUComputePassEncoder computePassEncoder = wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr); - wgpuComputePassEncoderSetPipeline(computePassEncoder, op.computePipeline); - wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op.bindGroup, 0, + wgpuComputePassEncoderSetPipeline(computePassEncoder, op->computePipeline); + wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, op->bindGroup, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( - computePassEncoder, op.totalWorkgroups[0], op.totalWorkgroups[1], - op.totalWorkgroups[2]); + computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1], + op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); - op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + op->used = false; } } @@ -1219,11 +1274,19 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, const size_t *viewOffsets, const Shape &totalWorkgroups, const void *params = nullptr, size_t paramsSize = 0, - CompilationInfo* compilationInfo = nullptr) { + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr) { + // Create a cache key by the pointer values of the data bindings and the kernel code + if (cacheKey != nullptr && ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) { + LOG(kDefLog, kInfo, "Kernel cache hit"); + return ctx.kernelPool.data[cacheKey]; + } + assert(totalWorkgroups.rank == 3); WGPUDevice device = ctx.device; WGPUQueue queue = ctx.queue; - Kernel op; + Kernel op(new RawKernel()); + // paramIndex is the index into bgLayoutEntries for the parameters buffer If // there are no parameters for the kernel, paramsSize == 0 and paramIndex is // effectively undefined (== -1) @@ -1236,9 +1299,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, // op.buffers, op.bufferSizes and // bgLayoutEntries } - op.buffers = std::make_unique(numBindings); - op.bufferSizes = std::make_unique(numBindings); - op.numBindings = numBindings; + op->buffers = std::make_unique(numBindings); + op->bufferSizes = std::make_unique(numBindings); + op->numBindings = numBindings; std::vector bgLayoutEntries(numBindings); // Create layout entries for input buffers for (size_t i = 0; i < numTensors; ++i) { @@ -1272,8 +1335,8 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, WGPUBindGroupLayout bgLayout = wgpuDeviceCreateBindGroupLayout(device, &bgLayoutDesc); for (size_t i = 0; i < numTensors; ++i) { - op.buffers[i] = dataBindings[i].data.buffer; - op.bufferSizes[i] = dataBindings[i].data.size; + op->buffers[i] = dataBindings[i].data.buffer; + op->bufferSizes[i] = dataBindings[i].data.size; } // Create a buffer for the Params struct if (paramsSize > 0) { @@ -1282,9 +1345,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .size = paramsSize, .mappedAtCreation = false, }; - op.buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); - op.bufferSizes[paramIndex] = paramsSize; - wgpuQueueWriteBuffer(queue, op.buffers[paramIndex], 0, params, paramsSize); + op->buffers[paramIndex] = wgpuDeviceCreateBuffer(device, ¶msBufferDesc); + op->bufferSizes[paramIndex] = paramsSize; + wgpuQueueWriteBuffer(queue, op->buffers[paramIndex], 0, params, paramsSize); LOG(kDefLog, kTrace, "Params buffer written"); } else { LOG(kDefLog, kTrace, "No params buffer needed"); @@ -1293,9 +1356,9 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, for (size_t i = 0; i < numTensors; ++i) { bindGroupEntries[i] = WGPUBindGroupEntry{ .binding = static_cast(i), - .buffer = op.buffers[i], + .buffer = op->buffers[i], .offset = viewOffsets[i], - .size = op.bufferSizes[i], + .size = op->bufferSizes[i], }; } if (paramsSize > 0) { @@ -1303,7 +1366,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, LOG(kDefLog, kInfo, "paramIndex: %d", paramIndex); bindGroupEntries[paramIndex] = WGPUBindGroupEntry{ .binding = static_cast(paramIndex), - .buffer = op.buffers[paramIndex], + .buffer = op->buffers[paramIndex], .offset = 0, .size = paramsSize, }; @@ -1314,7 +1377,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, .entryCount = static_cast(numBindings), .entries = bindGroupEntries.data(), }; - op.bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); + op->bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); WGPUPipelineLayoutDescriptor pipelineLayoutDesc = { .bindGroupLayoutCount = 1, @@ -1336,12 +1399,13 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, computePipelineDesc.compute.entryPoint = code.entryPoint.c_str(); computePipelineDesc.label = code.label.c_str(); - op.computePipeline = + op->computePipeline = wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); - op.totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; + op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; resetCommandBuffer(device, op); - ctx.kernelPool.data.insert(&op); - + if (cacheKey != nullptr) + ctx.kernelPool.data[cacheKey]=op; + WGPUCompilationInfoCallback cb = [](WGPUCompilationInfoRequestStatus status, WGPUCompilationInfo const *compilationInfo, void *userData) { @@ -1396,17 +1460,20 @@ Kernel createKernel(Context &ctx, const KernelCode &code, const Bindings &dataBindings, const Shape &totalWorkgroups, const ParamsType ¶ms = ParamsType{}, - CompilationInfo* compilationInfo = nullptr + CompilationInfo* compilationInfo = nullptr, + const char* cacheKey = nullptr ) { if constexpr (!IsNoParam) { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, reinterpret_cast(¶ms), - sizeof(ParamsType), compilationInfo); + sizeof(ParamsType), compilationInfo, + cacheKey); } else { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, nullptr, - 0, compilationInfo); + 0, compilationInfo, + cacheKey); } } @@ -1431,7 +1498,11 @@ Kernel createKernel(Context &ctx, const KernelCode &code, inline void dispatchKernel(Context &ctx, Kernel &kernel, std::promise &promise) { // Submit the command buffer - wgpuQueueSubmit(ctx.queue, 1, &kernel.commandBuffer); + if (kernel->used) { + resetCommandBuffer(ctx.device, kernel); + } + wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); + kernel->used = true; wgpuQueueOnSubmittedWorkDone( ctx.queue, [](WGPUQueueWorkDoneStatus status, void *data) { From 9fff8cd39a9e89a7b4633f97b788a67c232d0658 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 17:43:16 +0900 Subject: [PATCH 05/44] Add caches for ops.cpp --- experimental/kernels/ops.cpp | 639 ++++++++++++++++++++++++----------- gpu.hpp | 5 + 2 files changed, 448 insertions(+), 196 deletions(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index b4b7555..6390314 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -23,27 +23,39 @@ void encoder_forward(Context& ctx, float* out, uint32_t C; }; setLogLevel(kError); - printf("Creating tensors\n"); - printf("Creating input tensor\%pn", inp); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - printf("Created input tensor\n"); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - printf("Created wte tensor\n"); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - printf("Created wpe tensor\n"); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); - printf("Created tensors\n"); + // Generate the key of the cache by arguments. + std::string key = "encoder_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -62,21 +74,38 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + // Generate the key of the cache by arguments. + std::string key = "encoder_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -95,23 +124,43 @@ void layernorm_forward(Context& ctx, float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + // Generate the key of the cache by arguments. + std::string key = "layernorm_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -131,25 +180,49 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + // Generate the key of the cache by arguments. + std::string key = "layernorm_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -219,46 +292,51 @@ void matmul_forward(Context& ctx, float* out, unsigned long oc = static_cast(OC); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32, weight); - Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32, bias); - Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + // Generate the key of the cache by arguments. + std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); + Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); + Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); + op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& weight_i = ctx.pool.data[op->buffers[1]]; + Tensor& bias_i = ctx.pool.data[op->buffers[2]]; + Tensor& out_o = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, weight, weight_i); + if (bias != NULL) { + toGPU(ctx, bias, bias_i); + } + std::promise promise; std::future future = promise.get_future(); - assert ( (b*t) % 256 == 0 ); - int version = 1; - if (version == 1){ + { DurationTime duration("matmul_forward_gpu without creating tensors", verbose); - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; - Kernel op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); { DurationTime duration("matmul_forward_gpu without creating kernel", verbose); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } - } else { - Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); - dispatchKernel(ctx, op, promise); - wait(ctx, future); } toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } @@ -277,24 +355,44 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + // Generate the key of the cache by arguments. + std::string key = "matmul_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -316,22 +414,38 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + // Generate the key of the cache by arguments. + std::string key = "attention_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -353,24 +467,44 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + // Generate the key of the cache by arguments. + std::string key = "attention_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -381,13 +515,28 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, void gelu_forward(Context& ctx, float* out, float* inp, int n) { unsigned long N = static_cast(n); setLogLevel(kError); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + // Generate the key of the cache by arguments. + std::string key = "gelu_forward_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -396,14 +545,31 @@ void gelu_forward(Context& ctx, float* out, float* inp, int n) { void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + // Generate the key of the cache by arguments. + std::string key = "gelu_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -412,14 +578,31 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + // Generate the key of the cache by arguments. + std::string key = "residual_forward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -428,14 +611,30 @@ void residual_forward(Context& ctx, float* out, float* inp1, float* inp2, int N) void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + // Generate the key of the cache by arguments. + std::string key = "residual_backward_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -452,14 +651,28 @@ void softmax_forward(Context& ctx, float* probs, float* logits, int B, int T, in uint32_t t = static_cast(T); uint32_t c = static_cast(V); uint32_t cp = static_cast(Vp); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + // Generate the key of the cache by arguments. + std::string key = "softmax_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise promise; std::future future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -477,20 +690,36 @@ void crossentropy_forward(Context& ctx, float* losses, unsigned long t = static_cast(T); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast(b), + static_cast(t), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast(b), - static_cast(t), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -510,22 +739,40 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits, unsigned long v = static_cast(V); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + // Generate the key of the cache by arguments. + std::string key = "crossentropy_softmax_backward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast(b), + static_cast(t), + static_cast(v), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast(b), - static_cast(t), - static_cast(v), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); diff --git a/gpu.hpp b/gpu.hpp index 9c34d78..8f8bd9b 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1175,6 +1175,11 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) { tensor.data.size); } +inline void toGPU(Context &ctx, const int *data, Tensor &tensor) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, + tensor.data.size); +} + template inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. From b9fb38bb6d4f8e5abc8783c11f3c7ea6e5d059a2 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 18:10:59 +0900 Subject: [PATCH 06/44] Add caches for unittests.cpp --- .../unittest_llmc/unittest_kernels.cpp | 652 ++++++++++++------ 1 file changed, 446 insertions(+), 206 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 4593bf5..a883124 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -75,6 +75,11 @@ struct DurationTime { } }; +static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; +static Context ctx = createContext({},{},{ + .requiredLimits = &requiredLimits + }); + void ENCODER_FORWARD_GPU(float* out, int* inp, float* wte, float* wpe, int B, int T, int C){ @@ -88,25 +93,40 @@ void ENCODER_FORWARD_GPU(float* out, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); - Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32, wte); - Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32, wpe); - Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + Tensor wte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor wpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor output = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{input, wte_t, wpe_t, output}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& wte_t = ctx.pool.data[op->buffers[1]]; + Tensor& wpe_t = ctx.pool.data[op->buffers[2]]; + Tensor& output = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, input); + toGPU(ctx, wte, wte_t); + toGPU(ctx, wpe, wpe_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoder, 256, kf32}, - Bindings{input, wte_t, wpe_t, output}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, b * t * c * sizeof(float)); @@ -125,25 +145,39 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe, uint32_t C; }; setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32, dwte); - Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32, dwpe); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor input = createTensor(ctx, Shape{b * t}, ki32, inp); + + // Generate the key of the cache by arguments. + std::string key = "ENCODER_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dwte_t = createTensor(ctx, Shape{v, c}, kf32); + Tensor dwpe_t = createTensor(ctx, Shape{t, c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor input = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte_t, dwpe_t, dout_t, input}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dwte_t = ctx.pool.data[op->buffers[0]]; + Tensor& dwpe_t = ctx.pool.data[op->buffers[1]]; + Tensor& dout_t = ctx.pool.data[op->buffers[2]]; + Tensor& input = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, - Bindings{dwte_t, dwpe_t, dout_t, input}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - EncoderParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dwte_t, dwte, v * c * sizeof(float)); @@ -162,24 +196,44 @@ void LAYERNORM_FORWARD_GPU(float* out, float* mean, float* rstd, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor bias_t = createTensor(ctx, Shape{c}, kf32, bias); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor bias_t = createTensor(ctx, Shape{c}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& weight_t = ctx.pool.data[op->buffers[1]]; + Tensor& bias_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + Tensor& mean_t = ctx.pool.data[op->buffers[4]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, bias, bias_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm, 256, kf32}, - Bindings{inp_t, weight_t, bias_t, out_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_t, out, b * t * c * sizeof(float)); @@ -199,26 +253,50 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, uint32_t C; }; setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{c}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{c}, kf32, weight); - Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32, mean); - Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32, rstd); + + // Generate the key of the cache by arguments. + std::string key = "LAYERNORM_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{c}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{c}, kf32); + Tensor mean_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor rstd_t = createTensor(ctx, Shape{b * t}, kf32); + op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + Tensor& mean_t = ctx.pool.data[op->buffers[6]]; + Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + toGPU(ctx, mean, mean_t); + toGPU(ctx, rstd, rstd_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t, mean_t, rstd_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - LayerNormParams{ - static_cast(b), - static_cast(t), - static_cast(c) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -230,28 +308,24 @@ void matmul_forward_dummy(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC); -static WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; -static Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); -static constexpr size_t BT = 64; -static constexpr size_t BC = 16; -static constexpr size_t BOC = 64; -static constexpr size_t TT = BT / BC; -static constexpr size_t TOC = BOC / BC; -static constexpr size_t num_threads = BT * BOC / (TT * TOC); -static Shape wgSize = {num_threads, 1, 1}; - -static std::string codeString(kShaderMatmul2DTiling); -static std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} +static constexpr size_t MATMUL_BT = 64; +static constexpr size_t MATMUL_BC = 16; +static constexpr size_t MATMUL_BOC = 64; +static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; +static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; +static constexpr size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); +static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; + +static std::string MATMUL_codeString(kShaderMatmul2DTiling); +static std::string MATMUL_unrolledCode = loopUnrolling(replaceAll(MATMUL_codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(MATMUL_BT)}, + {"{{BC}}", toString(MATMUL_BC)}, + {"{{BOC}}", toString(MATMUL_BOC)}, + {"{{TT}}", toString(MATMUL_TT)}, + {"{{TOC}}", toString(MATMUL_TOC)}, + {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, + {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} })); @@ -285,15 +359,15 @@ void MATMUL_FORWARD_GPU(float* out, { DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); // Generate the key of the cache by arguments. - std::string key = std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + op = createKernel(ctx, {MATMUL_unrolledCode, MATMUL_wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ @@ -411,28 +485,45 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, unsigned long c = static_cast(C); unsigned long oc = static_cast(OC); setLogLevel(kError); - WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; - Context ctx = createContext({},{},{ - .requiredLimits = &requiredLimits - }); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32, dinp); - Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32, dweight); - Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32, dbias); - Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32, inp); - Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32, weight); + + // Generate the key of the cache by arguments. + std::string key = "MATMUL_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor dweight_t = createTensor(ctx, Shape{oc * c}, kf32); + Tensor dbias_t = createTensor(ctx, Shape{oc}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * oc}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor weight_t = createTensor(ctx, Shape{oc * c}, kf32); + op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dweight_t = ctx.pool.data[op->buffers[1]]; + Tensor& dbias_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, weight, weight_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, - Bindings{dinp_t, dweight_t, dbias_t, dout_t, inp_t, weight_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * sizeof(float)); @@ -454,23 +545,39 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, preatt); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); - Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor preatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor out_t = createTensor(ctx, Shape{b * t * c}, kf32); + op = createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp_t, preatt_t, att_t, out_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_t = ctx.pool.data[op->buffers[0]]; + Tensor& preatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& att_t = ctx.pool.data[op->buffers[2]]; + Tensor& out_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, inp, inp_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttention, 256, kf32}, - Bindings{inp_t, preatt_t, att_t, out_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, preatt_t, preatt, b * nh * t * t * sizeof(float)); @@ -492,25 +599,45 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, unsigned long c = static_cast(C); unsigned long nh = static_cast(NH); setLogLevel(kError); - Context ctx = createContext(); - Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, dinp); - Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, dpreatt); - Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, datt); - Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32, dout); - Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32, inp); - Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32, att); + + // Generate the key of the cache by arguments. + std::string key = "ATTENTION_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(NH); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dinp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor dpreatt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor datt_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + Tensor dout_t = createTensor(ctx, Shape{b * t * c}, kf32); + Tensor inp_t = createTensor(ctx, Shape{b * t * c * 3}, kf32); + Tensor att_t = createTensor(ctx, Shape{b * nh * t * t}, kf32); + op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dinp_t = ctx.pool.data[op->buffers[0]]; + Tensor& dpreatt_t = ctx.pool.data[op->buffers[1]]; + Tensor& datt_t = ctx.pool.data[op->buffers[2]]; + Tensor& dout_t = ctx.pool.data[op->buffers[3]]; + Tensor& inp_t = ctx.pool.data[op->buffers[4]]; + Tensor& att_t = ctx.pool.data[op->buffers[5]]; + + toGPU(ctx, dout, dout_t); + toGPU(ctx, inp, inp_t); + toGPU(ctx, att, att_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, - Bindings{dinp_t, dpreatt_t, datt_t, dout_t, inp_t, att_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - AttentionParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(nh) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_t, dinp, b * t * c * 3 * sizeof(float)); @@ -521,14 +648,29 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, void GELU_FORWARD_GPU(float* out, float* inp, int n) { unsigned long N = static_cast(n); setLogLevel(kError); - Context ctx = createContext(); - Tensor input = createTensor(ctx, Shape{N}, kf32, inp); - Tensor output = createTensor(ctx, Shape{N}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "GELU_FORWARD_GPU_" + std::to_string(n); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, Shape{N}, kf32); + Tensor output = createTensor(ctx, Shape{N}, kf32); + op = createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{input, output}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, inp, input); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGelu, 256, kf32}, - Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, out, N * sizeof(float)); @@ -537,15 +679,32 @@ void GELU_FORWARD_GPU(float* out, float* inp, int n) { void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp_i = createTensor(ctx, Shape{n}, kf32, inp); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp_o = createTensor(ctx, Shape{n}, kf32, dinp); + + // Generate the key of the cache by arguments. + std::string key = "GELU_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp_i = createTensor(ctx, Shape{n}, kf32); + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp_i, dout_i, dinp_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp_i = ctx.pool.data[op->buffers[0]]; + Tensor& dout_i = ctx.pool.data[op->buffers[1]]; + Tensor& dinp_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp, inp_i); + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderGeluBackward, 256, kf32}, - Bindings{inp_i, dout_i, dinp_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp_o, dinp, n * sizeof(float)); @@ -554,15 +713,32 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor inp1_i = createTensor(ctx, Shape{n}, kf32, inp1); - Tensor inp2_i = createTensor(ctx, Shape{n}, kf32, inp2); - Tensor out_o = createTensor(ctx, Shape{n}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_FORWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor inp1_i = createTensor(ctx, Shape{n}, kf32); + Tensor inp2_i = createTensor(ctx, Shape{n}, kf32); + Tensor out_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1_i, inp2_i, out_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& inp1_i = ctx.pool.data[op->buffers[0]]; + Tensor& inp2_i = ctx.pool.data[op->buffers[1]]; + Tensor& out_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, inp1, inp1_i); + toGPU(ctx, inp2, inp2_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidual, 256, kf32}, - Bindings{inp1_i, inp2_i, out_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, out_o, out, n * sizeof(float)); @@ -571,15 +747,31 @@ void RESIDUAL_FORWARD_GPU(float* out, float* inp1, float* inp2, int N){ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){ unsigned long n = static_cast(N); setLogLevel(kError); - Context ctx = createContext(); - Tensor dout_i = createTensor(ctx, Shape{n}, kf32, dout); - Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32, dinp1); - Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32, dinp2); + + // Generate the key of the cache by arguments. + std::string key = "RESIDUAL_BACKWARD_GPU_" + std::to_string(N); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dout_i = createTensor(ctx, Shape{n}, kf32); + Tensor dinp1_o = createTensor(ctx, Shape{n}, kf32); + Tensor dinp2_o = createTensor(ctx, Shape{n}, kf32); + op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout_i, dinp1_o, dinp2_o}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}, + nullptr, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dout_i = ctx.pool.data[op->buffers[0]]; + Tensor& dinp1_o = ctx.pool.data[op->buffers[1]]; + Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, dout, dout_i); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderResidualBackward, 256, kf32}, - Bindings{dout_i, dinp1_o, dinp2_o}, - /* nWorkgroups */ {cdiv(n, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dinp1_o, dinp1, n * sizeof(float)); @@ -596,15 +788,29 @@ void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int B, int T, int V, int V uint32_t t = static_cast(T); uint32_t c = static_cast(V); uint32_t cp = static_cast(Vp); - Context ctx = createContext(); - Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits); - Tensor output = createTensor(ctx, {b * t, cp}, kf32); + + // Generate the key of the cache by arguments. + std::string key = "SOFTMAX_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor input = createTensor(ctx, {b * t, cp}, kf32); + Tensor output = createTensor(ctx, {b * t, cp}, kf32); + assert( (B*T) % 256 == 0); + op = createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& input = ctx.pool.data[op->buffers[0]]; + Tensor& output = ctx.pool.data[op->buffers[1]]; + + toGPU(ctx, logits, input); + std::promise promise; std::future future = promise.get_future(); - assert( (B*T) % 256 == 0); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, probs, sizeof(float)*b*t*cp); @@ -622,21 +828,37 @@ void CROSSENTROPY_FORWARD_GPU(float* losses, unsigned long t = static_cast(T); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32, losses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor losses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast(b), + static_cast(t), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& losses_t = ctx.pool.data[op->buffers[0]]; + Tensor& probs_t = ctx.pool.data[op->buffers[1]]; + Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, - Bindings{losses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropyParams{ - static_cast(b), - static_cast(t), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, losses_t, losses, b * t * sizeof(float)); @@ -656,23 +878,41 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits, unsigned long v = static_cast(V); unsigned long vp = static_cast(Vp); setLogLevel(kError); - Context ctx = createContext(); - Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32, dlogits); - Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32, dlosses); - Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32, probs); - Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32, targets); + + // Generate the key of the cache by arguments. + std::string key = "CROSSENTROPY_SOFTMAX_BACKWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(V) + "_" + std::to_string(Vp); + Kernel op; + if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { + Tensor dlogits_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor dlosses_t = createTensor(ctx, Shape{b * t}, kf32); + Tensor probs_t = createTensor(ctx, Shape{b * t * vp}, kf32); + Tensor targets_t = createTensor(ctx, Shape{b * t}, ki32); + op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast(b), + static_cast(t), + static_cast(v), + static_cast(vp) + }, + nullptr, + key.c_str()); + } else { + op = ctx.kernelPool.data[key]; + } + Tensor& dlogits_t = ctx.pool.data[op->buffers[0]]; + Tensor& dlosses_t = ctx.pool.data[op->buffers[1]]; + Tensor& probs_t = ctx.pool.data[op->buffers[2]]; + Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + + toGPU(ctx, dlosses, dlosses_t); + toGPU(ctx, probs, probs_t); + toGPU(ctx, targets, targets_t); + std::promise promise; std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, - Bindings{dlogits_t, dlosses_t, probs_t, targets_t}, - /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - /* params */ - CrossEntropySoftmaxBackwardParams{ - static_cast(b), - static_cast(t), - static_cast(v), - static_cast(vp) - }); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, dlogits_t, dlogits, b * t * vp * sizeof(float)); From efb87eebce70acaf5dcc5b9a2a0d67339e719142 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 19:59:28 +0900 Subject: [PATCH 07/44] Fix bugs --- experimental/kernels/ops.cpp | 18 +++++++++++++++++ .../unittest_llmc/unittest_kernels.cpp | 20 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 6390314..3d5f528 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -101,6 +101,8 @@ void encoder_backward(Context& ctx, float* dwte, float* dwpe, Tensor& dout_t = ctx.pool.data[op->buffers[2]]; Tensor& input = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, input); @@ -215,6 +217,9 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, Tensor& mean_t = ctx.pool.data[op->buffers[6]]; Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -387,6 +392,9 @@ void matmul_backward(Context& ctx, float* dinp, float* dweight, float* dbias, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -443,6 +451,8 @@ void attention_forward(Context& ctx, float* out, float* preatt, float* att, Tensor& out_t = ctx.pool.data[op->buffers[3]]; toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); std::promise promise; std::future future = promise.get_future(); @@ -499,6 +509,9 @@ void attention_backward(Context& ctx, float* dinp, float* dpreatt, float* datt, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& att_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, att, att_t); @@ -567,6 +580,7 @@ void gelu_backward(Context& ctx, float* dinp, float* inp, float* dout, int N){ toGPU(ctx, inp, inp_i); toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); std::promise promise; std::future future = promise.get_future(); @@ -632,6 +646,8 @@ void residual_backward(Context& ctx, float* dinp1, float* dinp2, float* dout, in Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); std::promise promise; std::future future = promise.get_future(); @@ -715,6 +731,7 @@ void crossentropy_forward(Context& ctx, float* losses, Tensor& probs_t = ctx.pool.data[op->buffers[1]]; Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + toGPU(ctx, losses, losses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); @@ -767,6 +784,7 @@ void crossentropy_softmax_backward(Context& ctx, float* dlogits, Tensor& probs_t = ctx.pool.data[op->buffers[2]]; Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dlogits, dlogits_t); toGPU(ctx, dlosses, dlosses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index a883124..a455780 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -173,6 +173,8 @@ void ENCODER_BACKWARD_GPU(float* dwte, float* dwpe, Tensor& dout_t = ctx.pool.data[op->buffers[2]]; Tensor& input = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dwte, dwte_t); + toGPU(ctx, dwpe, dwpe_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, input); @@ -289,6 +291,9 @@ void LAYERNORM_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, Tensor& mean_t = ctx.pool.data[op->buffers[6]]; Tensor& rstd_t = ctx.pool.data[op->buffers[7]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -518,6 +523,9 @@ void MATMUL_BACKWARD_GPU(float* dinp, float* dweight, float* dbias, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& weight_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dweight, dweight_t); + toGPU(ctx, dbias, dbias_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, weight, weight_t); @@ -575,6 +583,8 @@ void ATTENTION_FORWARD_GPU(float* out, float* preatt, float* att, Tensor& out_t = ctx.pool.data[op->buffers[3]]; toGPU(ctx, inp, inp_t); + toGPU(ctx, preatt, preatt_t); + toGPU(ctx, att, att_t); std::promise promise; std::future future = promise.get_future(); @@ -632,10 +642,13 @@ void ATTENTION_BACKWARD_GPU(float* dinp, float* dpreatt, float* datt, Tensor& inp_t = ctx.pool.data[op->buffers[4]]; Tensor& att_t = ctx.pool.data[op->buffers[5]]; + toGPU(ctx, dinp, dinp_t); + toGPU(ctx, dpreatt, dpreatt_t); + toGPU(ctx, datt, datt_t); toGPU(ctx, dout, dout_t); toGPU(ctx, inp, inp_t); toGPU(ctx, att, att_t); - + std::promise promise; std::future future = promise.get_future(); dispatchKernel(ctx, op, promise); @@ -702,6 +715,7 @@ void GELU_BACKWARD_GPU(float* dinp, float* inp, float* dout, int N){ toGPU(ctx, inp, inp_i); toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp, dinp_o); std::promise promise; std::future future = promise.get_future(); @@ -769,6 +783,8 @@ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){ Tensor& dinp2_o = ctx.pool.data[op->buffers[2]]; toGPU(ctx, dout, dout_i); + toGPU(ctx, dinp1, dinp1_o); + toGPU(ctx, dinp2, dinp2_o); std::promise promise; std::future future = promise.get_future(); @@ -854,6 +870,7 @@ void CROSSENTROPY_FORWARD_GPU(float* losses, Tensor& probs_t = ctx.pool.data[op->buffers[1]]; Tensor& targets_t = ctx.pool.data[op->buffers[2]]; + toGPU(ctx, losses, losses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); @@ -907,6 +924,7 @@ void CROSSENTROPY_SOFTMAX_BACKWARD_GPU(float* dlogits, Tensor& probs_t = ctx.pool.data[op->buffers[2]]; Tensor& targets_t = ctx.pool.data[op->buffers[3]]; + toGPU(ctx, dlogits, dlogits_t); toGPU(ctx, dlosses, dlosses_t); toGPU(ctx, probs, probs_t); toGPU(ctx, targets, targets_t); From c474fba4560a942a0a21c1323d83b5799c277fad Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 20:07:16 +0900 Subject: [PATCH 08/44] Remove global variables of kernels --- experimental/kernels/ops.cpp | 50 ++++++++++--------- .../unittest_llmc/unittest_kernels.cpp | 43 ++++++++-------- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/experimental/kernels/ops.cpp b/experimental/kernels/ops.cpp index 3d5f528..0e9c076 100644 --- a/experimental/kernels/ops.cpp +++ b/experimental/kernels/ops.cpp @@ -235,28 +235,6 @@ void layernorm_backward(Context& ctx, float* dinp, float* dweight, float* dbias, toCPU(ctx, dbias_t, dbias, c * sizeof(float)); } -static constexpr size_t MATMUL_BT = 64; -static constexpr size_t MATMUL_BC = 8; -static constexpr size_t MATMUL_BOC = 64; -static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; -static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; -static size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); -static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; -static std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); -static std::string kShaderMatmul2D(loopUnrolling( - replaceAll(kShaderMatmul2DTiling_, - {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(MATMUL_BT)}, - {"{{BC}}", toString(MATMUL_BC)}, - {"{{BOC}}", toString(MATMUL_BOC)}, - {"{{TT}}", toString(MATMUL_TT)}, - {"{{TOC}}", toString(MATMUL_TOC)}, - {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, - {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} - }) - ) - ); - struct DurationTime { std::chrono::high_resolution_clock::time_point start; std::chrono::high_resolution_clock::time_point end; @@ -301,12 +279,36 @@ void matmul_forward(Context& ctx, float* out, std::string key = "matmul_forward_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + constexpr size_t BT = 64; + constexpr size_t BC = 8; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); + std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }) + ) + ); + Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {kShaderMatmul2D, MATMUL_wgSize, kf32}, + + op = createKernel(ctx, {kShaderMatmul2D, wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index a455780..d58432e 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -314,26 +314,6 @@ void matmul_forward_dummy(float* out, int B, int T, int C, int OC); -static constexpr size_t MATMUL_BT = 64; -static constexpr size_t MATMUL_BC = 16; -static constexpr size_t MATMUL_BOC = 64; -static constexpr size_t MATMUL_TT = MATMUL_BT / MATMUL_BC; -static constexpr size_t MATMUL_TOC = MATMUL_BOC / MATMUL_BC; -static constexpr size_t MATMUL_num_threads = MATMUL_BT * MATMUL_BOC / (MATMUL_TT * MATMUL_TOC); -static Shape MATMUL_wgSize = {MATMUL_num_threads, 1, 1}; - -static std::string MATMUL_codeString(kShaderMatmul2DTiling); -static std::string MATMUL_unrolledCode = loopUnrolling(replaceAll(MATMUL_codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(MATMUL_BT)}, - {"{{BC}}", toString(MATMUL_BC)}, - {"{{BOC}}", toString(MATMUL_BOC)}, - {"{{TT}}", toString(MATMUL_TT)}, - {"{{TOC}}", toString(MATMUL_TOC)}, - {"{{NUM_TILEI}}", toString(MATMUL_BT * MATMUL_BC / MATMUL_num_threads)}, - {"{{NUM_TILEW}}", toString(MATMUL_BOC * MATMUL_BC / MATMUL_num_threads)} - })); - - void MATMUL_FORWARD_GPU(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC){ @@ -367,12 +347,31 @@ void MATMUL_FORWARD_GPU(float* out, std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - Shape nWorkgroups = {b, cdiv(T, MATMUL_BT), cdiv(OC, MATMUL_BOC)}; + constexpr size_t BT = 64; + constexpr size_t BC = 16; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + + std::string codeString(kShaderMatmul2DTiling); + std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {MATMUL_unrolledCode, MATMUL_wgSize, kf32}, + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, Bindings{inp_i, weight_i, bias_i, out_o}, nWorkgroups, /* params */ From 438be22df671de669ebeae58698e3560ab5c7b17 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 16 Oct 2024 21:09:44 +0900 Subject: [PATCH 09/44] Fix the matmul of version 1 in unittests --- .../unittest_llmc/unittest_kernels.cpp | 129 +++++++++--------- 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index d58432e..311d895 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -321,7 +321,7 @@ void MATMUL_FORWARD_GPU(float* out, bool verbose = false; bool debug = false; float *out_exp; - DurationTime duration("matmul_forward_gpu with creating context", verbose); + DurationTime duration("matmul_forward_gpu with preparing a kernel", verbose); if (verbose) { printf("matmul forward: B=%d, T=%d, C=%d, OC=%d, bias=%d\n", B, T, C, OC, bias != NULL); } @@ -341,49 +341,65 @@ void MATMUL_FORWARD_GPU(float* out, unsigned long oc = static_cast(OC); setLogLevel(kError); - { - DurationTime duration("matmul_forward_gpu: before creating tensors", verbose); + if (version == 2 || version == 1) { // Generate the key of the cache by arguments. std::string key = "MATMUL_FORWARD_GPU_" + std::to_string(B) + "_" + std::to_string(T) + "_" + std::to_string(C) + "_" + std::to_string(OC); Kernel op; if (ctx.kernelPool.data.find(key) == ctx.kernelPool.data.end()) { - constexpr size_t BT = 64; - constexpr size_t BC = 16; - constexpr size_t BOC = 64; - constexpr size_t TT = BT / BC; - constexpr size_t TOC = BOC / BC; - constexpr size_t num_threads = BT * BOC / (TT * TOC); - Shape wgSize = {num_threads, 1, 1}; - - std::string codeString(kShaderMatmul2DTiling); - std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, - {"{{BT}}", toString(BT)}, - {"{{BC}}", toString(BC)}, - {"{{BOC}}", toString(BOC)}, - {"{{TT}}", toString(TT)}, - {"{{TOC}}", toString(TOC)}, - {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, - {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} - })); - - Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; Tensor inp_i = createTensor(ctx, Shape{b * t * c}, kf32); Tensor weight_i = createTensor(ctx, Shape{oc * c}, kf32); Tensor bias_i = bias == NULL ? createTensor(ctx, Shape{1}, kf32) : createTensor(ctx, Shape{oc}, kf32); Tensor out_o = createTensor(ctx, Shape{b * t * oc}, kf32); - op = createKernel(ctx, {unrolledCode, wgSize, kf32}, - Bindings{inp_i, weight_i, bias_i, out_o}, - nWorkgroups, - /* params */ - MatmulParams{ - static_cast(b), - static_cast(t), - static_cast(c), - static_cast(oc) - }, - nullptr, - key.c_str() - ); + + if (version == 2) { + constexpr size_t BT = 64; + constexpr size_t BC = 16; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + constexpr size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + + std::string codeString(kShaderMatmul2DTiling); + std::string unrolledCode = loopUnrolling(replaceAll(codeString, {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + })); + + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + op = createKernel(ctx, {unrolledCode, wgSize, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } else { + op = createKernel(ctx, {kShaderMatmul, 256, kf32}, + Bindings{inp_i, weight_i, bias_i, out_o}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }, + nullptr, + key.c_str() + ); + } } else { op = ctx.kernelPool.data[key]; } @@ -400,39 +416,16 @@ void MATMUL_FORWARD_GPU(float* out, std::promise promise; std::future future = promise.get_future(); - - if (version == 2) { - DurationTime duration("matmul_forward_gpu: after creating tensors", verbose); - { - DurationTime duration("matmul_forward_gpu: before creating kernels", verbose); - { - DurationTime duration("matmul_forward_gpu without creating context", verbose); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - } - } - // } else if (version == 1) { - // Kernel op = createKernel(ctx, {kShaderMatmul, 256, kf32}, - // Bindings{inp_i, weight_i, bias_i, out_o}, - // /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, - // /* params */ - // MatmulParams{ - // static_cast(b), - // static_cast(t), - // static_cast(c), - // static_cast(oc) - // }); - // { - // DurationTime duration("matmul_forward_gpu without creating context", verbose); - // dispatchKernel(ctx, op, promise); - // wait(ctx, future); - // toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); - // } - } else { - DurationTime duration("matmul_forward_cpu", verbose); - matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); + + { + DurationTime duration("matmul_forward_gpu", verbose); + dispatchKernel(ctx, op, promise); + wait(ctx, future); + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } + } else { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); } if (debug) { // compare out with out_exp. From dd145acd76a00d399ea93760dd837c7927b87cd7 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Thu, 17 Oct 2024 02:00:18 +0900 Subject: [PATCH 10/44] Add the duration-time of matmul_forward_dummy to compare GPU's one with CPU's one --- experimental/kernels/unittest_llmc/unittest_kernels.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/experimental/kernels/unittest_llmc/unittest_kernels.cpp b/experimental/kernels/unittest_llmc/unittest_kernels.cpp index 311d895..d037eac 100644 --- a/experimental/kernels/unittest_llmc/unittest_kernels.cpp +++ b/experimental/kernels/unittest_llmc/unittest_kernels.cpp @@ -327,7 +327,10 @@ void MATMUL_FORWARD_GPU(float* out, } if (debug) { out_exp = new float[B*T*OC]; - matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + { + DurationTime duration("matmul_forward_cpu", verbose); + matmul_forward_dummy(out_exp, inp, weight, bias, B, T, C, OC); + } } struct MatmulParams { uint32_t B; @@ -421,8 +424,8 @@ void MATMUL_FORWARD_GPU(float* out, DurationTime duration("matmul_forward_gpu", verbose); dispatchKernel(ctx, op, promise); wait(ctx, future); - toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } + toCPU(ctx, out_o, out, b * t * oc * sizeof(float)); } else { DurationTime duration("matmul_forward_cpu", verbose); matmul_forward_dummy(out, inp, weight, bias, B, T, C, OC); From 590f2571710ebca733e316b96e1360d502df36a7 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 16:09:51 +0900 Subject: [PATCH 11/44] Add wgpuBufferRelease for CopyData --- gpu.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 8f8bd9b..2db9f66 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1065,6 +1065,9 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } toCPU(ctx, tensor, data, bufferSize, op); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } } /** @@ -1131,6 +1134,9 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, }, &callbackData); wait(ctx, op.future); + if (op.readbackBuffer) { + wgpuBufferRelease(op.readbackBuffer); + } } From 6c52b98dd80ac174b46ba1b439f74956a5392d01 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:26:43 +0900 Subject: [PATCH 12/44] Add wgpuCommandBufferRelease after calling wgpuQueueSubmit --- gpu.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 2db9f66..4cca2a1 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1002,6 +1002,7 @@ inline void wait(Context &ctx, std::future &future) { inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize, CopyData &op) { wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; wgpuQueueOnSubmittedWorkDone( @@ -1109,6 +1110,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); + wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; wgpuQueueOnSubmittedWorkDone( @@ -1513,6 +1515,7 @@ inline void dispatchKernel(Context &ctx, Kernel &kernel, resetCommandBuffer(ctx.device, kernel); } wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); + wgpuCommandBufferRelease(kernel->commandBuffer); kernel->used = true; wgpuQueueOnSubmittedWorkDone( ctx.queue, From a6140e0d83bf6c27de6d1264dd1b15f3abcaf67f Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:31:16 +0900 Subject: [PATCH 13/44] Add wgpuCommandEncoderRelease after calling wgpuCommandEncoderFinish --- gpu.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gpu.hpp b/gpu.hpp index 4cca2a1..9262dcb 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1063,6 +1063,7 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0, op.readbackBuffer, 0, bufferSize); op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } toCPU(ctx, tensor, data, bufferSize, op); @@ -1107,6 +1108,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, op.readbackBuffer, 0, bufferSize); op.commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); check(op.commandBuffer, "Create command buffer", __FILE__, __LINE__); } wgpuQueueSubmit(ctx.queue, 1, &op.commandBuffer); @@ -1225,6 +1227,7 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); + wgpuCommandEncoderRelease(commandEncoder); op->used = false; } } From 30ed026e1f62e70ca1c9bfbd5a4541f8e76845ea Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 18 Oct 2024 17:41:48 +0900 Subject: [PATCH 14/44] Add wgpuComputePassEncoderRelease after calling wgpuComputePassEncoderEnd --- gpu.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gpu.hpp b/gpu.hpp index 9262dcb..941656e 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1058,7 +1058,6 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { } { WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, tensor.data.buffer, 0, op.readbackBuffer, 0, bufferSize); @@ -1103,7 +1102,6 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, } { WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; commandEncoder = wgpuDeviceCreateCommandEncoder(ctx.device, nullptr); wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, buffer, 0, op.readbackBuffer, 0, bufferSize); @@ -1226,6 +1224,7 @@ inline void resetCommandBuffer(WGPUDevice &device, Kernel &op) { computePassEncoder, op->totalWorkgroups[0], op->totalWorkgroups[1], op->totalWorkgroups[2]); wgpuComputePassEncoderEnd(computePassEncoder); + wgpuComputePassEncoderRelease(computePassEncoder); op->commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); wgpuCommandEncoderRelease(commandEncoder); op->used = false; From 0a9437f0c11c204b139f7df7f13d3baff44e21f0 Mon Sep 17 00:00:00 2001 From: buttfa <1662332017@qq.com> Date: Sat, 19 Oct 2024 22:11:21 +0800 Subject: [PATCH 15/44] chore: Set a check-os target and improved the description of how to install dependencies. --- Makefile | 47 +++++++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/Makefile b/Makefile index 35c3b0f..063dda7 100644 --- a/Makefile +++ b/Makefile @@ -21,49 +21,35 @@ pch: # TODO(avh): change extension based on platform # Get the current OS name OS = $(shell uname | tr -d '\n') +# Set the specific variables for each platform LIB_PATH ?= /usr/lib HEADER_PATH ?= /usr/include -# Set the specific variables for each platform ifeq ($(OS), Linux) - OS_TYPE ?= Linux - - GPU_CPP_LIB_NAME ?= libgpucpp.so - DAWN_LIB_NAME ?= libdawn.so +OS_TYPE ?= Linux +GPU_CPP_LIB_NAME ?= libgpucpp.so +DAWN_LIB_NAME ?= libdawn.so else ifeq ($(OS), Darwin) - OS_TYPE ?= macOS - - GPU_CPP_LIB_NAME ?= libgpucpp.dylib - DAWN_LIB_NAME ?= libdawn.dylib +OS_TYPE ?= macOS +GPU_CPP_LIB_NAME ?= libgpucpp.dylib +DAWN_LIB_NAME ?= libdawn.dylib else - OS_TYPE ?= unknown +OS_TYPE ?= unknown endif lib: check-clang dawnlib -ifneq ($(OS_TYPE), unknown) mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/$(GPU_CPP_LIB_NAME) python3 build.py cp third_party/lib/$(DAWN_LIB_NAME) build/ -else - @echo "Unsupported operating system" -endif install: -ifneq ($(OS_TYPE), unknown) cp build/$(GPU_CPP_LIB_NAME) $(LIB_PATH) cp build/$(DAWN_LIB_NAME) $(LIB_PATH) cp build/gpu.hpp $(HEADER_PATH) -else - @echo "Unsupported operating system" -endif uninstall: -ifneq ($(OS_TYPE), unknown) rm $(LIB_PATH)/$(GPU_CPP_LIB_NAME) rm $(LIB_PATH)/$(DAWN_LIB_NAME) rm $(HEADER_PATH)/gpu.hpp -else - @echo "Unsupported operating system" -endif examples/hello_world/build/hello_world: check-clang dawnlib examples/hello_world/run.cpp check-linux-vulkan $(LIBSPEC) && cd examples/hello_world && make build/hello_world && ./build/hello_world @@ -139,15 +125,24 @@ clean-all: # Checks ################################################################################ +# Check all +check-all: check-os check-clang check-cmake check-python + +# check the os +check-os: +ifeq ($(OS_TYPE), unknown) +$(error Unsupported operating system) +endif + # check for the existence of clang++ and cmake check-clang: - @command -v clang++ >/dev/null 2>&1 || { echo >&2 "Please install clang++ with 'sudo apt-get install clang' or 'brew install llvm'"; exit 1; } + @command -v clang++ >/dev/null 2>&1 || { echo -e >&2 "Clang++ is not installed. Please install clang++ to continue.\nOn Debian / Ubuntu: 'sudo apt-get install clang' or 'brew install llvm'\nOn Centos: 'sudo yum install clang'"; exit 1; } check-cmake: - @command -v cmake >/dev/null 2>&1 || { echo >&2 "Please install cmake with 'sudo apt-get install cmake' or 'brew install cmake'"; exit 1; } + @command -v cmake >/dev/null 2>&1 || { echo -e >&2 "Cmake is not installed. Please install cmake to continue.\nOn Debian / Ubuntu: 'sudo apt-get install cmake' or 'brew install cmake'\nOn Centos: 'sudo yum install cmake'"; exit 1; } check-python: - @command -v python3 >/dev/null 2>&1 || { echo >&2 "Python needs to be installed and in your path."; exit 1; } + @command -v python3 >/dev/null 2>&1 || { echo -e >&2 "Python is not installed. Please install python to continue.\nOn Debian / Ubuntu: 'sudo apt-get install python'\nOn Centos: 'sudo yum install python'"; exit 1; } check-linux-vulkan: @echo "Checking system type and Vulkan availability..." @@ -156,7 +151,7 @@ check-linux-vulkan: echo "Vulkan is installed."; \ vulkaninfo; \ else \ - echo "Vulkan is not installed. Please install Vulkan drivers to continue. On Debian / Ubuntu: sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools"; \ + echo -e "Vulkan is not installed. Please install Vulkan drivers to continue.\nOn Debian / Ubuntu: 'sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools'.\nOn Centos: 'sudo yum install vulkan vulkan-tools.'"; \ exit 1; \ fi \ else \ From d4eb571b18e04c7fe19d8b1dd8662b99d140372c Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Mon, 21 Oct 2024 11:22:48 +0900 Subject: [PATCH 16/44] Add the ops of AoT --- experimental/kernels/gpt2_webgpu_aot.cpp | 799 +++++++++++++++++++++++ experimental/kernels/ops_aot.cpp | 356 ++++++++++ experimental/kernels/ops_aot.hpp | 108 +++ 3 files changed, 1263 insertions(+) create mode 100644 experimental/kernels/gpt2_webgpu_aot.cpp create mode 100644 experimental/kernels/ops_aot.cpp create mode 100644 experimental/kernels/ops_aot.hpp diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp new file mode 100644 index 0000000..0c136f7 --- /dev/null +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -0,0 +1,799 @@ +#include "gpu.hpp" +#include "ops.hpp" +/* +This file trains the GPT-2 model. +This version is the clean, minimal, reference. As such: +- it runs on CPU. +- it does not make the code too complex; it is readable. +- it does not use any processor-specific instructions, intrinsics and such. +- it _does_ use a few OpenMP pragmas because this is a large speedup at very low cost +There will be other versions of this code that specialize it and make it fast. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef OMP +#include +#endif +// our own utilities +// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck +#include "llmc/utils.h" +// defines: tokenizer_init, tokenizer_decode, tokenizer_free +#include "llmc/tokenizer.h" +// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free +#include "llmc/dataloader.h" + +using namespace gpu; + +// ---------------------------------------------------------------------------- +// GPT-2 model definition + +typedef struct { + int max_seq_len; // max sequence length, e.g. 1024 + int vocab_size; // vocab size, e.g. 50257 + int padded_vocab_size; // padded to e.g. %128==0, 50304 + int num_layers; // number of layers, e.g. 12 + int num_heads; // number of heads in attention, e.g. 12 + int channels; // number of channels, e.g. 768 +} GPT2Config; + +// the parameters of the model +#define NUM_PARAMETER_TENSORS 16 +#define NUM_PARAMETER_LAYERS 12 +typedef struct { + Tensor wte; // (V, C) + Tensor wpe; // (maxT, C) + std::vector ln1w; // (L, C) + std::vector ln1b; // (L, C) + std::vector qkvw; // (L, 3*C, C) + std::vector qkvb; // (L, 3*C) + std::vector attprojw; // (L, C, C) + std::vector attprojb; // (L, C) + std::vector ln2w; // (L, C) + std::vector ln2b; // (L, C) + std::vector fcw; // (L, 4*C, C) + std::vector fcb; // (L, 4*C) + std::vector fcprojw; // (L, C, 4*C) + std::vector fcprojb; // (L, C) + Tensor lnfw; // (C) + Tensor lnfb; // (C) +} ParameterTensors; + +void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) { + size_t Vp = config.padded_vocab_size; + size_t C = config.channels; + size_t maxT = config.max_seq_len; + size_t L = config.num_layers; + param_sizes[0] = Vp * C; // wte + param_sizes[1] = maxT * C; // wpe + param_sizes[2] = L * C; // ln1w + param_sizes[3] = L * C; // ln1b + param_sizes[4] = L * (3 * C) * C; // qkvw + param_sizes[5] = L * (3 * C); // qkvb + param_sizes[6] = L * C * C; // attprojw + param_sizes[7] = L * C; // attprojb + param_sizes[8] = L * C; // ln2w + param_sizes[9] = L * C; // ln2b + param_sizes[10] = L * (4 * C) * C; // fcw + param_sizes[11] = L * (4 * C); // fcb + param_sizes[12] = L * C * (4 * C); // fcprojw + param_sizes[13] = L * C; // fcprojb + param_sizes[14] = C; // lnfw + param_sizes[15] = C; // lnfb +} + +// allocate memory for the parameters and point the individual tensors to the right places +float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) { + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += param_sizes[i]; + } + // malloc all parameters all at once + float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); + // assign all the tensors + float** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + float* params_memory_iterator = params_memory; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = params_memory_iterator; + params_memory_iterator += param_sizes[i]; + } + return params_memory; +} + + +#define NUM_ACTIVATION_TENSORS 23 +typedef struct { + Tensor encoded; // (B, T, C) + std::vector ln1; // (L, B, T, C) + std::vector ln1_mean; // (L, B, T) + std::vector ln1_rstd; // (L, B, T) + std::vector qkv; // (L, B, T, 3*C) + std::vector atty; // (L, B, T, C) + std::vector preatt; // (L, B, NH, T, T) + std::vector att; // (L, B, NH, T, T) + std::vector attproj; // (L, B, T, C) + std::vector residual2; // (L, B, T, C) + std::vector ln2; // (L, B, T, C) + std::vector ln2_mean; // (L, B, T) + std::vector ln2_rstd; // (L, B, T) + std::vector fch; // (L, B, T, 4*C) + std::vector fch_gelu; // (L, B, T, 4*C) + std::vector fcproj; // (L, B, T, C) + std::vector residual3; // (L, B, T, C) + Tensor lnf; // (B, T, C) + Tensor lnf_mean; // (B, T) + Tensor lnf_rstd; // (B, T) + Tensor logits; // (B, T, V) + Tensor probs; // (B, T, V) + Tensor losses; // (B, T) +} ActivationTensors; + +typedef struct { + Kernel encoder_forward; + std::vector layernorm_forward; + std::vector qkv_projection_forward; + std::vector attention_forward; + std::vector attention_projection_forward; + std::vector residual_forward; + std::vector ff_up_forward; + std::vector gelu_forward; + std::vector ff_down_forward; + std::vector residual2_forward; + Kernel layernorm_final_forward; + Kernel matmul_final_forward; + Kernel softmax_final_forward; + std::vector crossentropy_forward; + + Kernel crossentropy_softmax_backward; + Kernel matmul_final_backward; + Kernel layernorm_final_backward; + std::vector residual2_backward; + std::vector ff_down_backward; + std::vector gelu_backward; + std::vector ff_up_backward; + std::vector layernorm2_backward; + std::vector attention_projection_backward; + std::vector attention_backward; + std::vector qkv_projection_backward; + std::vector layernorm1_backward; + Kernel encoder_backward; +} Kernels; + +void fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T) { + size_t C = config.channels; + size_t NH = config.num_heads; + size_t L = config.num_layers; + size_t Vp = config.padded_vocab_size; + act_sizes[0] = B * T * C; // encoded + act_sizes[1] = L * B * T * C; // ln1 + act_sizes[2] = L * B * T; // ln1_mean + act_sizes[3] = L * B * T; // ln1_rstd + act_sizes[4] = L * B * T * 3 * C; // qkv + act_sizes[5] = L * B * T * C; // atty + act_sizes[6] = L * B * NH * T * T; // preatt + act_sizes[7] = L * B * NH * T * T; // att + act_sizes[8] = L * B * T * C; // attproj + act_sizes[9] = L * B * T * C; // residual2 + act_sizes[10] = L * B * T * C; // ln2 + act_sizes[11] = L * B * T; // ln2_mean + act_sizes[12] = L * B * T; // ln2_rstd + act_sizes[13] = L * B * T * 4 * C; // fch + act_sizes[14] = L * B * T * 4 * C; // fch_gelu + act_sizes[15] = L * B * T * C; // fcproj + act_sizes[16] = L * B * T * C; // residual3 + act_sizes[17] = B * T * C; // lnf + act_sizes[18] = B * T; // lnf_mean + act_sizes[19] = B * T; // lnf_rstd + act_sizes[20] = B * T * Vp; // logits + act_sizes[21] = B * T * Vp; // probs + act_sizes[22] = B * T; // losses +} + +float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) { + size_t num_activations = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + num_activations += act_sizes[i]; + } + float* acts_memory = (float*)mallocCheck(num_activations * sizeof(float)); + float** ptrs[] = { + &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty, + &acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, + &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, + &acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses + }; + float* acts_memory_iterator = acts_memory; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + *(ptrs[i]) = acts_memory_iterator; + acts_memory_iterator += act_sizes[i]; + } + return acts_memory; +} + +struct GPUParameters { + Tensor data[NUM_PARAMETER_TENSORS]; +}; + +struct GPUActivations { + Tensor data[NUM_ACTIVATION_TENSORS]; +}; + + +void gpu_alloc(Context& ctx, Tensor* tensors, size_t* sizes, size_t n) { + for (size_t i = 0; i < n; i++) { + tensors[i] = createTensor(ctx, Shape{sizes[i]}, kf32); + } +} + +typedef struct { + GPT2Config config; + // the weights (parameters) of the model, and their sizes + ParameterTensors params; + GPUParameters params_; // TODO(avh): eventually this replaces params + size_t param_sizes[NUM_PARAMETER_TENSORS]; + float* params_memory; + size_t num_parameters; + // gradients of the weights + ParameterTensors grads; + float* grads_memory; + // buffers for the AdamW optimizer + float* m_memory; + float* v_memory; + // the activations of the model, and their sizes + ActivationTensors acts; + GPUActivations acts_; // TODO(avh): eventually this replaces params + size_t act_sizes[NUM_ACTIVATION_TENSORS]; + float* acts_memory; + size_t num_activations; + // gradients of the activations + ActivationTensors grads_acts; + float* grads_acts_memory; + // other run state configuration + int batch_size; // the batch size (B) of current forward pass + int seq_len; // the sequence length (T) of current forward pass + int* inputs; // the input tokens for the current forward pass + int* targets; // the target tokens for the current forward pass + float mean_loss; // after a forward pass with targets, will be populated with the mean loss +} GPT2; + +void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoint_path) { + + // read in model from a checkpoint file + FILE *model_file = fopenCheck(checkpoint_path, "rb"); + int model_header[256]; + freadCheck(model_header, sizeof(int), 256, model_file); + if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(1); } + if (model_header[1] != 3) { + printf("Bad version in model file\n"); + printf("---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(1); + } + + // read in hyperparameters + size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow + model->config.max_seq_len = maxT = model_header[2]; + model->config.vocab_size = V = model_header[3]; +#ifdef __EMSCRIPTEN__ + model->config.num_layers = L = 12; // TODO(avh): Debugging only hack - revert this +#else + model->config.num_layers = L = model_header[4]; +#endif + model->config.num_heads = NH = model_header[5]; + model->config.channels = C = model_header[6]; + model->config.padded_vocab_size = Vp = model_header[7]; + printf("[GPT-2]\n"); + printf("max_seq_len: %zu\n", maxT); + printf("vocab_size: %zu\n", V); + printf("padded_vocab_size: %zu\n", Vp); + printf("num_layers: %zu\n", L); + printf("num_heads: %zu\n", NH); + printf("channels: %zu\n", C); + + // allocate space for all the parameters and read them in + fill_in_parameter_sizes(model->param_sizes, model->config); + + // count the number of parameters + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += model->param_sizes[i]; + } + printf("num_parameters: %zu\n", num_parameters); + model->num_parameters = num_parameters; + + // read in all the parameters from file + model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes); + freadCheck(model->params_memory, sizeof(float), num_parameters, model_file); + fcloseCheck(model_file); + + // other inits + model->acts_memory = NULL; + model->grads_memory = NULL; + model->m_memory = NULL; + model->v_memory = NULL; + model->grads_acts_memory = NULL; + model->inputs = NULL; + model->targets = NULL; + model->batch_size = 0; + model->seq_len = 0; + model->mean_loss = -1.0f; // -1.0f will designate no loss + + // TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations + gpu_alloc(ctx, model->params_.data, model->param_sizes, NUM_PARAMETER_TENSORS); + +} + + +void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { + // targets are optional and could be NULL + + // ensure the model was initialized or error out + if (model->params_memory == NULL) { + printf("Error: model was not initialized properly.\n"); + exit(1); + } + + // convenience parameters (size_t to help prevent int overflow) + size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; + size_t L = model->config.num_layers; + size_t NH = model->config.num_heads; + size_t C = model->config.channels; + + // validate inputs, all indices must be in the range [0, V) + for(int i = 0; i < B * T; i++) { + assert(0 <= inputs[i] && inputs[i] < V); + if (targets != NULL) { + assert(0 <= targets[i] && targets[i] < V); + } + } + + // allocate space for all the activations if needed (done here, lazily) + if(model->acts_memory == NULL) { + // record the current B,T as well + model->batch_size = B; + model->seq_len = T; + // and now allocate the space + fill_in_activation_sizes(model->act_sizes, model->config, B, T); + // TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations + gpu_alloc(ctx, model->acts_.data, model->act_sizes, NUM_PARAMETER_TENSORS); + size_t num_activations = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + num_activations += model->act_sizes[i]; + } + printf("num_activations: %zu\n", num_activations); + model->num_activations = num_activations; + printf("Allocating %.2f MB for activations\n", num_activations * sizeof(float) / (1024.0f * 1024.0f)); + model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes); + // also create memory for caching inputs and targets + model->inputs = (int*)mallocCheck(B * T * sizeof(int)); + model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small + } else { + // validate B,T is consistent with how we've allocated the memory before + // in principle we could get more clever here in the future, for now this is safest + if (B != model->batch_size || T != model->seq_len) { + printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T); + exit(EXIT_FAILURE); + } + } + + printf("Cache inputs/targets\n"); + // cache the inputs/targets + memcpy(model->inputs, inputs, B * T * sizeof(int)); + if (targets != NULL) { + memcpy(model->targets, targets, B * T * sizeof(int)); + } + + printf("Forward pass\n"); + // forward pass + ParameterTensors params = model->params; // for brevity + ActivationTensors acts = model->acts; + float* residual; + printf("Encoding\n"); + printf("inputs[0] = %d\n", inputs[0]); + encoder_forward(ctx, acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0] + for (int l = 0; l < L; l++) { + printf("Forward Pass Layer %d\n", l); + + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + float* l_ln1w = params.ln1w + l * C; + float* l_ln1b = params.ln1b + l * C; + float* l_qkvw = params.qkvw + l * 3*C * C; + float* l_qkvb = params.qkvb + l * 3*C; + float* l_attprojw = params.attprojw + l * C * C; + float* l_attprojb = params.attprojb + l * C; + float* l_ln2w = params.ln2w + l * C; + float* l_ln2b = params.ln2b + l * C; + float* l_fcw = params.fcw + l * 4*C * C; + float* l_fcb = params.fcb + l * 4*C; + float* l_fcprojw = params.fcprojw + l * C * 4*C; + float* l_fcprojb = params.fcprojb + l * C; + + // get the pointers of the activations for this layer + float* l_ln1 = acts.ln1 + l * B * T * C; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_qkv = acts.qkv + l * B * T * 3*C; + float* l_atty = acts.atty + l * B * T * C; + float* l_preatt = acts.preatt + l * B * NH * T * T; + float* l_att = acts.att + l * B * NH * T * T; + float* l_attproj = acts.attproj + l * B * T * C; + float* l_residual2 = acts.residual2 + l * B * T * C; + float* l_ln2 = acts.ln2 + l * B * T * C; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_fch = acts.fch + l * B * T * 4*C; + float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; + float* l_fcproj = acts.fcproj + l * B * T * C; + float* l_residual3 = acts.residual3 + l * B * T * C; + + // now do the forward pass + printf(" [Forward] : LayerNorm1\n"); + layernorm_forward(ctx, l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); + printf(" [Forward] : QKV Projection\n"); + matmul_forward(ctx, l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + printf(" [Forward] : Attention\n"); + attention_forward(ctx, l_atty, l_preatt, l_att, l_qkv, B, T, C, NH); + printf(" [Forward] : Attention Projection\n"); + matmul_forward(ctx, l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + printf(" [Forward] : Residual1\n"); + residual_forward(ctx, l_residual2, residual, l_attproj, B*T*C); + printf(" [Forward] : LayerNorm2\n"); + layernorm_forward(ctx, l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); + printf(" [Forward] : FF Up\n"); + matmul_forward(ctx, l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + printf(" [Forward] : GELU\n"); + gelu_forward(ctx, l_fch_gelu, l_fch, B*T*4*C); + printf(" [Forward] : FF Down\n"); + matmul_forward(ctx, l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + printf(" [Forward] : Residual2\n"); + residual_forward(ctx, l_residual3, l_residual2, l_fcproj, B*T*C); + } + residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 + layernorm_forward(ctx, acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); + matmul_forward(ctx, acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp); + softmax_forward(ctx, acts.probs, acts.logits, B, T, V, Vp); + + printf("Crossentropy\n"); + // also forward the cross-entropy loss function if we have the targets + if (targets != NULL) { + crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); + // for convenience also evaluate the mean loss + float mean_loss = 0.0f; + for (int i=0; iacts.losses[i]; } + mean_loss /= B*T; + model->mean_loss = mean_loss; + } else { + // if we don't have targets, we don't have a loss + model->mean_loss = -1.0f; + } + printf("Forward pass done\n"); +} + +void gpt2_zero_grad(GPT2 *model) { + if(model->grads_memory != NULL) { memset(model->grads_memory, 0, model->num_parameters * sizeof(float)); } + if(model->grads_acts_memory != NULL) { memset(model->grads_acts_memory, 0, model->num_activations * sizeof(float)); } +} + +void gpt2_backward(Context& ctx, GPT2 *model) { + printf("Backward pass\n"); + + // double check we forwarded previously, with targets + if (model->mean_loss == -1.0f) { + printf("Error: must forward with targets before backward\n"); + exit(1); + } + + // lazily allocate the memory for gradients of the weights and activations, if needed + if (model->grads_memory == NULL) { + printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f)); + model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes); + model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, model->act_sizes); + gpt2_zero_grad(model); + } + + // convenience shortcuts (and size_t to help prevent int overflow) + size_t B = model->batch_size; + size_t T = model->seq_len; + size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; + size_t L = model->config.num_layers; + size_t NH = model->config.num_heads; + size_t C = model->config.channels; + + // backward pass: go in the reverse order of the forward pass, and call backward() functions + ParameterTensors params = model->params; // for brevity + ParameterTensors grads = model->grads; + ActivationTensors acts = model->acts; + ActivationTensors grads_acts = model->grads_acts; + + // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) + // technically this is a small, inline backward() pass of calculating + // total, final loss as the mean over all losses over all (B,T) positions in the batch + float dloss_mean = 1.0f / (B*T); + for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; } + + crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp); + matmul_backward(ctx, grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp); + float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual + float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual + layernorm_backward(ctx, dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C); + + for (int l = L-1; l >= 0; l--) { + printf("Backward Pass Layer %d\n", l); + + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + float* l_ln1w = params.ln1w + l * C; + float* l_qkvw = params.qkvw + l * 3*C * C; + float* l_attprojw = params.attprojw + l * C * C; + float* l_ln2w = params.ln2w + l * C; + float* l_fcw = params.fcw + l * 4*C * C; + float* l_fcprojw = params.fcprojw + l * C * 4*C; + // get the pointers of the gradients of the weights for this layer + float* dl_ln1w = grads.ln1w + l * C; + float* dl_ln1b = grads.ln1b + l * C; + float* dl_qkvw = grads.qkvw + l * 3*C * C; + float* dl_qkvb = grads.qkvb + l * 3*C; + float* dl_attprojw = grads.attprojw + l * C * C; + float* dl_attprojb = grads.attprojb + l * C; + float* dl_ln2w = grads.ln2w + l * C; + float* dl_ln2b = grads.ln2b + l * C; + float* dl_fcw = grads.fcw + l * 4*C * C; + float* dl_fcb = grads.fcb + l * 4*C; + float* dl_fcprojw = grads.fcprojw + l * C * 4*C; + float* dl_fcprojb = grads.fcprojb + l * C; + // get the pointers of the activations for this layer + float* l_ln1 = acts.ln1 + l * B * T * C; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_qkv = acts.qkv + l * B * T * 3*C; + float* l_atty = acts.atty + l * B * T * C; + float* l_att = acts.att + l * B * NH * T * T; + float* l_residual2 = acts.residual2 + l * B * T * C; + float* l_ln2 = acts.ln2 + l * B * T * C; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_fch = acts.fch + l * B * T * 4*C; + float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; + // get the pointers of the gradients of the activations for this layer + float* dl_ln1 = grads_acts.ln1 + l * B * T * C; + float* dl_qkv = grads_acts.qkv + l * B * T * 3*C; + float* dl_atty = grads_acts.atty + l * B * T * C; + float* dl_preatt = grads_acts.preatt + l * B * NH * T * T; + float* dl_att = grads_acts.att + l * B * NH * T * T; + float* dl_attproj = grads_acts.attproj + l * B * T * C; + float* dl_residual2 = grads_acts.residual2 + l * B * T * C; + float* dl_ln2 = grads_acts.ln2 + l * B * T * C; + float* dl_fch = grads_acts.fch + l * B * T * 4*C; + float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C; + float* dl_fcproj = grads_acts.fcproj + l * B * T * C; + float* dl_residual3 = grads_acts.residual3 + l * B * T * C; + + // backprop this layer + printf(" [Backward] : Residual2\n"); + residual_backward(ctx, dl_residual2, dl_fcproj, dl_residual3, B*T*C); + printf(" [Backward] : FF Down \n"); + matmul_backward(ctx, dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C); + printf(" [Backward] : GELU\n"); + gelu_backward(ctx, dl_fch, l_fch, dl_fch_gelu, B*T*4*C); + printf(" [Backward] : FF Up\n"); + matmul_backward(ctx, dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C); + printf(" [Backward] : LayerNorm2\n"); + layernorm_backward(ctx, dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C); + printf(" [Backward] : Residual1\n"); + residual_backward(ctx, dresidual, dl_attproj, dl_residual2, B*T*C); + printf(" [Backward] : Attention Projection\n"); + matmul_backward(ctx, dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C); + printf(" [Backward] : Attention\n"); + attention_backward(ctx, dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH); + printf(" [Backward] : QKV Projection\n"); + matmul_backward(ctx, dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C); + printf(" [Backward] : LayerNorm1\n"); + layernorm_backward(ctx, dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); + } + encoder_backward(ctx, grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C); +} + +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) { + // reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + + // lazily allocate the memory for m_memory and v_memory + if (model->m_memory == NULL) { + model->m_memory = (float*)calloc(model->num_parameters, sizeof(float)); + model->v_memory = (float*)calloc(model->num_parameters, sizeof(float)); + } + + for (size_t i = 0; i < model->num_parameters; i++) { + float param = model->params_memory[i]; + float grad = model->grads_memory[i]; + + // update the first moment (momentum) + float m = beta1 * model->m_memory[i] + (1.0f - beta1) * grad; + // update the second moment (RMSprop) + float v = beta2 * model->v_memory[i] + (1.0f - beta2) * grad * grad; + // bias-correct both moments + float m_hat = m / (1.0f - powf(beta1, t)); + float v_hat = v / (1.0f - powf(beta2, t)); + + // update + model->m_memory[i] = m; + model->v_memory[i] = v; + model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param); + } +} + +void gpt2_free(GPT2 *model) { + free(model->params_memory); + free(model->grads_memory); + free(model->m_memory); + free(model->v_memory); + free(model->acts_memory); + free(model->grads_acts_memory); + free(model->inputs); + free(model->targets); +} + +#ifndef TESTING +// if we are TESTING (see test_gpt2.c), we'll skip the int main below +// ---------------------------------------------------------------------------- +// sampler + +unsigned int random_u32(uint64_t *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} +float random_f32(uint64_t *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + +int sample_mult(float* probabilities, int n, float coin) { + // sample index from probabilities (they must sum to 1!) + // coin is a random number in [0, 1), usually from random_f32() + float cdf = 0.0f; + for (int i = 0; i < n; i++) { + cdf += probabilities[i]; + if (coin < cdf) { + return i; + } + } + return n - 1; // in case of rounding errors +} + +// ---------------------------------------------------------------------------- +// main training loop +int main() { + + setLogLevel(kWarn); + + printf("Creating GPU context\n"); + WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB; + gpu::Context ctx = gpu::createContext({}, {}, { + .requiredLimits = &requiredLimits + }); + // gpu::Context ctx = gpu::createContext(); + + // build the GPT-2 model from a checkpoint + GPT2 model; + gpt2_build_from_checkpoint(ctx, &model, "gpt2_124M.bin"); + + // build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories + const char* tiny_stories_train = "dev/data/tinystories/TinyStories_train.bin"; + const char* tiny_stories_val = "dev/data/tinystories/TinyStories_val.bin"; + const char* tiny_shakespeare_train = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; + const char* tiny_shakespeare_val = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; + const char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train; + const char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val; + constexpr int B = 4; // batch size 4 (i.e. 4 independent token sequences will be trained on) + constexpr int T = 64; // sequence length 64 (i.e. each sequence is 64 tokens long). must be <= maxT, which is 1024 for GPT-2 + DataLoader train_loader, val_loader; + dataloader_init(&train_loader, train_tokens, B, T, 0, 1, 1); + dataloader_init(&val_loader, val_tokens, B, T, 0, 1, 0); + printf("train dataset num_batches: %zu\n", train_loader.num_tokens / (B*T)); + printf("val dataset num_batches: %zu\n", val_loader.num_tokens / (B*T)); + int val_num_batches = 5; + + // build the Tokenizer + Tokenizer tokenizer; + tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + + // some memory for generating samples from the model + uint64_t rng_state = 1337; + int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); + const int genT = 64; // number of steps of inference we will do + + + // train + struct timespec start, end; + printf("Starting training\n"); + for (int step = 0; step <= 40; step++) { + printf("Step %d\n", step); + + // once in a while estimate the validation loss + if (step % 10 == 0) { + float val_loss = 0.0f; + dataloader_reset(&val_loader); + for (int i = 0; i < val_num_batches; i++) { + dataloader_next_batch(&val_loader); + gpt2_forward(ctx, &model, val_loader.inputs, val_loader.targets, B, T); + val_loss += model.mean_loss; + } + val_loss /= val_num_batches; + printf("val loss %f\n", val_loss); + } + + // once in a while do model inference to print generated text + if (step > 0 && step % 20 == 0) { + // fill up gen_tokens with the GPT2_EOT, which kicks off the generation + for(int i = 0; i < B * T; ++i) { + gen_tokens[i] = tokenizer.eot_token; + } + // now sample from the model autoregressively + printf("generating:\n---\n"); + for (int t = 1; t < genT; t++) { + // note that inference is very wasteful here because for each token + // we re-calculate the forward pass for all of (B,T) positions from scratch + // but the inference here is just for sanity checking anyway + // and we can maybe optimize a bit more later, with careful tests + gpt2_forward(ctx, &model, gen_tokens, NULL, B, T); + // furthermore, below we're only using b=0 (i.e. the first row) of all B rows + // we're in principle running B "inference streams" in parallel here + // but only using position 0 + // get the Vp-dimensional vector probs[0, t-1, :] + float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size; + float coin = random_f32(&rng_state); + // note we're only sampling from the first V elements, ignoring padding + // (the probabilities in the padded region should be zero anyway) + int next_token = sample_mult(probs, model.config.vocab_size, coin); + gen_tokens[t] = next_token; + // print the generated token, either using the Tokenizer or a fallback + if (tokenizer.init_ok) { + const char* token_str = tokenizer_decode(&tokenizer, next_token); + safe_printf(token_str); + } else { + // fall back to printing the token id + printf("%d ", next_token); + } + fflush(stdout); + } + printf("\n---\n"); + } + + // do a training step + clock_gettime(CLOCK_MONOTONIC, &start); + dataloader_next_batch(&train_loader); + gpt2_forward(ctx, &model, train_loader.inputs, train_loader.targets, B, T); + gpt2_zero_grad(&model); + gpt2_backward(ctx, &model); + gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + clock_gettime(CLOCK_MONOTONIC, &end); + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; + printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000); + } + + // free + dataloader_free(&train_loader); + dataloader_free(&val_loader); + tokenizer_free(&tokenizer); + gpt2_free(&model); + free(gen_tokens); + return 0; +} +#endif diff --git a/experimental/kernels/ops_aot.cpp b/experimental/kernels/ops_aot.cpp new file mode 100644 index 0000000..f4ce9c0 --- /dev/null +++ b/experimental/kernels/ops_aot.cpp @@ -0,0 +1,356 @@ +#include "gpu.hpp" +#include +#include +#include +#include + +#include "kernels.h" +#include "ops_aot.hpp" +#include "experimental/wgsl.h" // loopUnrolling + +using namespace gpu; + +Kernel encoder_forward(Context& ctx, Tensor& out, + Tensor& inp, Tensor& wte, Tensor& wpe, + int B, int T, int C){ + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long v = VOCAB_SIZE; + struct EncoderParams { + uint32_t B; + uint32_t T; + uint32_t C; + }; + setLogLevel(kError); + return createKernel(ctx, {kShaderEncoder, 256, kf32}, + Bindings{inp, wte, wpe, out}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }); +} + +Kernel encoder_backward(Context& ctx, Tensor& dwte, Tensor& dwpe, + Tensor& dout, Tensor& inp, + int B, int T, int C) { + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long v = VOCAB_SIZE; + struct EncoderParams { + uint32_t B; + uint32_t T; + uint32_t C; + }; + setLogLevel(kError); + return createKernel(ctx, {kShaderEncoderBackward, 256, kf32}, + Bindings{dwte, dwpe, dout, inp}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + EncoderParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }); +} + +Kernel layernorm_forward(Context& ctx, Tensor& out, Tensor& mean, Tensor& rstd, + Tensor& inp, Tensor& weight, Tensor& bias, + int B, int T, int C){ + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + struct LayerNormParams { + uint32_t B; + uint32_t T; + uint32_t C; + }; + setLogLevel(kError); + return createKernel(ctx, {kShaderLayerNorm, 256, kf32}, + Bindings{inp, weight, bias, out, mean, rstd}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }); +} + +Kernel layernorm_backward(Context& ctx, Tensor& dinp, Tensor& dweight, Tensor& dbias, + Tensor& dout, Tensor& inp, Tensor& weight, Tensor& mean, Tensor& rstd, + int B, int T, int C){ + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + struct LayerNormParams { + uint32_t B; + uint32_t T; + uint32_t C; + }; + setLogLevel(kError); + return createKernel(ctx, {kShaderLayerNormBackward, 256, kf32}, + Bindings{dinp, dweight, dbias, dout, inp, weight, mean, rstd}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + LayerNormParams{ + static_cast(b), + static_cast(t), + static_cast(c) + }); +} + +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + + inline DurationTime(const std::string& src, bool verbose = true) { + this->src = src; + this->verbose = verbose; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count())); + } + } +}; + + +Kernel matmul_forward(Context& ctx, Tensor& out, + const Tensor& inp, const Tensor& weight, const Tensor& bias, + int B, int T, int C, int OC){ + bool verbose = false; + DurationTime duration("matmul_forward_gpu", verbose); + struct MatmulParams { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t OC; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long oc = static_cast(OC); + setLogLevel(kError); + + constexpr size_t BT = 64; + constexpr size_t BC = 8; + constexpr size_t BOC = 64; + constexpr size_t TT = BT / BC; + constexpr size_t TOC = BOC / BC; + size_t num_threads = BT * BOC / (TT * TOC); + Shape wgSize = {num_threads, 1, 1}; + Shape nWorkgroups = {b, cdiv(T, BT), cdiv(OC, BOC)}; + + std::string kShaderMatmul2DTiling_(kShaderMatmul2DTiling); + std::string kShaderMatmul2D(loopUnrolling( + replaceAll(kShaderMatmul2DTiling_, + {{"{{precision}}", toString(kf32)}, + {"{{BT}}", toString(BT)}, + {"{{BC}}", toString(BC)}, + {"{{BOC}}", toString(BOC)}, + {"{{TT}}", toString(TT)}, + {"{{TOC}}", toString(TOC)}, + {"{{NUM_TILEI}}", toString(BT * BC / num_threads)}, + {"{{NUM_TILEW}}", toString(BOC * BC / num_threads)} + }) + ) + ); + + return createKernel(ctx, {kShaderMatmul2D, wgSize, kf32}, + Bindings{inp, weight, bias, out}, + nWorkgroups, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); +} + +Kernel matmul_backward(Context& ctx, Tensor& dinp, Tensor& dweight, Tensor& dbias, + const Tensor& dout, const Tensor& inp, const Tensor& weight, + int B, int T, int C, int OC){ + struct MatmulParams { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t OC; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long oc = static_cast(OC); + setLogLevel(kError); + return createKernel(ctx, {kShaderMatmulBackward, 256, kf32}, + Bindings{dinp, dweight, dbias, dout, inp, weight}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + MatmulParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(oc) + }); +} + +Kernel attention_forward(Context& ctx, Tensor& out, Tensor& preatt, Tensor& att, + Tensor& inp, + int B, int T, int C, int NH){ + struct AttentionParams { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t NH; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long nh = static_cast(NH); + setLogLevel(kError); + return createKernel(ctx, {kShaderAttention, 256, kf32}, + Bindings{inp, preatt, att, out}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }); +} + +Kernel attention_backward(Context& ctx, Tensor& dinp, Tensor& dpreatt, Tensor& datt, + Tensor& dout, Tensor& inp, Tensor& att, + int B, int T, int C, int NH){ + struct AttentionParams { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t NH; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long c = static_cast(C); + unsigned long nh = static_cast(NH); + setLogLevel(kError); + return createKernel(ctx, {kShaderAttentionBackward, 256, kf32}, + Bindings{dinp, dpreatt, datt, dout, inp, att}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + AttentionParams{ + static_cast(b), + static_cast(t), + static_cast(c), + static_cast(nh) + }); +} + +Kernel gelu_forward(Context& ctx, Tensor& out, Tensor& inp, int n) { + unsigned long N = static_cast(n); + setLogLevel(kError); + return createKernel(ctx, {kShaderGelu, 256, kf32}, + Bindings{inp, out}, + /* nWorkgroups */ {cdiv(N, 256), 1, 1}); +} + +Kernel gelu_backward(Context& ctx, Tensor& dinp, Tensor& inp, Tensor& dout, int N){ + unsigned long n = static_cast(N); + setLogLevel(kError); + return createKernel(ctx, {kShaderGeluBackward, 256, kf32}, + Bindings{inp, dout, dinp}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}); +} + +Kernel residual_forward(Context& ctx, Tensor& out, Tensor& inp1, Tensor& inp2, int N){ + unsigned long n = static_cast(N); + setLogLevel(kError); + return createKernel(ctx, {kShaderResidual, 256, kf32}, + Bindings{inp1, inp2, out}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}); +} + +Kernel residual_backward(Context& ctx, Tensor& dinp1, Tensor& dinp2, Tensor& dout, int N){ + unsigned long n = static_cast(N); + setLogLevel(kError); + return createKernel(ctx, {kShaderResidualBackward, 256, kf32}, + Bindings{dout, dinp1, dinp2}, + /* nWorkgroups */ {cdiv(n, 256), 1, 1}); +} + +Kernel softmax_forward(Context& ctx, Tensor& probs, Tensor& logits, int B, int T, int V, int Vp) { + struct SoftmaxParam { + uint32_t N; + uint32_t C; + uint32_t Cp; + }; + uint32_t b = static_cast(B); + uint32_t t = static_cast(T); + uint32_t c = static_cast(V); + uint32_t cp = static_cast(Vp); + assert( (B*T) % 256 == 0); + return createKernel( + ctx, {kShaderSoftmax1, 256, kf32}, Bindings{logits, probs}, + Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp}); +} + +Kernel crossentropy_forward(Context& ctx, Tensor& losses, + Tensor& probs, Tensor& targets, + int B, int T, int Vp){ + struct CrossEntropyParams { + uint32_t B; + uint32_t T; + uint32_t VP; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long vp = static_cast(Vp); + setLogLevel(kError); + return createKernel(ctx, {kShaderCrossEntropyForward, 256, kf32}, + Bindings{losses, probs, targets}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropyParams{ + static_cast(b), + static_cast(t), + static_cast(vp) + }); +} + +Kernel crossentropy_softmax_backward(Context& ctx, Tensor& dlogits, + Tensor& dlosses, Tensor& probs, Tensor& targets, + int B, int T, int V, int Vp){ + struct CrossEntropySoftmaxBackwardParams { + uint32_t B; + uint32_t T; + uint32_t V; + uint32_t VP; + }; + unsigned long b = static_cast(B); + unsigned long t = static_cast(T); + unsigned long v = static_cast(V); + unsigned long vp = static_cast(Vp); + setLogLevel(kError); + return createKernel(ctx, {kShaderCrossEntropySoftmaxBackward, 256, kf32}, + Bindings{dlogits, dlosses, probs, targets}, + /* nWorkgroups */ {cdiv(b * t, 256), 1, 1}, + /* params */ + CrossEntropySoftmaxBackwardParams{ + static_cast(b), + static_cast(t), + static_cast(v), + static_cast(vp) + }); +} diff --git a/experimental/kernels/ops_aot.hpp b/experimental/kernels/ops_aot.hpp new file mode 100644 index 0000000..8ec6d8e --- /dev/null +++ b/experimental/kernels/ops_aot.hpp @@ -0,0 +1,108 @@ +#ifndef OPS_H +#define OPS_H + +#include "gpu.hpp" + +using namespace gpu; + +#ifdef __cplusplus +extern "C" { +#endif + +#define VOCAB_SIZE 50257 + +// See https://github.com/google/dawn/blob/a8fbe981a86cb59536e2de423d2013a82d9b54a0/src/dawn/native/Limits.cpp +#define LIMITS_BUFFER_SIZE_1GB { \ + .nextInChain = nullptr, \ + .limits = { \ + .maxTextureDimension1D=8192, \ + .maxTextureDimension2D=8192, \ + .maxTextureDimension3D=2048, \ + .maxTextureArrayLayers=256, \ + .maxBindGroups=4, \ + .maxBindGroupsPlusVertexBuffers=24, \ + .maxBindingsPerBindGroup=1000, \ + .maxDynamicUniformBuffersPerPipelineLayout=8, \ + .maxDynamicStorageBuffersPerPipelineLayout=4, \ + .maxSampledTexturesPerShaderStage=16, \ + .maxSamplersPerShaderStage=16, \ + .maxStorageBuffersPerShaderStage=8, \ + .maxStorageTexturesPerShaderStage=4, \ + .maxUniformBuffersPerShaderStage=12, \ + .maxUniformBufferBindingSize=65536, \ + .maxStorageBufferBindingSize=1073741824, \ + .minUniformBufferOffsetAlignment=256, \ + .minStorageBufferOffsetAlignment=256, \ + .maxVertexBuffers=8, \ + .maxBufferSize=0x80000000, \ + .maxVertexAttributes=16, \ + .maxVertexBufferArrayStride=2048, \ + .maxInterStageShaderComponents=64, \ + .maxInterStageShaderVariables=16, \ + .maxColorAttachments=8, \ + .maxColorAttachmentBytesPerSample=32, \ + .maxComputeWorkgroupStorageSize=16384, \ + .maxComputeInvocationsPerWorkgroup=256, \ + .maxComputeWorkgroupSizeX=256, \ + .maxComputeWorkgroupSizeY=256, \ + .maxComputeWorkgroupSizeZ=64, \ + .maxComputeWorkgroupsPerDimension=65535 \ + } \ + } + + +Kernel encoder_forward(Context& ctx, Tensor& out, + Tensor& inp, Tensor& wte, Tensor& wpe, + int B, int T, int C); + +Kernel encoder_backward(Context& ctx, Tensor& dwte, Tensor& dwpe, + Tensor& dout, Tensor& inp, + int B, int T, int C); + +Kernel layernorm_forward(Context& ctx, Tensor& out, Tensor& mean, Tensor& rstd, + Tensor& inp, Tensor& weight, Tensor& bias, + int B, int T, int C); + +Kernel layernorm_backward(Context& ctx, Tensor& dinp, Tensor& dweight, Tensor& dbias, + Tensor& dout, Tensor& inp, Tensor& weight, Tensor& mean, Tensor& rstd, + int B, int T, int C); + +Kernel matmul_forward(Context& ctx, Tensor& out, + const Tensor& inp, const Tensor& weight, const Tensor& bias, + int B, int T, int C, int OC); + +Kernel matmul_backward(Context& ctx, Tensor& dinp, Tensor& dweight, Tensor& dbias, + const Tensor& dout, const Tensor& inp, const Tensor& weight, + int B, int T, int C, int OC); + +Kernel attention_forward(Context& ctx, Tensor& out, Tensor& preatt, Tensor& att, + Tensor& inp, + int B, int T, int C, int NH); + +Kernel attention_backward(Context& ctx, Tensor& dinp, Tensor& dpreatt, Tensor& datt, + Tensor& dout, Tensor& inp, Tensor& att, + int B, int T, int C, int NH); + +Kernel gelu_forward(Context& ctx, Tensor& out, Tensor& inp, int N); + +Kernel gelu_backward(Context& ctx, Tensor& dinp, Tensor& inp, Tensor& dout, int N); + +Kernel residual_forward(Context& ctx, Tensor& out, Tensor& inp1, Tensor& inp2, int N); + +Kernel residual_backward(Context& ctx, Tensor& dinp1, Tensor& dinp2, Tensor& dout, int N); + +Kernel softmax_forward(Context& ctx, Tensor& probs, Tensor& logits, int B, int T, int V, int Vp); + +Kernel crossentropy_forward(Context& ctx, Tensor& losses, + Tensor& probs, Tensor& targets, + int B, int T, int Vp); + +Kernel crossentropy_softmax_backward(Context& ctx, Tensor& dlogits, + Tensor& dlosses, Tensor& probs, Tensor& targets, + int B, int T, int V, int Vp); + +#ifdef __cplusplus +} +#endif + +#endif // OPS_H From 43e4ac0e594cdb00a71c5ced3170441b94b73ba2 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Tue, 22 Oct 2024 01:31:35 +0900 Subject: [PATCH 17/44] Update --- experimental/kernels/Makefile | 8 +- experimental/kernels/gpt2_webgpu_aot.cpp | 586 +++++++++++++++-------- experimental/kernels/ops_aot.hpp | 8 - 3 files changed, 393 insertions(+), 209 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index 5817e23..7430a71 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -99,6 +99,10 @@ build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp mkdir -p build $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu.cpp ops.cpp +build/gpt2_webgpu_aot: llm.c gpt2_124M.bin llm.c gpt2_webgpu_aot.cpp ops_aot.cpp + mkdir -p build + $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu_aot.cpp ops_aot.cpp + build/gpt2_webgpu.html: check-emsdk gpt2_webgpu.cpp term.html llm.c em++ gpt2_webgpu.cpp ops.cpp \ --preload-file gpt2_tokenizer.bin@/gpt2_tokenizer.bin \ @@ -116,8 +120,8 @@ watch-web: watch-native: ls *.cpp *.c *.hpp *.h | entr -s "rm -f build/gpt2_webgpu && rm -f build/ops.o && make build/gpt2_webgpu" -run-native: build/gpt2_webgpu - . $(GPUCPP)/source && ./build/gpt2_webgpu +run-native: build/gpt2_webgpu_aot + . $(GPUCPP)/source && ./build/gpt2_webgpu_aot # server: build/train_gpt2.html build/test_gpt2.html build/gpt2_gpucpp.html server: build/gpt2_webgpu.html diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index 0c136f7..e0a1d54 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -1,5 +1,5 @@ #include "gpu.hpp" -#include "ops.hpp" +#include "ops_aot.hpp" /* This file trains the GPT-2 model. This version is the clean, minimal, reference. As such: @@ -91,25 +91,25 @@ void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) { } // allocate memory for the parameters and point the individual tensors to the right places -float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) { - size_t num_parameters = 0; - for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { - num_parameters += param_sizes[i]; - } - // malloc all parameters all at once - float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); - // assign all the tensors - float** ptrs[] = { - ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, - ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, - ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb - }; - float* params_memory_iterator = params_memory; - for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { - *(ptrs[i]) = params_memory_iterator; - params_memory_iterator += param_sizes[i]; +void malloc_and_point_parameters(Context& ctx, ParameterTensors* params, size_t* param_sizes) { + params->wte = createTensor(ctx, Shape{param_sizes[0]}, kf32); + params->wpe = createTensor(ctx, Shape{param_sizes[1]}, kf32); + for(int l = 0; l < NUM_PARAMETER_LAYERS; l++) { + params->ln1w.push_back(createTensor(ctx, Shape{param_sizes[2]/NUM_PARAMETER_LAYERS}, kf32)); + params->ln1b.push_back(createTensor(ctx, Shape{param_sizes[3]/NUM_PARAMETER_LAYERS}, kf32)); + params->qkvw.push_back(createTensor(ctx, Shape{param_sizes[4]/NUM_PARAMETER_LAYERS}, kf32)); + params->qkvb.push_back(createTensor(ctx, Shape{param_sizes[5]/NUM_PARAMETER_LAYERS}, kf32)); + params->attprojw.push_back(createTensor(ctx, Shape{param_sizes[6]/NUM_PARAMETER_LAYERS}, kf32)); + params->attprojb.push_back(createTensor(ctx, Shape{param_sizes[7]/NUM_PARAMETER_LAYERS}, kf32)); + params->ln2w.push_back(createTensor(ctx, Shape{param_sizes[8]/NUM_PARAMETER_LAYERS}, kf32)); + params->ln2b.push_back(createTensor(ctx, Shape{param_sizes[9]/NUM_PARAMETER_LAYERS}, kf32)); + params->fcw.push_back(createTensor(ctx, Shape{param_sizes[10]/NUM_PARAMETER_LAYERS}, kf32)); + params->fcb.push_back(createTensor(ctx, Shape{param_sizes[11]/NUM_PARAMETER_LAYERS}, kf32)); + params->fcprojw.push_back(createTensor(ctx, Shape{param_sizes[12]/NUM_PARAMETER_LAYERS}, kf32)); + params->fcprojb.push_back(createTensor(ctx, Shape{param_sizes[13]/NUM_PARAMETER_LAYERS}, kf32)); } - return params_memory; + params->lnfw = createTensor(ctx, Shape{param_sizes[14]}, kf32); + params->lnfb = createTensor(ctx, Shape{param_sizes[15]}, kf32); } @@ -154,7 +154,7 @@ typedef struct { Kernel layernorm_final_forward; Kernel matmul_final_forward; Kernel softmax_final_forward; - std::vector crossentropy_forward; + Kernel crossentropy_forward; Kernel crossentropy_softmax_backward; Kernel matmul_final_backward; @@ -201,24 +201,32 @@ void fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T act_sizes[22] = B * T; // losses } -float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) { - size_t num_activations = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - num_activations += act_sizes[i]; +void malloc_and_point_activations(Context& ctx, ActivationTensors* acts, size_t* act_sizes) { + acts->encoded = createTensor(ctx, Shape{act_sizes[0]}, kf32); + for (int l = 0; l < NUM_PARAMETER_LAYERS; l++) { + acts->ln1.push_back(createTensor(ctx, Shape{act_sizes[1]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln1_mean.push_back(createTensor(ctx, Shape{act_sizes[2]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln1_rstd.push_back(createTensor(ctx, Shape{act_sizes[3]/NUM_PARAMETER_LAYERS}, kf32)); + acts->qkv.push_back(createTensor(ctx, Shape{act_sizes[4]/NUM_PARAMETER_LAYERS}, kf32)); + acts->atty.push_back(createTensor(ctx, Shape{act_sizes[5]/NUM_PARAMETER_LAYERS}, kf32)); + acts->preatt.push_back(createTensor(ctx, Shape{act_sizes[6]/NUM_PARAMETER_LAYERS}, kf32)); + acts->att.push_back(createTensor(ctx, Shape{act_sizes[7]/NUM_PARAMETER_LAYERS}, kf32)); + acts->attproj.push_back(createTensor(ctx, Shape{act_sizes[8]/NUM_PARAMETER_LAYERS}, kf32)); + acts->residual2.push_back(createTensor(ctx, Shape{act_sizes[9]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln2.push_back(createTensor(ctx, Shape{act_sizes[10]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln2_mean.push_back(createTensor(ctx, Shape{act_sizes[11]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln2_rstd.push_back(createTensor(ctx, Shape{act_sizes[12]/NUM_PARAMETER_LAYERS}, kf32)); + acts->fch.push_back(createTensor(ctx, Shape{act_sizes[13]/NUM_PARAMETER_LAYERS}, kf32)); + acts->fch_gelu.push_back(createTensor(ctx, Shape{act_sizes[14]/NUM_PARAMETER_LAYERS}, kf32)); + acts->fcproj.push_back(createTensor(ctx, Shape{act_sizes[15]/NUM_PARAMETER_LAYERS}, kf32)); + acts->residual3.push_back(createTensor(ctx, Shape{act_sizes[16]/NUM_PARAMETER_LAYERS}, kf32)); } - float* acts_memory = (float*)mallocCheck(num_activations * sizeof(float)); - float** ptrs[] = { - &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty, - &acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, - &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, - &acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses - }; - float* acts_memory_iterator = acts_memory; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - *(ptrs[i]) = acts_memory_iterator; - acts_memory_iterator += act_sizes[i]; - } - return acts_memory; + acts->lnf = createTensor(ctx, Shape{act_sizes[17]}, kf32); + acts->lnf_mean = createTensor(ctx, Shape{act_sizes[18]}, kf32); + acts->lnf_rstd = createTensor(ctx, Shape{act_sizes[19]}, kf32); + acts->logits = createTensor(ctx, Shape{act_sizes[20]}, kf32); + acts->probs = createTensor(ctx, Shape{act_sizes[21]}, kf32); + acts->losses = createTensor(ctx, Shape{act_sizes[22]}, kf32); } struct GPUParameters { @@ -240,7 +248,6 @@ typedef struct { GPT2Config config; // the weights (parameters) of the model, and their sizes ParameterTensors params; - GPUParameters params_; // TODO(avh): eventually this replaces params size_t param_sizes[NUM_PARAMETER_TENSORS]; float* params_memory; size_t num_parameters; @@ -252,7 +259,6 @@ typedef struct { float* v_memory; // the activations of the model, and their sizes ActivationTensors acts; - GPUActivations acts_; // TODO(avh): eventually this replaces params size_t act_sizes[NUM_ACTIVATION_TENSORS]; float* acts_memory; size_t num_activations; @@ -262,13 +268,16 @@ typedef struct { // other run state configuration int batch_size; // the batch size (B) of current forward pass int seq_len; // the sequence length (T) of current forward pass - int* inputs; // the input tokens for the current forward pass - int* targets; // the target tokens for the current forward pass + Tensor inputs; // the input tokens for the current forward pass + Tensor targets; // the target tokens for the current forward pass float mean_loss; // after a forward pass with targets, will be populated with the mean loss + + // kernels + Kernels kernels; } GPT2; void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoint_path) { - + printf("Building GPT-2 model from checkpoint '%s'\n", checkpoint_path); // read in model from a checkpoint file FILE *model_file = fopenCheck(checkpoint_path, "rb"); int model_header[256]; @@ -302,7 +311,6 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin // allocate space for all the parameters and read them in fill_in_parameter_sizes(model->param_sizes, model->config); - // count the number of parameters size_t num_parameters = 0; for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { @@ -312,29 +320,65 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin model->num_parameters = num_parameters; // read in all the parameters from file - model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes); + malloc_and_point_parameters(ctx, &model->params, model->param_sizes); + model->params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); freadCheck(model->params_memory, sizeof(float), num_parameters, model_file); fcloseCheck(model_file); + // transfer to GPU memory + float* iter = model->params_memory; + toGPU(ctx, iter, model->params.wte); + iter += model->param_sizes[0]; + toGPU(ctx, iter, model->params.wpe); + iter += model->param_sizes[1]; + for (int l = 0; l < L; l++) { + toGPU(ctx, iter, model->params.ln1w[l]); + iter += model->param_sizes[2]/L; + toGPU(ctx, iter, model->params.ln1b[l]); + iter += model->param_sizes[3]/L; + toGPU(ctx, iter, model->params.qkvw[l]); + iter += model->param_sizes[4]/L; + toGPU(ctx, iter, model->params.qkvb[l]); + iter += model->param_sizes[5]/L; + toGPU(ctx, iter, model->params.attprojw[l]); + iter += model->param_sizes[6]/L; + toGPU(ctx, iter, model->params.attprojb[l]); + iter += model->param_sizes[7]/L; + toGPU(ctx, iter, model->params.ln2w[l]); + iter += model->param_sizes[8]/L; + toGPU(ctx, iter, model->params.ln2b[l]); + iter += model->param_sizes[9]/L; + toGPU(ctx, iter, model->params.fcw[l]); + iter += model->param_sizes[10]/L; + toGPU(ctx, iter, model->params.fcb[l]); + iter += model->param_sizes[11]/L; + toGPU(ctx, iter, model->params.fcprojw[l]); + iter += model->param_sizes[12]/L; + toGPU(ctx, iter, model->params.fcprojb[l]); + iter += model->param_sizes[13]/L; + } + toGPU(ctx, iter, model->params.lnfw); + iter += model->param_sizes[14]; + toGPU(ctx, iter, model->params.lnfb); + iter += model->param_sizes[15]; + + // other inits model->acts_memory = NULL; model->grads_memory = NULL; model->m_memory = NULL; model->v_memory = NULL; model->grads_acts_memory = NULL; - model->inputs = NULL; - model->targets = NULL; model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss - // TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations - gpu_alloc(ctx, model->params_.data, model->param_sizes, NUM_PARAMETER_TENSORS); + printf("Model build complete\n"); } -void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { +void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, size_t B, size_t T) { // targets are optional and could be NULL // ensure the model was initialized or error out @@ -350,13 +394,13 @@ void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B size_t NH = model->config.num_heads; size_t C = model->config.channels; - // validate inputs, all indices must be in the range [0, V) - for(int i = 0; i < B * T; i++) { - assert(0 <= inputs[i] && inputs[i] < V); - if (targets != NULL) { - assert(0 <= targets[i] && targets[i] < V); - } - } + // // validate inputs, all indices must be in the range [0, V) + // for(int i = 0; i < B * T; i++) { + // assert(0 <= inputs[i] && inputs[i] < V); + // if (targets != NULL) { + // assert(0 <= targets[i] && targets[i] < V); + // } + // } // allocate space for all the activations if needed (done here, lazily) if(model->acts_memory == NULL) { @@ -365,8 +409,8 @@ void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B model->seq_len = T; // and now allocate the space fill_in_activation_sizes(model->act_sizes, model->config, B, T); + // TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations - gpu_alloc(ctx, model->acts_.data, model->act_sizes, NUM_PARAMETER_TENSORS); size_t num_activations = 0; for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { num_activations += model->act_sizes[i]; @@ -374,10 +418,12 @@ void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B printf("num_activations: %zu\n", num_activations); model->num_activations = num_activations; printf("Allocating %.2f MB for activations\n", num_activations * sizeof(float) / (1024.0f * 1024.0f)); - model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes); + malloc_and_point_activations(ctx, &model->acts, model->act_sizes); // also create memory for caching inputs and targets - model->inputs = (int*)mallocCheck(B * T * sizeof(int)); - model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small + //model->inputs = (int*)mallocCheck(B * T * sizeof(int)); + //model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small + model->inputs = createTensor(ctx, Shape{B * T}, ki32); + model->targets = createTensor(ctx, Shape{B * T}, ki32); } else { // validate B,T is consistent with how we've allocated the memory before // in principle we could get more clever here in the future, for now this is safest @@ -386,99 +432,202 @@ void gpt2_forward(Context& ctx, GPT2 *model, int* inputs, int* targets, size_t B exit(EXIT_FAILURE); } } - - printf("Cache inputs/targets\n"); - // cache the inputs/targets - memcpy(model->inputs, inputs, B * T * sizeof(int)); - if (targets != NULL) { - memcpy(model->targets, targets, B * T * sizeof(int)); + // create all kernels ahead of time + if (model->kernels.encoder_forward == nullptr) { + printf("Creating Kernels\n"); + Kernels& kernels = model->kernels; + kernels.layernorm_forward.resize(L); + kernels.layernorm1_backward.resize(L); + kernels.qkv_projection_forward.resize(L); + kernels.qkv_projection_backward.resize(L); + kernels.attention_forward.resize(L); + kernels.attention_backward.resize(L); + kernels.attention_projection_forward.resize(L); + kernels.attention_projection_backward.resize(L); + kernels.residual_forward.resize(L); + kernels.residual2_forward.resize(L); + kernels.residual2_backward.resize(L); + kernels.ff_up_forward.resize(L); + kernels.ff_up_backward.resize(L); + kernels.gelu_forward.resize(L); + kernels.gelu_backward.resize(L); + kernels.ff_down_forward.resize(L); + kernels.ff_down_backward.resize(L); + for (int l = 0; l < L; ++l) { + kernels.layernorm_forward[l] = layernorm_forward(ctx, model->acts.ln1[l], model->acts.ln1_mean[l], model->acts.ln1_rstd[l], + /*input=*/ model->acts.residual3[l], /*weight=*/ model->params.ln1w[l], /*bias=*/ model->params.ln1b[l], + B, T, C); + kernels.qkv_projection_forward[l] = matmul_forward(ctx, model->acts.qkv[l], model->acts.ln1[l], model->params.qkvw[l], model->params.qkvb[l], B, T, C, 3*C); + kernels.attention_forward[l] = attention_forward(ctx, model->acts.atty[l], model->acts.preatt[l], model->acts.att[l], model->acts.qkv[l], B, T, C, NH); + kernels.attention_projection_forward[l] = matmul_forward(ctx, model->acts.attproj[l], model->acts.atty[l], model->params.attprojw[l], model->params.attprojb[l], B, T, C, C); + kernels.residual_forward[l] = residual_forward(ctx, model->acts.residual2[l], model->acts.residual3[l], model->acts.attproj[l], B*T*C); + kernels.ff_up_forward[l] = matmul_forward(ctx, model->acts.fch[l], model->acts.ln2[l], model->params.fcw[l], model->params.fcb[l], B, T, C, 4*C); + kernels.gelu_forward[l] = gelu_forward(ctx, model->acts.fch_gelu[l], model->acts.fch[l], B*T*4*C); + kernels.ff_down_forward[l] = matmul_forward(ctx, model->acts.fcproj[l], model->acts.fch_gelu[l], model->params.fcw[l], model->params.fcb[l], B, T, 4*C, C); + kernels.residual2_forward[l] = residual_forward(ctx, model->acts.residual3[l], model->acts.residual2[l], model->acts.fcproj[l], B*T*C); + } + kernels.crossentropy_forward = crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); + + kernels.encoder_forward = encoder_forward(ctx, model->acts.encoded, inputs, model->params.wte, model->params.wpe, B, T, C); // encoding goes into residual[0] + kernels.encoder_backward = encoder_backward(ctx, model->params.wte, model->params.wpe, model->acts.encoded, inputs, B, T, C); + kernels.layernorm_final_forward = layernorm_forward(ctx, model->acts.lnf, model->acts.lnf_mean, model->acts.lnf_rstd, + /*input=*/ model->acts.residual3[L-1], /*weight=*/ model->params.lnfw, /*bias=*/ model->params.lnfb, + B, T, C); + Tensor nullTensor = createTensor(ctx, Shape{1}, kf32); + kernels.matmul_final_forward = matmul_forward(ctx, model->acts.logits, model->acts.lnf, model->params.wte, nullTensor, B, T, C, Vp); + kernels.softmax_final_forward = softmax_forward(ctx, model->acts.probs, model->acts.logits, B, T, V, Vp); + kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp); + kernels.matmul_final_backward = matmul_backward(ctx, model->acts.lnf, model->params.wte, nullTensor, model->acts.logits, + model->acts.lnf, model->params.wte, B, T, C, Vp); + kernels.layernorm_final_backward = layernorm_backward(ctx, model->acts.residual3[L-1], model->params.lnfw, model->params.lnfb, + model->acts.lnf, model->acts.residual3[L-1], model->params.lnfw, + model->acts.lnf_mean, model->acts.lnf_rstd, B, T, C); + printf("Created Kernels\n"); } + printf("Cache inputs/targets\n"); printf("Forward pass\n"); // forward pass ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; float* residual; printf("Encoding\n"); - printf("inputs[0] = %d\n", inputs[0]); - encoder_forward(ctx, acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0] + //printf("inputs[0] = %d\n", inputs[0]); + // encoder_forward(ctx, acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0] + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.encoder_forward, promise); + wait(ctx, future); + } for (int l = 0; l < L; l++) { printf("Forward Pass Layer %d\n", l); - residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; - - // get the pointers of the weights for this layer - float* l_ln1w = params.ln1w + l * C; - float* l_ln1b = params.ln1b + l * C; - float* l_qkvw = params.qkvw + l * 3*C * C; - float* l_qkvb = params.qkvb + l * 3*C; - float* l_attprojw = params.attprojw + l * C * C; - float* l_attprojb = params.attprojb + l * C; - float* l_ln2w = params.ln2w + l * C; - float* l_ln2b = params.ln2b + l * C; - float* l_fcw = params.fcw + l * 4*C * C; - float* l_fcb = params.fcb + l * 4*C; - float* l_fcprojw = params.fcprojw + l * C * 4*C; - float* l_fcprojb = params.fcprojb + l * C; - - // get the pointers of the activations for this layer - float* l_ln1 = acts.ln1 + l * B * T * C; - float* l_ln1_mean = acts.ln1_mean + l * B * T; - float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - float* l_qkv = acts.qkv + l * B * T * 3*C; - float* l_atty = acts.atty + l * B * T * C; - float* l_preatt = acts.preatt + l * B * NH * T * T; - float* l_att = acts.att + l * B * NH * T * T; - float* l_attproj = acts.attproj + l * B * T * C; - float* l_residual2 = acts.residual2 + l * B * T * C; - float* l_ln2 = acts.ln2 + l * B * T * C; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - float* l_fch = acts.fch + l * B * T * 4*C; - float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; - float* l_fcproj = acts.fcproj + l * B * T * C; - float* l_residual3 = acts.residual3 + l * B * T * C; - // now do the forward pass printf(" [Forward] : LayerNorm1\n"); - layernorm_forward(ctx, l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); + // layernorm_forward(ctx, l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : QKV Projection\n"); - matmul_forward(ctx, l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + // matmul_forward(ctx, l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.qkv_projection_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : Attention\n"); - attention_forward(ctx, l_atty, l_preatt, l_att, l_qkv, B, T, C, NH); + // attention_forward(ctx, l_atty, l_preatt, l_att, l_qkv, B, T, C, NH); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.attention_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : Attention Projection\n"); - matmul_forward(ctx, l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + // matmul_forward(ctx, l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.attention_projection_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : Residual1\n"); - residual_forward(ctx, l_residual2, residual, l_attproj, B*T*C); + // residual_forward(ctx, l_residual2, residual, l_attproj, B*T*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.residual_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : LayerNorm2\n"); - layernorm_forward(ctx, l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); + // layernorm_forward(ctx, l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm2_backward[l], promise); + wait(ctx, future); + } printf(" [Forward] : FF Up\n"); - matmul_forward(ctx, l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + // matmul_forward(ctx, l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.ff_up_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : GELU\n"); - gelu_forward(ctx, l_fch_gelu, l_fch, B*T*4*C); + // gelu_forward(ctx, l_fch_gelu, l_fch, B*T*4*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.gelu_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : FF Down\n"); - matmul_forward(ctx, l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + // matmul_forward(ctx, l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.ff_down_forward[l], promise); + wait(ctx, future); + } printf(" [Forward] : Residual2\n"); - residual_forward(ctx, l_residual3, l_residual2, l_fcproj, B*T*C); + // residual_forward(ctx, l_residual3, l_residual2, l_fcproj, B*T*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.residual2_forward[l], promise); + wait(ctx, future); + } + } + // residual = acts.residual3.data() + (L-1) * B * T * C; // last residual is in residual3 + // layernorm_forward(ctx, acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm_final_forward, promise); + wait(ctx, future); + } + // matmul_forward(ctx, acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.matmul_final_forward, promise); + wait(ctx, future); + } + // softmax_forward(ctx, acts.probs, acts.logits, B, T, V, Vp); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.softmax_final_forward, promise); + wait(ctx, future); } - residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 - layernorm_forward(ctx, acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); - matmul_forward(ctx, acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp); - softmax_forward(ctx, acts.probs, acts.logits, B, T, V, Vp); printf("Crossentropy\n"); // also forward the cross-entropy loss function if we have the targets - if (targets != NULL) { - crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); + // if (targets != NULL) { + // crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.crossentropy_forward, promise); + wait(ctx, future); + } // for convenience also evaluate the mean loss float mean_loss = 0.0f; - for (int i=0; iacts.losses[i]; } + //toCPU(ctx, model->acts_.data[22], model->acts.losses.data, model->act_sizes[22] * sizeof(float)); + for (int i=0; iacts.losses.data[i]; } mean_loss /= B*T; model->mean_loss = mean_loss; - } else { - // if we don't have targets, we don't have a loss - model->mean_loss = -1.0f; - } + // } else { + // // if we don't have targets, we don't have a loss + // model->mean_loss = -1.0f; + // } printf("Forward pass done\n"); } @@ -499,8 +648,8 @@ void gpt2_backward(Context& ctx, GPT2 *model) { // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f)); - model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes); - model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, model->act_sizes); + malloc_and_point_parameters(&model->grads, model->param_sizes); + malloc_and_point_activations(&model->grads_acts, model->act_sizes); gpt2_zero_grad(model); } @@ -523,90 +672,124 @@ void gpt2_backward(Context& ctx, GPT2 *model) { // technically this is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch float dloss_mean = 1.0f / (B*T); - for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; } - - crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp); - matmul_backward(ctx, grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp); - float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual - float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual - layernorm_backward(ctx, dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C); + for (int i = 0; i < B*T; i++) { grads_acts.losses.data[i] = dloss_mean; } + toGPU(ctx, grads_acts.losses.data, model->acts_.data[22]); + + // crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.crossentropy_softmax_backward, promise); + wait(ctx, future); + } + // matmul_backward(ctx, grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.matmul_final_backward, promise); + wait(ctx, future); + } + // layernorm_backward(ctx, dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm_final_backward, promise); + wait(ctx, future); + } for (int l = L-1; l >= 0; l--) { printf("Backward Pass Layer %d\n", l); - - residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; - dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C; - - // get the pointers of the weights for this layer - float* l_ln1w = params.ln1w + l * C; - float* l_qkvw = params.qkvw + l * 3*C * C; - float* l_attprojw = params.attprojw + l * C * C; - float* l_ln2w = params.ln2w + l * C; - float* l_fcw = params.fcw + l * 4*C * C; - float* l_fcprojw = params.fcprojw + l * C * 4*C; - // get the pointers of the gradients of the weights for this layer - float* dl_ln1w = grads.ln1w + l * C; - float* dl_ln1b = grads.ln1b + l * C; - float* dl_qkvw = grads.qkvw + l * 3*C * C; - float* dl_qkvb = grads.qkvb + l * 3*C; - float* dl_attprojw = grads.attprojw + l * C * C; - float* dl_attprojb = grads.attprojb + l * C; - float* dl_ln2w = grads.ln2w + l * C; - float* dl_ln2b = grads.ln2b + l * C; - float* dl_fcw = grads.fcw + l * 4*C * C; - float* dl_fcb = grads.fcb + l * 4*C; - float* dl_fcprojw = grads.fcprojw + l * C * 4*C; - float* dl_fcprojb = grads.fcprojb + l * C; - // get the pointers of the activations for this layer - float* l_ln1 = acts.ln1 + l * B * T * C; - float* l_ln1_mean = acts.ln1_mean + l * B * T; - float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - float* l_qkv = acts.qkv + l * B * T * 3*C; - float* l_atty = acts.atty + l * B * T * C; - float* l_att = acts.att + l * B * NH * T * T; - float* l_residual2 = acts.residual2 + l * B * T * C; - float* l_ln2 = acts.ln2 + l * B * T * C; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - float* l_fch = acts.fch + l * B * T * 4*C; - float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; - // get the pointers of the gradients of the activations for this layer - float* dl_ln1 = grads_acts.ln1 + l * B * T * C; - float* dl_qkv = grads_acts.qkv + l * B * T * 3*C; - float* dl_atty = grads_acts.atty + l * B * T * C; - float* dl_preatt = grads_acts.preatt + l * B * NH * T * T; - float* dl_att = grads_acts.att + l * B * NH * T * T; - float* dl_attproj = grads_acts.attproj + l * B * T * C; - float* dl_residual2 = grads_acts.residual2 + l * B * T * C; - float* dl_ln2 = grads_acts.ln2 + l * B * T * C; - float* dl_fch = grads_acts.fch + l * B * T * 4*C; - float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C; - float* dl_fcproj = grads_acts.fcproj + l * B * T * C; - float* dl_residual3 = grads_acts.residual3 + l * B * T * C; - // backprop this layer printf(" [Backward] : Residual2\n"); - residual_backward(ctx, dl_residual2, dl_fcproj, dl_residual3, B*T*C); + // residual_backward(ctx, dl_residual2, dl_fcproj, dl_residual3, B*T*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.residual2_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : FF Down \n"); - matmul_backward(ctx, dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C); + // matmul_backward(ctx, dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.ff_down_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : GELU\n"); - gelu_backward(ctx, dl_fch, l_fch, dl_fch_gelu, B*T*4*C); + // gelu_backward(ctx, dl_fch, l_fch, dl_fch_gelu, B*T*4*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.gelu_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : FF Up\n"); - matmul_backward(ctx, dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C); + // matmul_backward(ctx, dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.ff_up_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : LayerNorm2\n"); - layernorm_backward(ctx, dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C); + // layernorm_backward(ctx, dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm2_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : Residual1\n"); - residual_backward(ctx, dresidual, dl_attproj, dl_residual2, B*T*C); + // residual_backward(ctx, dresidual, dl_attproj, dl_residual2, B*T*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.residual_forward[l], promise); + wait(ctx, future); + } printf(" [Backward] : Attention Projection\n"); - matmul_backward(ctx, dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C); + // matmul_backward(ctx, dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.attention_projection_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : Attention\n"); - attention_backward(ctx, dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH); + // attention_backward(ctx, dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.attention_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : QKV Projection\n"); - matmul_backward(ctx, dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C); + // matmul_backward(ctx, dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.qkv_projection_backward[l], promise); + wait(ctx, future); + } printf(" [Backward] : LayerNorm1\n"); - layernorm_backward(ctx, dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); + // layernorm_backward(ctx, dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.layernorm1_backward[l], promise); + wait(ctx, future); + } + } + // encoder_backward(ctx, grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C); + { + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, model->kernels.encoder_backward, promise); + wait(ctx, future); } - encoder_backward(ctx, grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C); + toCPU(ctx, model->params_.data[0], model->grads.wte.data, model->param_sizes[0] * sizeof(float)); + toCPU(ctx, model->params_.data[1], model->grads.wpe.data, model->param_sizes[1] * sizeof(float)); } void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) { @@ -635,6 +818,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo model->v_memory[i] = v; model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param); } + toGPU(ctx, model->params_memory, model->params_.data[0]); + toGPU(ctx, model->params_memory + model->param_sizes[0], model->params_.data[1]); } void gpt2_free(GPT2 *model) { @@ -688,9 +873,11 @@ int main() { gpu::Context ctx = gpu::createContext({}, {}, { .requiredLimits = &requiredLimits }); - // gpu::Context ctx = gpu::createContext(); + +Continue! - // build the GPT-2 model from a checkpoint +```cpp + // build the GPT-2 model from a checkpoint GPT2 model; gpt2_build_from_checkpoint(ctx, &model, "gpt2_124M.bin"); @@ -719,12 +906,11 @@ int main() { int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); const int genT = 64; // number of steps of inference we will do - // train struct timespec start, end; printf("Starting training\n"); for (int step = 0; step <= 40; step++) { - printf("Step %d\n", step); + printf("Step %d\n", step); // once in a while estimate the validation loss if (step % 10 == 0) { @@ -757,7 +943,9 @@ int main() { // we're in principle running B "inference streams" in parallel here // but only using position 0 // get the Vp-dimensional vector probs[0, t-1, :] - float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size; + float* probs = model.acts.probs.data + (t-1) * model.config.padded_vocab_size; + toCPU(ctx, model.acts_.data[21], probs, (t-1) * model.config.padded_vocab_size * sizeof(float)); + float coin = random_f32(&rng_state); // note we're only sampling from the first V elements, ignoring padding // (the probabilities in the padded region should be zero anyway) diff --git a/experimental/kernels/ops_aot.hpp b/experimental/kernels/ops_aot.hpp index 8ec6d8e..5db9ff7 100644 --- a/experimental/kernels/ops_aot.hpp +++ b/experimental/kernels/ops_aot.hpp @@ -5,10 +5,6 @@ using namespace gpu; -#ifdef __cplusplus -extern "C" { -#endif - #define VOCAB_SIZE 50257 // See https://github.com/google/dawn/blob/a8fbe981a86cb59536e2de423d2013a82d9b54a0/src/dawn/native/Limits.cpp @@ -101,8 +97,4 @@ Kernel crossentropy_softmax_backward(Context& ctx, Tensor& dlogits, Tensor& dlosses, Tensor& probs, Tensor& targets, int B, int T, int V, int Vp); -#ifdef __cplusplus -} -#endif - #endif // OPS_H From 1d8e43577deda96675dee68960b3dc315670f4ca Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Tue, 22 Oct 2024 13:08:06 +0900 Subject: [PATCH 18/44] Update --- experimental/kernels/gpt2_webgpu_aot.cpp | 31 +++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index e0a1d54..c95ad3d 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -271,6 +271,9 @@ typedef struct { Tensor inputs; // the input tokens for the current forward pass Tensor targets; // the target tokens for the current forward pass float mean_loss; // after a forward pass with targets, will be populated with the mean loss + float* mean_loss_buffer; + + Tensor nullTensor; // kernels Kernels kernels; @@ -372,6 +375,8 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss + // Allocate B * C buffer for mean loss + model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len); printf("Model build complete\n"); @@ -474,6 +479,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si /*input=*/ model->acts.residual3[L-1], /*weight=*/ model->params.lnfw, /*bias=*/ model->params.lnfb, B, T, C); Tensor nullTensor = createTensor(ctx, Shape{1}, kf32); + model->nullTensor = nullTensor; kernels.matmul_final_forward = matmul_forward(ctx, model->acts.logits, model->acts.lnf, model->params.wte, nullTensor, B, T, C, Vp); kernels.softmax_final_forward = softmax_forward(ctx, model->acts.probs, model->acts.logits, B, T, V, Vp); kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp); @@ -829,8 +835,9 @@ void gpt2_free(GPT2 *model) { free(model->v_memory); free(model->acts_memory); free(model->grads_acts_memory); - free(model->inputs); - free(model->targets); + // free(model->inputs); + // free(model->targets); + free(model->mean_loss_buffer); } #ifndef TESTING @@ -874,9 +881,6 @@ int main() { .requiredLimits = &requiredLimits }); -Continue! - -```cpp // build the GPT-2 model from a checkpoint GPT2 model; gpt2_build_from_checkpoint(ctx, &model, "gpt2_124M.bin"); @@ -903,11 +907,14 @@ Continue! // some memory for generating samples from the model uint64_t rng_state = 1337; - int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); + // int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); const int genT = 64; // number of steps of inference we will do // train struct timespec start, end; + Tensor inputs = createTensor(ctx, Shape{B, T}, ki32); + Tensor targets = createTensor(ctx, Shape{B, T}, ki32); + Tensor gen_tokens = createTensor(ctx, Shape{B, T}, ki32); printf("Starting training\n"); for (int step = 0; step <= 40; step++) { printf("Step %d\n", step); @@ -918,7 +925,9 @@ Continue! dataloader_reset(&val_loader); for (int i = 0; i < val_num_batches; i++) { dataloader_next_batch(&val_loader); - gpt2_forward(ctx, &model, val_loader.inputs, val_loader.targets, B, T); + toGPU(ctx, val_loader.inputs, inputs); + toGPU(ctx, val_loader.targets, targets); + gpt2_forward(ctx, &model, inputs, targets, B, T); val_loss += model.mean_loss; } val_loss /= val_num_batches; @@ -928,9 +937,7 @@ Continue! // once in a while do model inference to print generated text if (step > 0 && step % 20 == 0) { // fill up gen_tokens with the GPT2_EOT, which kicks off the generation - for(int i = 0; i < B * T; ++i) { - gen_tokens[i] = tokenizer.eot_token; - } + toGPU(ctx, tokenizer.eot_token, gen_tokens); // now sample from the model autoregressively printf("generating:\n---\n"); for (int t = 1; t < genT; t++) { @@ -938,7 +945,7 @@ Continue! // we re-calculate the forward pass for all of (B,T) positions from scratch // but the inference here is just for sanity checking anyway // and we can maybe optimize a bit more later, with careful tests - gpt2_forward(ctx, &model, gen_tokens, NULL, B, T); + gpt2_forward(ctx, &model, gen_tokens, model.nullTensor, B, T); // furthermore, below we're only using b=0 (i.e. the first row) of all B rows // we're in principle running B "inference streams" in parallel here // but only using position 0 @@ -981,7 +988,7 @@ Continue! dataloader_free(&val_loader); tokenizer_free(&tokenizer); gpt2_free(&model); - free(gen_tokens); + // free(gen_tokens); return 0; } #endif From 49859306cf44d4db408f949b1ea61068d91557f7 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 23 Oct 2024 02:46:41 +0900 Subject: [PATCH 19/44] Update --- experimental/kernels/gpt2_webgpu_aot.cpp | 128 +++++++++++++++++++---- 1 file changed, 107 insertions(+), 21 deletions(-) diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index c95ad3d..cd474f9 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -272,6 +272,7 @@ typedef struct { Tensor targets; // the target tokens for the current forward pass float mean_loss; // after a forward pass with targets, will be populated with the mean loss float* mean_loss_buffer; + float* probs_buffer; Tensor nullTensor; @@ -377,6 +378,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin model->mean_loss = -1.0f; // -1.0f will designate no loss // Allocate B * C buffer for mean loss model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len); + model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp); printf("Model build complete\n"); @@ -616,7 +618,8 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si printf("Crossentropy\n"); // also forward the cross-entropy loss function if we have the targets - // if (targets != NULL) { + // When targets's shape is (1), it means we don't have targets + if (targets.shape[0] != 1) { // crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); { std::promise promise; @@ -627,13 +630,14 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si // for convenience also evaluate the mean loss float mean_loss = 0.0f; //toCPU(ctx, model->acts_.data[22], model->acts.losses.data, model->act_sizes[22] * sizeof(float)); - for (int i=0; iacts.losses.data[i]; } + toCPU(ctx, model->acts.losses, model->mean_loss_buffer, B*T * sizeof(float)); + for (int i=0; imean_loss_buffer[i]; } mean_loss /= B*T; model->mean_loss = mean_loss; - // } else { - // // if we don't have targets, we don't have a loss - // model->mean_loss = -1.0f; - // } + } else { + // if we don't have targets, we don't have a loss + model->mean_loss = -1.0f; + } printf("Forward pass done\n"); } @@ -654,8 +658,8 @@ void gpt2_backward(Context& ctx, GPT2 *model) { // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f)); - malloc_and_point_parameters(&model->grads, model->param_sizes); - malloc_and_point_activations(&model->grads_acts, model->act_sizes); + malloc_and_point_parameters(ctx, &model->grads, model->param_sizes); + malloc_and_point_activations(ctx, &model->grads_acts, model->act_sizes); gpt2_zero_grad(model); } @@ -678,8 +682,9 @@ void gpt2_backward(Context& ctx, GPT2 *model) { // technically this is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch float dloss_mean = 1.0f / (B*T); - for (int i = 0; i < B*T; i++) { grads_acts.losses.data[i] = dloss_mean; } - toGPU(ctx, grads_acts.losses.data, model->acts_.data[22]); + for (int i = 0; i < B*T; i++) { model->mean_loss_buffer[i] = dloss_mean; } + toGPU(ctx, model->mean_loss_buffer, model->acts.losses); + //toGPU(ctx, grads_acts.losses.data, model->acts_.data[22]); // crossentropy_softmax_backward(ctx, grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp); { @@ -794,11 +799,11 @@ void gpt2_backward(Context& ctx, GPT2 *model) { dispatchKernel(ctx, model->kernels.encoder_backward, promise); wait(ctx, future); } - toCPU(ctx, model->params_.data[0], model->grads.wte.data, model->param_sizes[0] * sizeof(float)); - toCPU(ctx, model->params_.data[1], model->grads.wpe.data, model->param_sizes[1] * sizeof(float)); + // toCPU(ctx, model->params_.data[0], model->grads.wte.data, model->param_sizes[0] * sizeof(float)); + // toCPU(ctx, model->params_.data[1], model->grads.wpe.data, model->param_sizes[1] * sizeof(float)); } -void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) { +void gpt2_update(Context& ctx, GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) { // reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html // lazily allocate the memory for m_memory and v_memory @@ -807,6 +812,45 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo model->v_memory = (float*)calloc(model->num_parameters, sizeof(float)); } + // Copy the parameters to the CPU + float* iter = model->params_memory; + toCPU(ctx, model->params.wte, iter, model->param_sizes[0] * sizeof(float)); + iter += model->param_sizes[0]; + toCPU(ctx, model->params.wpe, iter, model->param_sizes[1] * sizeof(float)); + iter += model->param_sizes[1]; + size_t L = model->config.num_layers; + for (int l = 0; l < L; l++) { + toCPU(ctx, model->params.ln1w[l], iter, model->param_sizes[2]/L * sizeof(float)); + iter += model->param_sizes[2]/L; + toCPU(ctx, model->params.ln1b[l], iter, model->param_sizes[3]/L * sizeof(float)); + iter += model->param_sizes[3]/L; + toCPU(ctx, model->params.qkvw[l], iter, model->param_sizes[4]/L * sizeof(float)); + iter += model->param_sizes[4]/L; + toCPU(ctx, model->params.qkvb[l], iter, model->param_sizes[5]/L * sizeof(float)); + iter += model->param_sizes[5]/L; + toCPU(ctx, model->params.attprojw[l], iter, model->param_sizes[6]/L * sizeof(float)); + iter += model->param_sizes[6]/L; + toCPU(ctx, model->params.attprojb[l], iter, model->param_sizes[7]/L * sizeof(float)); + iter += model->param_sizes[7]/L; + toCPU(ctx, model->params.ln2w[l], iter, model->param_sizes[8]/L * sizeof(float)); + iter += model->param_sizes[8]/L; + toCPU(ctx, model->params.ln2b[l], iter, model->param_sizes[9]/L * sizeof(float)); + iter += model->param_sizes[9]/L; + toCPU(ctx, model->params.fcw[l], iter, model->param_sizes[10]/L * sizeof(float)); + iter += model->param_sizes[10]/L; + toCPU(ctx, model->params.fcb[l], iter, model->param_sizes[11]/L * sizeof(float)); + iter += model->param_sizes[11]/L; + toCPU(ctx, model->params.fcprojw[l], iter, model->param_sizes[12]/L * sizeof(float)); + iter += model->param_sizes[12]/L; + toCPU(ctx, model->params.fcprojb[l], iter, model->param_sizes[13]/L * sizeof(float)); + iter += model->param_sizes[13]/L; + } + toCPU(ctx, model->params.lnfw, iter, model->param_sizes[14] * sizeof(float)); + iter += model->param_sizes[14]; + toCPU(ctx, model->params.lnfb, iter, model->param_sizes[15] * sizeof(float)); + iter += model->param_sizes[15]; + + for (size_t i = 0; i < model->num_parameters; i++) { float param = model->params_memory[i]; float grad = model->grads_memory[i]; @@ -824,8 +868,43 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo model->v_memory[i] = v; model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param); } - toGPU(ctx, model->params_memory, model->params_.data[0]); - toGPU(ctx, model->params_memory + model->param_sizes[0], model->params_.data[1]); + // toGPU(ctx, model->params_memory, model->params_.data[0]); + // toGPU(ctx, model->params_memory + model->param_sizes[0], model->params_.data[1]); + iter = model->params_memory; + toGPU(ctx, iter, model->params.wte); + iter += model->param_sizes[0]; + toGPU(ctx, iter, model->params.wpe); + iter += model->param_sizes[1]; + for (int l = 0; l < L; l++) { + toGPU(ctx, iter, model->params.ln1w[l]); + iter += model->param_sizes[2]/L; + toGPU(ctx, iter, model->params.ln1b[l]); + iter += model->param_sizes[3]/L; + toGPU(ctx, iter, model->params.qkvw[l]); + iter += model->param_sizes[4]/L; + toGPU(ctx, iter, model->params.qkvb[l]); + iter += model->param_sizes[5]/L; + toGPU(ctx, iter, model->params.attprojw[l]); + iter += model->param_sizes[6]/L; + toGPU(ctx, iter, model->params.attprojb[l]); + iter += model->param_sizes[7]/L; + toGPU(ctx, iter, model->params.ln2w[l]); + iter += model->param_sizes[8]/L; + toGPU(ctx, iter, model->params.ln2b[l]); + iter += model->param_sizes[9]/L; + toGPU(ctx, iter, model->params.fcw[l]); + iter += model->param_sizes[10]/L; + toGPU(ctx, iter, model->params.fcb[l]); + iter += model->param_sizes[11]/L; + toGPU(ctx, iter, model->params.fcprojw[l]); + iter += model->param_sizes[12]/L; + toGPU(ctx, iter, model->params.fcprojb[l]); + iter += model->param_sizes[13]/L; + } + toGPU(ctx, iter, model->params.lnfw); + iter += model->param_sizes[14]; + toGPU(ctx, iter, model->params.lnfb); + iter += model->param_sizes[15]; } void gpt2_free(GPT2 *model) { @@ -915,6 +994,7 @@ int main() { Tensor inputs = createTensor(ctx, Shape{B, T}, ki32); Tensor targets = createTensor(ctx, Shape{B, T}, ki32); Tensor gen_tokens = createTensor(ctx, Shape{B, T}, ki32); + int* gen_tokens_cpu = (int*)mallocCheck(B * T * sizeof(int)); printf("Starting training\n"); for (int step = 0; step <= 40; step++) { printf("Step %d\n", step); @@ -937,7 +1017,10 @@ int main() { // once in a while do model inference to print generated text if (step > 0 && step % 20 == 0) { // fill up gen_tokens with the GPT2_EOT, which kicks off the generation - toGPU(ctx, tokenizer.eot_token, gen_tokens); + for(int i = 0; i < B * T; ++i) { + gen_tokens_cpu[i] = tokenizer.eot_token; + } + toGPU(ctx, gen_tokens_cpu, gen_tokens); // now sample from the model autoregressively printf("generating:\n---\n"); for (int t = 1; t < genT; t++) { @@ -950,14 +1033,15 @@ int main() { // we're in principle running B "inference streams" in parallel here // but only using position 0 // get the Vp-dimensional vector probs[0, t-1, :] - float* probs = model.acts.probs.data + (t-1) * model.config.padded_vocab_size; - toCPU(ctx, model.acts_.data[21], probs, (t-1) * model.config.padded_vocab_size * sizeof(float)); + toCPU(ctx, model.acts.probs, model.probs_buffer, B * T * model.config.padded_vocab_size * sizeof(float)); + float* probs = model.probs_buffer + (t-1) * model.config.padded_vocab_size; float coin = random_f32(&rng_state); // note we're only sampling from the first V elements, ignoring padding // (the probabilities in the padded region should be zero anyway) int next_token = sample_mult(probs, model.config.vocab_size, coin); - gen_tokens[t] = next_token; + gen_tokens_cpu[t] = next_token; + toGPU(ctx, gen_tokens_cpu, gen_tokens); // print the generated token, either using the Tokenizer or a fallback if (tokenizer.init_ok) { const char* token_str = tokenizer_decode(&tokenizer, next_token); @@ -974,10 +1058,12 @@ int main() { // do a training step clock_gettime(CLOCK_MONOTONIC, &start); dataloader_next_batch(&train_loader); - gpt2_forward(ctx, &model, train_loader.inputs, train_loader.targets, B, T); + toGPU(ctx, train_loader.inputs, inputs); + toGPU(ctx, train_loader.targets, targets); + gpt2_forward(ctx, &model, inputs, targets, B, T); gpt2_zero_grad(&model); gpt2_backward(ctx, &model); - gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000); From f3e0dbca692d6d3aee0323854256bfa408a231bf Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 30 Oct 2024 17:48:42 +0900 Subject: [PATCH 20/44] Add summantion kernels --- experimental/kernels/Makefile | 4 + experimental/kernels/kernels.h | 72 ++++++ experimental/kernels/reduce.cpp | 415 ++++++++++++++++++++++++++++++++ gpu.hpp | 8 + 4 files changed, 499 insertions(+) create mode 100644 experimental/kernels/reduce.cpp diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index c233ef5..90da2fe 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -29,6 +29,10 @@ endif default: run-native +build/reduce: reduce.cpp kernels.h + $(CC) $(CFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $< + $(LIBSPEC) && build/reduce + run_llm.c: ./build/test_gpt2 dawnlib $(LIBSPEC) && $< diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h index 212c075..b8a08f8 100644 --- a/experimental/kernels/kernels.h +++ b/experimental/kernels/kernels.h @@ -683,6 +683,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3) { } )"; +static const char *kSum = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (blockSize >= 1024 && threadId < 512) { + buffer[threadId] += buffer[threadId + 512]; + } + workgroupBarrier(); + + if (blockSize >= 512 && threadId < 256) { + buffer[threadId] += buffer[threadId + 256]; + } + workgroupBarrier(); + + if (blockSize >= 256 && threadId < 128) { + buffer[threadId] += buffer[threadId + 128]; + } + workgroupBarrier(); + + if (threadId < 64) { + buffer[threadId] += buffer[threadId + 64]; + } + workgroupBarrier(); + + if (threadId < 32) { + buffer[threadId] += buffer[threadId + 32]; + } + workgroupBarrier(); + + if (threadId < 16) { + buffer[threadId] += buffer[threadId + 16]; + } + workgroupBarrier(); + + if (threadId < 8) { + buffer[threadId] += buffer[threadId + 8]; + } + workgroupBarrier(); + + if (threadId < 4) { + buffer[threadId] += buffer[threadId + 4]; + } + workgroupBarrier(); + + if (threadId < 2) { + buffer[threadId] += buffer[threadId + 2]; + } + workgroupBarrier(); + + if (threadId == 0) { + buffer[0] += buffer[1]; + out[blockId] = buffer[0]; + } +} +)"; + } // namespace gpu #endif // KERNELS_H diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp new file mode 100644 index 0000000..13c6c40 --- /dev/null +++ b/experimental/kernels/reduce.cpp @@ -0,0 +1,415 @@ +#include "gpu.hpp" +#include +#include +#include +#include +#include +#include "utils/array_utils.hpp" // show, isclose, randn, randint +#include "kernels.h" + +using namespace gpu; + +#define LIMITS { \ + .nextInChain = nullptr, \ + .limits = { \ + .maxTextureDimension1D=8192, \ + .maxTextureDimension2D=8192, \ + .maxTextureDimension3D=2048, \ + .maxTextureArrayLayers=256, \ + .maxBindGroups=4, \ + .maxBindGroupsPlusVertexBuffers=24, \ + .maxBindingsPerBindGroup=1000, \ + .maxDynamicUniformBuffersPerPipelineLayout=8, \ + .maxDynamicStorageBuffersPerPipelineLayout=4, \ + .maxSampledTexturesPerShaderStage=16, \ + .maxSamplersPerShaderStage=16, \ + .maxStorageBuffersPerShaderStage=8, \ + .maxStorageTexturesPerShaderStage=4, \ + .maxUniformBuffersPerShaderStage=12, \ + .maxUniformBufferBindingSize=65536, \ + .maxStorageBufferBindingSize=1073741824, \ + .minUniformBufferOffsetAlignment=256, \ + .minStorageBufferOffsetAlignment=256, \ + .maxVertexBuffers=8, \ + .maxBufferSize=0x80000000, \ + .maxVertexAttributes=16, \ + .maxVertexBufferArrayStride=2048, \ + .maxInterStageShaderComponents=64, \ + .maxInterStageShaderVariables=16, \ + .maxColorAttachments=8, \ + .maxColorAttachmentBytesPerSample=32, \ + .maxComputeWorkgroupStorageSize=16384, \ + .maxComputeInvocationsPerWorkgroup=1024, \ + .maxComputeWorkgroupSizeX=1024, \ + .maxComputeWorkgroupSizeY=1024, \ + .maxComputeWorkgroupSizeZ=64, \ + .maxComputeWorkgroupsPerDimension=65535 \ + } \ + } + + +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + int num; + + inline DurationTime(const std::string& src, bool verbose = true, int num = 1) { + this->src = src; + this->verbose = verbose; + this->num = num; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count()) / static_cast(num)); + } + } +}; + +static const char *kSumVersion1 = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + + for (var stride: u32 = blockSize / 2; stride > 0; stride /= 2) { + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + } + + if (threadId == 0) { + out[blockId] = buffer[0]; + } +} +)"; + +static const char *kSumVersion2 = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let n: u32 = arrayLength(&inp); + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/4 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/8 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/16 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/32 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/64 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/128 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/256 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/512 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/1024 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + + if (threadId == 0) { + out[blockId] = buffer[0]; + } +} +)"; + +static const char *kSum2d = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +@group(0) @binding(2) var params : Params; +struct Params { + N: u32, + C: u32, +}; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (blockSize >= 1024 && threadId < 512) { + buffer[threadId] += buffer[threadId + 512]; + } + workgroupBarrier(); + + if (blockSize >= 512 && threadId < 256) { + buffer[threadId] += buffer[threadId + 256]; + } + workgroupBarrier(); + + if (blockSize >= 256 && threadId < 128) { + buffer[threadId] += buffer[threadId + 128]; + } + workgroupBarrier(); + + if (threadId < 64) { + buffer[threadId] += buffer[threadId + 64]; + } + workgroupBarrier(); + + if (threadId < 32) { + buffer[threadId] += buffer[threadId + 32]; + } + workgroupBarrier(); + + if (threadId < 16) { + buffer[threadId] += buffer[threadId + 16]; + } + workgroupBarrier(); + + if (threadId < 8) { + buffer[threadId] += buffer[threadId + 8]; + } + workgroupBarrier(); + + if (threadId < 4) { + buffer[threadId] += buffer[threadId + 4]; + } + workgroupBarrier(); + + if (threadId < 2) { + buffer[threadId] += buffer[threadId + 2]; + } + workgroupBarrier(); + + if (threadId == 0) { + buffer[0] += buffer[1]; + out[blockId] = buffer[0]; + } +} +)"; + +float sum_cpu(const float* data, size_t size) { + float result = 0; + for (size_t i = 0; i < size; ++i) { + result += data[i]; + } + return result; +} + +Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) { + uint32_t num_threads = 1024; + uint32_t num_blocks = ((size + num_threads -1) / num_threads); + uint32_t size_x = 32768u < num_blocks ? 32768u : num_blocks; + uint32_t size_y = size_x == 32768u ? num_blocks / 32768u : 1; + size_x /= 2; + size_x = size_x < 1 ? 1 : size_x; + // print size_x, size_y + // printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); +} + +float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) { + WGPURequiredLimits requiredLimits = LIMITS; + uint32_t num_threads = 1024; + int nSum = round(log2(size) / log2(num_threads)); + int input_size = size; + unsigned long output_size = size; + std::vector outputs; + std::vector ops; + outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(1024*2))}, kf32)); + for(int i=size,j=0;i>0;i/=num_threads,j++){ + output_size = (output_size + num_threads - 1) / num_threads; + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(1024*2))}, kf32)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + // printf("size: %d\n", input_size); + input_size = output_size; + } + toGPU(ctx, data, outputs[0], size * sizeof(float)); + + + { + for(int i=size,j=0;i>0;i/=num_threads,j++){ + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, ops[j], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[j]); + } + } + + { + int nIter = 100; + DurationTime dt("GPU", true, nIter); + for (int t = 0; t < nIter; t++){ + for(int i=size,j=0;i>0;i/=num_threads,j++){ + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, ops[j], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[j]); + } + } + } + + float r = 0; + toCPU(ctx, outputs[outputs.size()-1], (void*)buffer, 4 * sizeof(float)); + + return buffer[0]; +} + +// float sum_gpu2d(Context& ctx, const float* data, const float* buffer, size_t size_x, size_t size_y) { +// WGPURequiredLimits requiredLimits = LIMITS; +// Tensor input = createTensor(ctx, Shape{size}, kf32, data); +// Tensor output = createTensor(ctx, Shape{size}, kf32); +// uint32_t num_threads = 1024; +// uint32_t num_blocks = ((size_x + num_threads -1) / num_threads); +// printf("size: %u, size_x: %u, size_y: %u\n", size, size_x, size_y); +// Kernel op = createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); +// +// { +// for (int i = 0; i < 100; ++i){ +// DurationTime dt("GPU"); +// std::promise promise; +// std::future future = promise.get_future(); +// dispatchKernel(ctx, op, promise); +// wait(ctx, future); +// resetCommandBuffer(ctx.device, op); +// } +// } +// +// float r = 0; +// toCPU(ctx, output, (void*)buffer, num_blocks * sizeof(float)); +// +// for (int i = 0; i < num_blocks; i++){ +// r+=buffer[i]; +// } +// return r; +// } + +int main(int argc, char **argv) { + static constexpr size_t M = 4096*2; + static constexpr size_t N = 4096*2; + static constexpr size_t BUF_SIZE = 16; + std::unique_ptr inputArr = std::make_unique(M * N); + std::unique_ptr buffer = std::make_unique(BUF_SIZE); + std::mt19937 gen(314159); + printf("Initializing %zu values\n", M*N); + randn(inputArr.get(), M*N, gen); + // for(int i=0;i= 1e-0f) { + printf("Error: diff = %.6f\n", diff); + } else { + printf("Success: diff = %.6f\n", diff); + } + + printf("Computed %zu values of kSum(x)\n\n", M*N); + return 0; +} diff --git a/gpu.hpp b/gpu.hpp index 83fc94b..8047646 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1119,6 +1119,14 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) { tensor.data.size); } +inline void toGPU(Context &ctx, const float *data, Tensor &tensor, size_t size) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, size); +} + +inline void toGPU(Context &ctx, const half *data, Tensor &tensor, size_t size) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, size); +} + template inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. From f956f2b78b08bbfd0212efa53b3d1fd90d9b0941 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Thu, 31 Oct 2024 04:14:51 +0900 Subject: [PATCH 21/44] Add SumKernel --- experimental/kernels/reduce.cpp | 61 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp index 13c6c40..e1bc387 100644 --- a/experimental/kernels/reduce.cpp +++ b/experimental/kernels/reduce.cpp @@ -285,51 +285,55 @@ Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); } -float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) { - WGPURequiredLimits requiredLimits = LIMITS; - uint32_t num_threads = 1024; - int nSum = round(log2(size) / log2(num_threads)); - int input_size = size; - unsigned long output_size = size; +struct SumKernel { std::vector outputs; std::vector ops; - outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(1024*2))}, kf32)); - for(int i=size,j=0;i>0;i/=num_threads,j++){ - output_size = (output_size + num_threads - 1) / num_threads; - outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(1024*2))}, kf32)); - ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); - // printf("size: %d\n", input_size); - input_size = output_size; - } - toGPU(ctx, data, outputs[0], size * sizeof(float)); - - - { + SumKernel(Context& ctx, size_t size) { + uint32_t num_threads = 1024; + int nSum = round(log2(size) / log2(num_threads)); + int input_size = size; + unsigned long output_size = size; + outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(num_threads*2))}, kf32)); for(int i=size,j=0;i>0;i/=num_threads,j++){ + output_size = (output_size + num_threads - 1) / num_threads; + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2))}, kf32)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + input_size = output_size; + } + } + void dispatchKernel(Context& ctx) { + for(int i=0;i promise; std::future future = promise.get_future(); - dispatchKernel(ctx, ops[j], promise); + gpu::dispatchKernel(ctx, ops[i], promise); wait(ctx, future); - resetCommandBuffer(ctx.device, ops[j]); + resetCommandBuffer(ctx.device, ops[i]); } } + void toGPU(Context& ctx, const float* data, size_t size) { + gpu::toGPU(ctx, data, outputs[0], size); + } + void toCPU(Context& ctx, float* data, size_t size) { + gpu::toCPU(ctx, outputs[outputs.size()-1], data, size); + } +}; + +float sum_gpu(Context& ctx, const float* data, float* buffer, size_t size) { + WGPURequiredLimits requiredLimits = LIMITS; + SumKernel sumKernel(ctx, size); + sumKernel.toGPU(ctx, data, size * sizeof(float)); + sumKernel.dispatchKernel(ctx); { int nIter = 100; DurationTime dt("GPU", true, nIter); for (int t = 0; t < nIter; t++){ - for(int i=size,j=0;i>0;i/=num_threads,j++){ - std::promise promise; - std::future future = promise.get_future(); - dispatchKernel(ctx, ops[j], promise); - wait(ctx, future); - resetCommandBuffer(ctx.device, ops[j]); - } + sumKernel.dispatchKernel(ctx); } } float r = 0; - toCPU(ctx, outputs[outputs.size()-1], (void*)buffer, 4 * sizeof(float)); + sumKernel.toCPU(ctx, buffer, 4 * sizeof(float)); return buffer[0]; } @@ -363,6 +367,7 @@ float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) // return r; // } + int main(int argc, char **argv) { static constexpr size_t M = 4096*2; static constexpr size_t N = 4096*2; From c13833fd07bac1659ef530eec4dd3558a56233d9 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 3 Nov 2024 16:13:33 +0900 Subject: [PATCH 22/44] Add SumKernel2d --- experimental/kernels/reduce.cpp | 280 +++++++++++++++++++++----------- 1 file changed, 187 insertions(+), 93 deletions(-) diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp index e1bc387..38cb6a7 100644 --- a/experimental/kernels/reduce.cpp +++ b/experimental/kernels/reduce.cpp @@ -199,68 +199,37 @@ struct Params { var buffer: array<{{precision}}, 1024>; @compute @workgroup_size({{workgroupSize}}) fn main( - @builtin(global_invocation_id) globalID : vec3, @builtin(local_invocation_id) localID : vec3, @builtin(workgroup_id) groupid : vec3, @builtin(num_workgroups) numGroups : vec3) { + let N : u32 = params.N; + let C : u32 = params.C; let blockSize3d: vec3 = vec3({{workgroupSize}}); let blockSize: u32 = blockSize3d.x; let threadId: u32 = localID.x; let blockId: u32 = groupid.x + groupid.y * numGroups.x; - let blockStart = blockId * blockSize * 2 + threadId; - - buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; - workgroupBarrier(); - var stride: u32 = blockSize / 2; - - if (blockSize >= 1024 && threadId < 512) { - buffer[threadId] += buffer[threadId + 512]; - } - workgroupBarrier(); - - if (blockSize >= 512 && threadId < 256) { - buffer[threadId] += buffer[threadId + 256]; - } - workgroupBarrier(); - if (blockSize >= 256 && threadId < 128) { - buffer[threadId] += buffer[threadId + 128]; - } - workgroupBarrier(); - - if (threadId < 64) { - buffer[threadId] += buffer[threadId + 64]; - } - workgroupBarrier(); - - if (threadId < 32) { - buffer[threadId] += buffer[threadId + 32]; - } - workgroupBarrier(); - - if (threadId < 16) { - buffer[threadId] += buffer[threadId + 16]; - } - workgroupBarrier(); - - if (threadId < 8) { - buffer[threadId] += buffer[threadId + 8]; - } - workgroupBarrier(); - - if (threadId < 4) { - buffer[threadId] += buffer[threadId + 4]; - } - workgroupBarrier(); - - if (threadId < 2) { - buffer[threadId] += buffer[threadId + 2]; - } - workgroupBarrier(); - - if (threadId == 0) { - buffer[0] += buffer[1]; - out[blockId] = buffer[0]; + for (var i: u32 = 0; i= N) { + } else if(blockStart + blockSize >= N) { + buffer[threadId] = inp[blockStart * C + i]; + } else { + buffer[threadId] = inp[blockStart * C + i] + inp[(blockStart + blockSize) * C + i]; + } + workgroupBarrier(); + + for (var stride: u32 = blockSize / 2; stride > 0; stride /= 2) { + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + } + + if (threadId == 0) { + out[blockId * C + i] = buffer[0]; + } + workgroupBarrier(); } } )"; @@ -273,33 +242,100 @@ float sum_cpu(const float* data, size_t size) { return result; } -Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) { - uint32_t num_threads = 1024; +void sum_cpu_2d(const float* data, float* out, size_t size0, size_t size1) { + float result = 0; + for (size_t j = 0; j < size1; ++j) { + out[j] = 0; + } + for (size_t i = 0; i < size0; ++i) { + for (size_t j = 0; j < size1; ++j) { + out[j] += data[(i * size1) + j]; + } + } +} + +Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size, uint32_t num_threads = 1024) { uint32_t num_blocks = ((size + num_threads -1) / num_threads); uint32_t size_x = 32768u < num_blocks ? 32768u : num_blocks; uint32_t size_y = size_x == 32768u ? num_blocks / 32768u : 1; size_x /= 2; size_x = size_x < 1 ? 1 : size_x; // print size_x, size_y - // printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); } +Kernel createSumKernel2d(Context& ctx, Tensor& input, Tensor& output, size_t size0, size_t size1, uint32_t num_threads = 1024) { + struct Params { + uint32_t N; + uint32_t C; + }; + uint32_t num_blocks = ((size0 + num_threads -1) / num_threads); + uint32_t size_x = num_blocks; + uint32_t size_y = size1; + size_x /= 2; + size_x = size_x < 1 ? 1 : size_x; + printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + return createKernel(ctx, + {kSum2d, num_threads, kf32}, + Bindings{input, output}, + {size_x, size_y, 1}, + Params{ + static_cast(size0), + static_cast(size1), + }); +} + struct SumKernel { std::vector outputs; std::vector ops; - SumKernel(Context& ctx, size_t size) { - uint32_t num_threads = 1024; - int nSum = round(log2(size) / log2(num_threads)); + SumKernel(Context& ctx, size_t size, uint32_t num_threads = 1024) { int input_size = size; unsigned long output_size = size; outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(num_threads*2))}, kf32)); - for(int i=size,j=0;i>0;i/=num_threads,j++){ - output_size = (output_size + num_threads - 1) / num_threads; + for(int j=0;output_size>1;j++){ + output_size = (output_size + (num_threads * 2) - 1) / (num_threads * 2); outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2))}, kf32)); - ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size, num_threads)); + input_size = output_size; + } + } + void dispatchKernel(Context& ctx) { + for(int i=0;i promise; + std::future future = promise.get_future(); + gpu::dispatchKernel(ctx, ops[i], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[i]); + } + } + void toGPU(Context& ctx, const float* data, size_t size) { + gpu::toGPU(ctx, data, outputs[0], size); + } + void toCPU(Context& ctx, float* data, size_t size) { + gpu::toCPU(ctx, outputs[outputs.size()-1], data, size); + } +}; + +struct SumKernel2d { + std::vector outputs; + std::vector ops; + bool debug; + SumKernel2d(Context& ctx, size_t size0, size_t size1, uint32_t num_threads = 1024) { + debug = false; + int input_size = size0; + unsigned long output_size = size0; + outputs.push_back(createTensor(ctx, Shape{std::max(size0, static_cast(num_threads*2)),size1}, kf32)); + for(int j=0;output_size>1;j++){ + output_size = (output_size + (num_threads * 2) - 1) / (num_threads * 2); + if (debug) + printf("size0: %d, num_threads: %d, output_size: %d\n", size0, num_threads, output_size); + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2)), size1}, kf32)); + ops.push_back(createSumKernel2d(ctx, outputs[j], outputs[j+1], input_size, size1, num_threads)); input_size = output_size; } + if (debug) + printf("ops.size(): %d\n", ops.size()); } void dispatchKernel(Context& ctx) { for(int i=0;i buffer = std::make_unique(8); + for(int i=0;i promise; -// std::future future = promise.get_future(); -// dispatchKernel(ctx, op, promise); -// wait(ctx, future); -// resetCommandBuffer(ctx.device, op); -// } -// } -// -// float r = 0; -// toCPU(ctx, output, (void*)buffer, num_blocks * sizeof(float)); -// -// for (int i = 0; i < num_blocks; i++){ -// r+=buffer[i]; -// } -// return r; -// } +void sum_gpu_2d(Context& ctx, const float* data, float* out, size_t size0, size_t size1) { + WGPURequiredLimits requiredLimits = LIMITS; + SumKernel2d sumKernel(ctx, size0, size1); + sumKernel.toGPU(ctx, data, size0 * size1 * sizeof(float)); + sumKernel.dispatchKernel(ctx); + + { + int nIter = 3; + DurationTime dt("GPU", true, nIter); + for (int t = 0; t < nIter; t++){ + sumKernel.dispatchKernel(ctx); + } + } + sumKernel.toCPU(ctx, out, size1 * sizeof(float)); +} -int main(int argc, char **argv) { +int main_1d(int argc, char **argv) { static constexpr size_t M = 4096*2; static constexpr size_t N = 4096*2; static constexpr size_t BUF_SIZE = 16; @@ -389,7 +423,6 @@ int main(int argc, char **argv) { gpu::Context ctx = gpu::createContext({}, {}, { .requiredLimits = &requiredLimits }); - Tensor input = createTensor(ctx, Shape{M*N}, kf32, inputArr.get()); printf("Start testing sum(x) on %zu values\n", M*N); cpu_result = sum_cpu(inputArr.get(), M*N); @@ -418,3 +451,64 @@ int main(int argc, char **argv) { printf("Computed %zu values of kSum(x)\n\n", M*N); return 0; } + +int main_2d(int argc, char **argv) { + static constexpr size_t M = 4096; + static constexpr size_t N = 4096; + std::unique_ptr inputArr = std::make_unique(M * N); + std::unique_ptr outputCpuArr = std::make_unique(N); + std::unique_ptr outputGpuArr = std::make_unique(N); + std::mt19937 gen(314159); + printf("Initializing %zu values\n", M*N); + randn(inputArr.get(), M*N, gen); + for(int i=0;i= 1e-0f) { + printf("Error: diff = %.6f\n", diff); + } else { + printf("Success: diff = %.6f\n", diff); + } + + return 0; +} + +int main(int argc, char **argv) { + printf("================================\n"); + printf("Start testing reduce-1d\n"); + main_1d(argc,argv); + printf("================================\n"); + printf("Start testing reduce-2d\n"); + main_2d(argc,argv); + return 0; +} From c9b7018b2c50148c8cacacd273a66cf829ef2c44 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 3 Nov 2024 14:51:19 -0500 Subject: [PATCH 23/44] fix printf format codes --- experimental/kernels/reduce.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp index 38cb6a7..46460df 100644 --- a/experimental/kernels/reduce.cpp +++ b/experimental/kernels/reduce.cpp @@ -329,13 +329,13 @@ struct SumKernel2d { for(int j=0;output_size>1;j++){ output_size = (output_size + (num_threads * 2) - 1) / (num_threads * 2); if (debug) - printf("size0: %d, num_threads: %d, output_size: %d\n", size0, num_threads, output_size); + printf("size0: %zu, num_threads: %d, output_size: %lu\n", size0, num_threads, output_size); outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2)), size1}, kf32)); ops.push_back(createSumKernel2d(ctx, outputs[j], outputs[j+1], input_size, size1, num_threads)); input_size = output_size; } if (debug) - printf("ops.size(): %d\n", ops.size()); + printf("ops.size(): %zu\n", ops.size()); } void dispatchKernel(Context& ctx) { for(int i=0;i Date: Tue, 5 Nov 2024 01:41:44 +0900 Subject: [PATCH 24/44] Add a flag to disable bardward-pass --- experimental/kernels/gpt2_webgpu_aot.cpp | 30 +++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index cd474f9..2190a7a 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -278,6 +278,7 @@ typedef struct { // kernels Kernels kernels; + bool backward_enabled; } GPT2; void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoint_path) { @@ -379,6 +380,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin // Allocate B * C buffer for mean loss model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len); model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp); + model->backward_enabled = false; printf("Model build complete\n"); @@ -476,7 +478,8 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si kernels.crossentropy_forward = crossentropy_forward(ctx, model->acts.losses, model->acts.probs, targets, B, T, Vp); kernels.encoder_forward = encoder_forward(ctx, model->acts.encoded, inputs, model->params.wte, model->params.wpe, B, T, C); // encoding goes into residual[0] - kernels.encoder_backward = encoder_backward(ctx, model->params.wte, model->params.wpe, model->acts.encoded, inputs, B, T, C); + if(model->backward_enabled) + kernels.encoder_backward = encoder_backward(ctx, model->params.wte, model->params.wpe, model->acts.encoded, inputs, B, T, C); kernels.layernorm_final_forward = layernorm_forward(ctx, model->acts.lnf, model->acts.lnf_mean, model->acts.lnf_rstd, /*input=*/ model->acts.residual3[L-1], /*weight=*/ model->params.lnfw, /*bias=*/ model->params.lnfb, B, T, C); @@ -484,12 +487,15 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si model->nullTensor = nullTensor; kernels.matmul_final_forward = matmul_forward(ctx, model->acts.logits, model->acts.lnf, model->params.wte, nullTensor, B, T, C, Vp); kernels.softmax_final_forward = softmax_forward(ctx, model->acts.probs, model->acts.logits, B, T, V, Vp); - kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp); - kernels.matmul_final_backward = matmul_backward(ctx, model->acts.lnf, model->params.wte, nullTensor, model->acts.logits, - model->acts.lnf, model->params.wte, B, T, C, Vp); - kernels.layernorm_final_backward = layernorm_backward(ctx, model->acts.residual3[L-1], model->params.lnfw, model->params.lnfb, - model->acts.lnf, model->acts.residual3[L-1], model->params.lnfw, - model->acts.lnf_mean, model->acts.lnf_rstd, B, T, C); + if(model->backward_enabled) + kernels.crossentropy_softmax_backward = crossentropy_softmax_backward(ctx, model->acts.logits, model->acts.losses, model->acts.probs, targets, B, T, V, Vp); + if(model->backward_enabled) + kernels.matmul_final_backward = matmul_backward(ctx, model->acts.lnf, model->params.wte, nullTensor, model->acts.logits, + model->acts.lnf, model->params.wte, B, T, C, Vp); + if(model->backward_enabled) + kernels.layernorm_final_backward = layernorm_backward(ctx, model->acts.residual3[L-1], model->params.lnfw, model->params.lnfb, + model->acts.lnf, model->acts.residual3[L-1], model->params.lnfw, + model->acts.lnf_mean, model->acts.lnf_rstd, B, T, C); printf("Created Kernels\n"); } @@ -557,7 +563,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si { std::promise promise; std::future future = promise.get_future(); - dispatchKernel(ctx, model->kernels.layernorm2_backward[l], promise); + dispatchKernel(ctx, model->kernels.layernorm_forward[l], promise); wait(ctx, future); } printf(" [Forward] : FF Up\n"); @@ -1061,9 +1067,11 @@ int main() { toGPU(ctx, train_loader.inputs, inputs); toGPU(ctx, train_loader.targets, targets); gpt2_forward(ctx, &model, inputs, targets, B, T); - gpt2_zero_grad(&model); - gpt2_backward(ctx, &model); - gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + if (model.backward_enabled) { + gpt2_zero_grad(&model); + gpt2_backward(ctx, &model); + gpt2_update(ctx, &model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + } clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000); From 6be7e1e0462d819712bf43570f55854a54fbc8f0 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sat, 16 Nov 2024 14:17:59 +0900 Subject: [PATCH 25/44] Fix the bug of memory allocation --- experimental/kernels/Makefile | 4 ++-- experimental/kernels/gpt2_webgpu_aot.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index 7430a71..aa34e97 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -16,7 +16,7 @@ CXXFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -I. -Iunittest_l CFLAGS=-Ofast -march=native -I. -Iunittest_llmc # CFLAGS=-O2 -march=native -I. -Iunittest_llmc -LDFLAGS=$(STDLIB) -L$(GPUCPP)/third_party/lib -ldl -ldawn +LDFLAGS=$(STDLIB) -L$(GPUCPP)/third_party/lib -ldl -ldawn -fsanitize=address FLAGS=$(CXXFLAGS) $(LDFLAGS) ifeq ($(shell [ -d /opt/homebrew/opt/libomp/lib ] && echo "exists"), exists) @@ -101,7 +101,7 @@ build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp build/gpt2_webgpu_aot: llm.c gpt2_124M.bin llm.c gpt2_webgpu_aot.cpp ops_aot.cpp mkdir -p build - $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu_aot.cpp ops_aot.cpp + $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu_aot.cpp ops_aot.cpp -g build/gpt2_webgpu.html: check-emsdk gpt2_webgpu.cpp term.html llm.c em++ gpt2_webgpu.cpp ops.cpp \ diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index 2190a7a..1e7043f 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -377,9 +377,8 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss - // Allocate B * C buffer for mean loss - model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len); - model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp); + model->mean_loss_buffer = NULL; + model->probs_buffer = NULL; model->backward_enabled = false; printf("Model build complete\n"); @@ -418,6 +417,8 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si model->seq_len = T; // and now allocate the space fill_in_activation_sizes(model->act_sizes, model->config, B, T); + model->mean_loss_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len); + model->probs_buffer = (float*)mallocCheck(sizeof(float) * model->batch_size * model->seq_len * Vp); // TODO(avh): this is just a resource test for now, eventually deprecate CPU allocations size_t num_activations = 0; @@ -635,7 +636,6 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si } // for convenience also evaluate the mean loss float mean_loss = 0.0f; - //toCPU(ctx, model->acts_.data[22], model->acts.losses.data, model->act_sizes[22] * sizeof(float)); toCPU(ctx, model->acts.losses, model->mean_loss_buffer, B*T * sizeof(float)); for (int i=0; imean_loss_buffer[i]; } mean_loss /= B*T; From f629a335aef7f8f27cc27d1811f3c09679e3bfb5 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 17 Nov 2024 04:36:43 +0900 Subject: [PATCH 26/44] Remove NUM_PARAMETER_LAYERS --- experimental/kernels/gpt2_webgpu_aot.cpp | 113 ++++++++++++++--------- 1 file changed, 67 insertions(+), 46 deletions(-) diff --git a/experimental/kernels/gpt2_webgpu_aot.cpp b/experimental/kernels/gpt2_webgpu_aot.cpp index 1e7043f..966fb7a 100644 --- a/experimental/kernels/gpt2_webgpu_aot.cpp +++ b/experimental/kernels/gpt2_webgpu_aot.cpp @@ -47,7 +47,6 @@ typedef struct { // the parameters of the model #define NUM_PARAMETER_TENSORS 16 -#define NUM_PARAMETER_LAYERS 12 typedef struct { Tensor wte; // (V, C) Tensor wpe; // (maxT, C) @@ -91,22 +90,36 @@ void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) { } // allocate memory for the parameters and point the individual tensors to the right places -void malloc_and_point_parameters(Context& ctx, ParameterTensors* params, size_t* param_sizes) { +void malloc_and_point_parameters(Context& ctx, GPT2Config config, ParameterTensors* params, size_t* param_sizes) { + size_t L = config.num_layers; params->wte = createTensor(ctx, Shape{param_sizes[0]}, kf32); params->wpe = createTensor(ctx, Shape{param_sizes[1]}, kf32); - for(int l = 0; l < NUM_PARAMETER_LAYERS; l++) { - params->ln1w.push_back(createTensor(ctx, Shape{param_sizes[2]/NUM_PARAMETER_LAYERS}, kf32)); - params->ln1b.push_back(createTensor(ctx, Shape{param_sizes[3]/NUM_PARAMETER_LAYERS}, kf32)); - params->qkvw.push_back(createTensor(ctx, Shape{param_sizes[4]/NUM_PARAMETER_LAYERS}, kf32)); - params->qkvb.push_back(createTensor(ctx, Shape{param_sizes[5]/NUM_PARAMETER_LAYERS}, kf32)); - params->attprojw.push_back(createTensor(ctx, Shape{param_sizes[6]/NUM_PARAMETER_LAYERS}, kf32)); - params->attprojb.push_back(createTensor(ctx, Shape{param_sizes[7]/NUM_PARAMETER_LAYERS}, kf32)); - params->ln2w.push_back(createTensor(ctx, Shape{param_sizes[8]/NUM_PARAMETER_LAYERS}, kf32)); - params->ln2b.push_back(createTensor(ctx, Shape{param_sizes[9]/NUM_PARAMETER_LAYERS}, kf32)); - params->fcw.push_back(createTensor(ctx, Shape{param_sizes[10]/NUM_PARAMETER_LAYERS}, kf32)); - params->fcb.push_back(createTensor(ctx, Shape{param_sizes[11]/NUM_PARAMETER_LAYERS}, kf32)); - params->fcprojw.push_back(createTensor(ctx, Shape{param_sizes[12]/NUM_PARAMETER_LAYERS}, kf32)); - params->fcprojb.push_back(createTensor(ctx, Shape{param_sizes[13]/NUM_PARAMETER_LAYERS}, kf32)); + + params->ln1w.resize(L); + params->ln1b.resize(L); + params->qkvw.resize(L); + params->qkvb.resize(L); + params->attprojw.resize(L); + params->attprojb.resize(L); + params->ln2w.resize(L); + params->ln2b.resize(L); + params->fcw.resize(L); + params->fcb.resize(L); + params->fcprojw.resize(L); + params->fcprojb.resize(L); + for(int l = 0; l < L ; l++) { + params->ln1w[l] = createTensor(ctx, Shape{param_sizes[2]/config.num_layers}, kf32); + params->ln1b[l] = createTensor(ctx, Shape{param_sizes[3]/config.num_layers}, kf32); + params->qkvw[l] = createTensor(ctx, Shape{param_sizes[4]/config.num_layers}, kf32); + params->qkvb[l] = createTensor(ctx, Shape{param_sizes[5]/config.num_layers}, kf32); + params->attprojw[l] = createTensor(ctx, Shape{param_sizes[6]/config.num_layers}, kf32); + params->attprojb[l] = createTensor(ctx, Shape{param_sizes[7]/config.num_layers}, kf32); + params->ln2w[l] = createTensor(ctx, Shape{param_sizes[8]/config.num_layers}, kf32); + params->ln2b[l] = createTensor(ctx, Shape{param_sizes[9]/config.num_layers}, kf32); + params->fcw[l] = createTensor(ctx, Shape{param_sizes[10]/config.num_layers}, kf32); + params->fcb[l] = createTensor(ctx, Shape{param_sizes[11]/config.num_layers}, kf32); + params->fcprojw[l] = createTensor(ctx, Shape{param_sizes[12]/config.num_layers}, kf32); + params->fcprojb[l] = createTensor(ctx, Shape{param_sizes[13]/config.num_layers}, kf32); } params->lnfw = createTensor(ctx, Shape{param_sizes[14]}, kf32); params->lnfb = createTensor(ctx, Shape{param_sizes[15]}, kf32); @@ -201,25 +214,42 @@ void fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T act_sizes[22] = B * T; // losses } -void malloc_and_point_activations(Context& ctx, ActivationTensors* acts, size_t* act_sizes) { +void malloc_and_point_activations(Context& ctx, GPT2Config config, ActivationTensors* acts, size_t* act_sizes) { + size_t L = config.num_layers; acts->encoded = createTensor(ctx, Shape{act_sizes[0]}, kf32); - for (int l = 0; l < NUM_PARAMETER_LAYERS; l++) { - acts->ln1.push_back(createTensor(ctx, Shape{act_sizes[1]/NUM_PARAMETER_LAYERS}, kf32)); - acts->ln1_mean.push_back(createTensor(ctx, Shape{act_sizes[2]/NUM_PARAMETER_LAYERS}, kf32)); - acts->ln1_rstd.push_back(createTensor(ctx, Shape{act_sizes[3]/NUM_PARAMETER_LAYERS}, kf32)); - acts->qkv.push_back(createTensor(ctx, Shape{act_sizes[4]/NUM_PARAMETER_LAYERS}, kf32)); - acts->atty.push_back(createTensor(ctx, Shape{act_sizes[5]/NUM_PARAMETER_LAYERS}, kf32)); - acts->preatt.push_back(createTensor(ctx, Shape{act_sizes[6]/NUM_PARAMETER_LAYERS}, kf32)); - acts->att.push_back(createTensor(ctx, Shape{act_sizes[7]/NUM_PARAMETER_LAYERS}, kf32)); - acts->attproj.push_back(createTensor(ctx, Shape{act_sizes[8]/NUM_PARAMETER_LAYERS}, kf32)); - acts->residual2.push_back(createTensor(ctx, Shape{act_sizes[9]/NUM_PARAMETER_LAYERS}, kf32)); - acts->ln2.push_back(createTensor(ctx, Shape{act_sizes[10]/NUM_PARAMETER_LAYERS}, kf32)); - acts->ln2_mean.push_back(createTensor(ctx, Shape{act_sizes[11]/NUM_PARAMETER_LAYERS}, kf32)); - acts->ln2_rstd.push_back(createTensor(ctx, Shape{act_sizes[12]/NUM_PARAMETER_LAYERS}, kf32)); - acts->fch.push_back(createTensor(ctx, Shape{act_sizes[13]/NUM_PARAMETER_LAYERS}, kf32)); - acts->fch_gelu.push_back(createTensor(ctx, Shape{act_sizes[14]/NUM_PARAMETER_LAYERS}, kf32)); - acts->fcproj.push_back(createTensor(ctx, Shape{act_sizes[15]/NUM_PARAMETER_LAYERS}, kf32)); - acts->residual3.push_back(createTensor(ctx, Shape{act_sizes[16]/NUM_PARAMETER_LAYERS}, kf32)); + acts->ln1.resize(L); + acts->ln1_mean.resize(L); + acts->ln1_rstd.resize(L); + acts->qkv.resize(L); + acts->atty.resize(L); + acts->preatt.resize(L); + acts->att.resize(L); + acts->attproj.resize(L); + acts->residual2.resize(L); + acts->ln2.resize(L); + acts->ln2_mean.resize(L); + acts->ln2_rstd.resize(L); + acts->fch.resize(L); + acts->fch_gelu.resize(L); + acts->fcproj.resize(L); + acts->residual3.resize(L); + for (int l = 0; l < L; l++) { + acts->ln1[l] = createTensor(ctx, Shape{act_sizes[1]/config.num_layers}, kf32); + acts->ln1_mean[l] = createTensor(ctx, Shape{act_sizes[2]/config.num_layers}, kf32); + acts->ln1_rstd[l] = createTensor(ctx, Shape{act_sizes[3]/config.num_layers}, kf32); + acts->qkv[l] = createTensor(ctx, Shape{act_sizes[4]/config.num_layers}, kf32); + acts->atty[l] = createTensor(ctx, Shape{act_sizes[5]/config.num_layers}, kf32); + acts->preatt[l] = createTensor(ctx, Shape{act_sizes[6]/config.num_layers}, kf32); + acts->att[l] = createTensor(ctx, Shape{act_sizes[7]/config.num_layers}, kf32); + acts->attproj[l] = createTensor(ctx, Shape{act_sizes[8]/config.num_layers}, kf32); + acts->residual2[l] = createTensor(ctx, Shape{act_sizes[9]/config.num_layers}, kf32); + acts->ln2[l] = createTensor(ctx, Shape{act_sizes[10]/config.num_layers}, kf32); + acts->ln2_mean[l] = createTensor(ctx, Shape{act_sizes[11]/config.num_layers}, kf32); + acts->ln2_rstd[l] = createTensor(ctx, Shape{act_sizes[12]/config.num_layers}, kf32); + acts->fch[l] = createTensor(ctx, Shape{act_sizes[13]/config.num_layers}, kf32); + acts->fch_gelu[l] = createTensor(ctx, Shape{act_sizes[14]/config.num_layers}, kf32); + acts->fcproj[l] = createTensor(ctx, Shape{act_sizes[15]/config.num_layers}, kf32); + acts->residual3[l] = createTensor(ctx, Shape{act_sizes[16]/config.num_layers}, kf32); } acts->lnf = createTensor(ctx, Shape{act_sizes[17]}, kf32); acts->lnf_mean = createTensor(ctx, Shape{act_sizes[18]}, kf32); @@ -229,15 +259,6 @@ void malloc_and_point_activations(Context& ctx, ActivationTensors* acts, size_t* acts->losses = createTensor(ctx, Shape{act_sizes[22]}, kf32); } -struct GPUParameters { - Tensor data[NUM_PARAMETER_TENSORS]; -}; - -struct GPUActivations { - Tensor data[NUM_ACTIVATION_TENSORS]; -}; - - void gpu_alloc(Context& ctx, Tensor* tensors, size_t* sizes, size_t n) { for (size_t i = 0; i < n; i++) { tensors[i] = createTensor(ctx, Shape{sizes[i]}, kf32); @@ -325,7 +346,7 @@ void gpt2_build_from_checkpoint(Context& ctx, GPT2 *model, const char* checkpoin model->num_parameters = num_parameters; // read in all the parameters from file - malloc_and_point_parameters(ctx, &model->params, model->param_sizes); + malloc_and_point_parameters(ctx, model->config, &model->params, model->param_sizes); model->params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); freadCheck(model->params_memory, sizeof(float), num_parameters, model_file); fcloseCheck(model_file); @@ -428,7 +449,7 @@ void gpt2_forward(Context& ctx, GPT2 *model, Tensor& inputs, Tensor& targets, si printf("num_activations: %zu\n", num_activations); model->num_activations = num_activations; printf("Allocating %.2f MB for activations\n", num_activations * sizeof(float) / (1024.0f * 1024.0f)); - malloc_and_point_activations(ctx, &model->acts, model->act_sizes); + malloc_and_point_activations(ctx, model->config, &model->acts, model->act_sizes); // also create memory for caching inputs and targets //model->inputs = (int*)mallocCheck(B * T * sizeof(int)); //model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small @@ -664,8 +685,8 @@ void gpt2_backward(Context& ctx, GPT2 *model) { // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { printf("Allocating %.2f MB for gradients\n", model->num_parameters * sizeof(float) / (1024.0f * 1024.0f)); - malloc_and_point_parameters(ctx, &model->grads, model->param_sizes); - malloc_and_point_activations(ctx, &model->grads_acts, model->act_sizes); + malloc_and_point_parameters(ctx, model->config, &model->grads, model->param_sizes); + malloc_and_point_activations(ctx, model->config, &model->grads_acts, model->act_sizes); gpt2_zero_grad(model); } From 6e3a240822aedd48762e5dbe032422db70f690e4 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 25 Dec 2024 22:35:55 +0900 Subject: [PATCH 27/44] Add python bindings --- bindings/python/Makefile | 26 ++++++++ bindings/python/gpu_cpp.cpp | 112 ++++++++++++++++++++++++++++++++ bindings/python/test_gpu_cpp.py | 39 +++++++++++ 3 files changed, 177 insertions(+) create mode 100644 bindings/python/Makefile create mode 100644 bindings/python/gpu_cpp.cpp create mode 100644 bindings/python/test_gpu_cpp.py diff --git a/bindings/python/Makefile b/bindings/python/Makefile new file mode 100644 index 0000000..fa70468 --- /dev/null +++ b/bindings/python/Makefile @@ -0,0 +1,26 @@ +CXX=clang++ +PYTHON=python3 +GPUCPP ?= $(PWD)/../.. +LIBDIR ?= $(GPUCPP)/third_party/lib +LIBSPEC ?= . $(GPUCPP)/source + +ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/null 2>&1 ; echo $$?),0) + STDLIB := +else + STDLIB := -stdlib=libc++ +endif + +FLAGS=-shared -fPIC -std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib -ldawn \ + `python3 -m pybind11 --includes` \ + `python3-config --include --ldflags --embed` + +SUFFIX=$(shell $(PYTHON)-config --extension-suffix) + +gpu_cpp$(SUFFIX): gpu_cpp.cpp + $(CXX) $(FLAGS) -o $@ $< + install_name_tool -change @rpath/libdawn.dylib $(LIBDIR)/libdawn.dylib gpu_cpp$(SUFFIX) + +test: test_gpu_cpp.py gpu_cpp$(SUFFIX) + $(PYTHON) test_gpu_cpp.py + +.PHONY: test diff --git a/bindings/python/gpu_cpp.cpp b/bindings/python/gpu_cpp.cpp new file mode 100644 index 0000000..a09d577 --- /dev/null +++ b/bindings/python/gpu_cpp.cpp @@ -0,0 +1,112 @@ +#include "gpu.hpp" +#include +#include +#include + +using namespace gpu; + +#include +#include +#include + +namespace py = pybind11; + +Shape vector_to_shape(const std::vector &dims) { + switch(dims.size()){ + case 1: + return Shape{(unsigned long)dims[0]}; + break; + case 2: + return Shape{(unsigned long)dims[0],(unsigned long)dims[1]}; + break; + case 3: + return Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2]}; + break; + case 4: + return Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2],(unsigned long)dims[3]}; + break; + case 5: + return Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2],(unsigned long)dims[3],(unsigned long)dims[4]}; + break; + } + return Shape{0}; +} + +Context* py_createContext() { + return new Context(createContext()); +} + +KernelCode* py_createKernelCode(const std::string &pData, size_t workgroupSize, int precision) { + return new KernelCode(pData, workgroupSize, (NumType)precision); +} + +Kernel* py_createKernel(Context *ctx, const KernelCode *code, + // const Tensor *dataBindings, size_t numTensors, + const py::list& dataBindings_py, + // const size_t *viewOffsets, + const py::list& viewOffsets_py, + const std::vector &totalWorkgroups){ + std::vector bindings; + for (auto item : dataBindings_py) { + bindings.push_back(item.cast()); + } + std::vector viewOffsets; + for (auto item : viewOffsets_py) { + viewOffsets.push_back(item.cast()); + } + return new Kernel(createKernel(*ctx, *code, bindings.data(), bindings.size(), viewOffsets.data(), vector_to_shape(totalWorkgroups))); +} + +Tensor* py_createTensor(Context *ctx, const std::vector &dims, int dtype) { + return new Tensor(createTensor(*ctx, vector_to_shape(dims), (NumType)dtype)); +} + +py::array_t py_toCPU_float(Context *ctx, Tensor* tensor) { + auto result = py::array_t(tensor->data.size/sizeof(float)); + py::buffer_info buf = result.request(); + toCPU(*ctx, *tensor, static_cast(buf.ptr), tensor->data.size); + return result; +} + + +void py_toGPU_float(Context *ctx, py::array_t array, Tensor *tensor) { + py::buffer_info buf = array.request(); + float *ptr = static_cast(buf.ptr); + toGPU(*ctx, ptr, *tensor); +} + + +struct GpuAsync { + std::promise promise; + std::future future ; + GpuAsync(): future(promise.get_future()){ + } +}; + +GpuAsync* py_dispatchKernel(Context *ctx, Kernel *kernel) { + auto async = new GpuAsync(); + dispatchKernel(*ctx, *kernel, async->promise); + return async; +} + +void py_wait(Context *ctx, GpuAsync* async) { + wait(*ctx, async->future); +} + +PYBIND11_MODULE(gpu_cpp, m) { + m.doc() = "gpu.cpp plugin"; + py::class_(m, "Context"); + py::class_(m, "Tensor"); + py::class_(m, "Kernel"); + py::class_(m, "KernelCode"); + py::class_(m, "GpuAsync"); + m.def("create_context", &py_createContext, py::return_value_policy::take_ownership); + m.def("create_tensor", &py_createTensor, py::return_value_policy::take_ownership); + m.def("create_kernel", &py_createKernel, py::return_value_policy::take_ownership); + m.def("create_kernel_code", &py_createKernelCode, py::return_value_policy::take_ownership); + m.def("dispatch_kernel", &py_dispatchKernel, py::return_value_policy::take_ownership); + m.def("wait", &py_wait, "Wait for GPU"); + m.def("to_cpu_float", &py_toCPU_float); + m.def("to_gpu_float", &py_toGPU_float); + m.attr("kf32") = (int)kf32; +} diff --git a/bindings/python/test_gpu_cpp.py b/bindings/python/test_gpu_cpp.py new file mode 100644 index 0000000..ad50c6a --- /dev/null +++ b/bindings/python/test_gpu_cpp.py @@ -0,0 +1,39 @@ +import gpu_cpp as gpu +import numpy as np + +ctx = gpu.create_context() + +N = 12 + +input = gpu.create_tensor(ctx, [N], gpu.kf32) +output = gpu.create_tensor(ctx, [N], gpu.kf32) +kernel_code = gpu.create_kernel_code( + """ + const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) + @group(0) @binding(0) var inp: array<{{precision}}>; + @group(0) @binding(1) var out: array<{{precision}}>; + @group(0) @binding(1) var dummy: array<{{precision}}>; + @compute @workgroup_size({{workgroupSize}}) + fn main( + @builtin(global_invocation_id) GlobalInvocationID: vec3) { + let i: u32 = GlobalInvocationID.x; + if (i < arrayLength(&inp)) { + let x: f32 = inp[i]; + out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR + * (x + .044715 * x * x * x))), x, x > 10.0); + } + } + """, + 256, + gpu.kf32 + ) + +kernel = gpu.create_kernel(ctx, kernel_code, [input, output], [0,0], [12,1,1]) + +gpu.to_gpu_float(ctx, np.array([1,2,3,4,1,2,3,4,1,2,3,4],np.float32), input) + +gpu_async = gpu.dispatch_kernel(ctx, kernel); + +gpu.wait(ctx, gpu_async); + +print(gpu.to_cpu_float(ctx, output)) From 3228b1b432a1f96dcb8501c9c6107897cd233560 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sat, 28 Dec 2024 15:54:30 +0900 Subject: [PATCH 28/44] Add haskell bindings --- bindings/haskell/CHANGELOG.md | 5 + bindings/haskell/Makefile | 3 + bindings/haskell/app/Main.hs | 37 +++++ bindings/haskell/gpu-cpp.cabal | 49 +++++++ bindings/haskell/src/GpuCpp.hs | 207 +++++++++++++++++++++++++++ bindings/haskell/src/GpuCpp/Types.hs | 40 ++++++ bindings/haskell/test/Main.hs | 49 +++++++ bindings/python/gpu_cpp.cpp | 1 - 8 files changed, 390 insertions(+), 1 deletion(-) create mode 100644 bindings/haskell/CHANGELOG.md create mode 100644 bindings/haskell/Makefile create mode 100644 bindings/haskell/app/Main.hs create mode 100644 bindings/haskell/gpu-cpp.cabal create mode 100644 bindings/haskell/src/GpuCpp.hs create mode 100644 bindings/haskell/src/GpuCpp/Types.hs create mode 100644 bindings/haskell/test/Main.hs diff --git a/bindings/haskell/CHANGELOG.md b/bindings/haskell/CHANGELOG.md new file mode 100644 index 0000000..d20679e --- /dev/null +++ b/bindings/haskell/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for gpu-cpp + +## 0.1.0.0 -- 2024-12-28 + +* First version. diff --git a/bindings/haskell/Makefile b/bindings/haskell/Makefile new file mode 100644 index 0000000..7ca37a0 --- /dev/null +++ b/bindings/haskell/Makefile @@ -0,0 +1,3 @@ +all: + cabal configure --extra-include-dirs=$(PWD)/../.. --extra-include-dirs=$(PWD)/../../third_party/headers --extra-lib-dirs=$(PWD)/../../third_party/lib + cabal build . diff --git a/bindings/haskell/app/Main.hs b/bindings/haskell/app/Main.hs new file mode 100644 index 0000000..ba1ae6d --- /dev/null +++ b/bindings/haskell/app/Main.hs @@ -0,0 +1,37 @@ +module Main where + +import GpuCpp.Types +import GpuCpp +import qualified Data.Vector.Storable as V +import Foreign.C.Types + +main :: IO () +main = do + context <- createContext + input <- createTensor context [12] kf32 + output <- createTensor context [12] kf32 + kernelCode <- createKernelCode + ( + "const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI)\n" <> + "@group(0) @binding(0) var inp: array<{{precision}}>;\n" <> + "@group(0) @binding(1) var out: array<{{precision}}>;\n" <> + "@group(0) @binding(1) var dummy: array<{{precision}}>;\n" <> + "@compute @workgroup_size({{workgroupSize}})\n" <> + "fn main(\n" <> + " @builtin(global_invocation_id) GlobalInvocationID: vec3) {\n" <> + " let i: u32 = GlobalInvocationID.x;\n" <> + " if (i < arrayLength(&inp)) {\n" <> + " let x: f32 = inp[i];\n" <> + " out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR \n" <> + " * (x + .044715 * x * x * x))), x, x > 10.0);\n" <> + " }\n" <> + "}\n" + ) + 256 + kf32 + kernel <- createKernel context kernelCode [input, output] [0,0] [12,1,1] + toGpu context (V.fromList [1 :: CFloat,2,3,4,1,2,3,4,1,2,3,4]) input + async <- dispatchKernel context kernel + wait context async + vec <- toCpu context output :: IO (V.Vector CFloat) + print vec diff --git a/bindings/haskell/gpu-cpp.cabal b/bindings/haskell/gpu-cpp.cabal new file mode 100644 index 0000000..39bab54 --- /dev/null +++ b/bindings/haskell/gpu-cpp.cabal @@ -0,0 +1,49 @@ +cabal-version: 3.0 +name: gpu-cpp +version: 0.1.0.0 +license: BSD-3-Clause +author: Junji Hashimoto +maintainer: junji.hashimoto@gmail.com +category: Math +build-type: Simple + +extra-doc-files: CHANGELOG.md + +common warnings + ghc-options: -Wall + +library + import: warnings + exposed-modules: GpuCpp + , GpuCpp.Types + build-depends: base ^>=4.18.1.0 + , inline-c + , inline-c-cpp + , containers + , template-haskell + , safe-exceptions + , vector + hs-source-dirs: src + default-language: Haskell2010 + ghc-options: -optcxx-std=c++17 + extra-libraries: dawn + +executable gpu-cpp + import: warnings + main-is: Main.hs + build-depends: base ^>=4.18.1.0 + , gpu-cpp + , vector + hs-source-dirs: app + default-language: Haskell2010 + +test-suite gpu-cpp-test + import: warnings + default-language: Haskell2010 + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Main.hs + build-depends: base ^>=4.18.1.0 + , gpu-cpp + , vector + , hspec diff --git a/bindings/haskell/src/GpuCpp.hs b/bindings/haskell/src/GpuCpp.hs new file mode 100644 index 0000000..2177ecf --- /dev/null +++ b/bindings/haskell/src/GpuCpp.hs @@ -0,0 +1,207 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FlexibleInstances #-} + +module GpuCpp where + +import qualified Language.C.Inline.Cpp as C +import qualified Language.C.Inline.Cpp.Unsafe as C +import qualified Language.C.Inline.Context as C +import Foreign.C.String +import Foreign.C.Types +import GHC.Int +import GHC.ForeignPtr(mallocPlainForeignPtrBytes) +import Foreign +import Control.Monad (forM_) +import GpuCpp.Types +import Control.Exception.Safe (bracket) +import qualified Data.Vector.Storable as V + +C.context $ C.cppCtx <> mempty { C.ctxTypesTable = typeTable } + +C.include "" +C.include "" +C.include "" + +[C.emitBlock| +struct GpuAsync { + std::promise promise; + std::future future; + GpuAsync(): future(promise.get_future()){ + } +}; + +gpu::Shape vector_to_shape(const std::vector &dims) { + switch(dims.size()){ + case 1: + return gpu::Shape{(unsigned long)dims[0]}; + break; + case 2: + return gpu::Shape{(unsigned long)dims[0],(unsigned long)dims[1]}; + break; + case 3: + return gpu::Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2]}; + break; + case 4: + return gpu::Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2],(unsigned long)dims[3]}; + break; + case 5: + return gpu::Shape{(unsigned long)dims[0],(unsigned long)dims[1],(unsigned long)dims[2],(unsigned long)dims[3],(unsigned long)dims[4]}; + break; + } + return gpu::Shape{0}; +} +|] + +kf32 :: CInt +kf32 = [C.pure| int { (int)gpu::kf32 } |] + +createContext :: IO (ForeignPtr Context) +createContext = + [C.throwBlock| gpu::Context* { return new gpu::Context(gpu::createContext()); }|] >>= + newForeignPtr + [C.funPtr| void deleteContext(gpu::Context* ptr) { delete ptr; }|] + + +createKernelCode :: String -> CInt -> CInt -> IO (ForeignPtr KernelCode) +createKernelCode kernelString workgroupSize precision = + withCString kernelString $ \pData -> + [C.throwBlock| gpu::KernelCode* { return new gpu::KernelCode($(char* pData), $(int workgroupSize), (gpu::NumType)$(int precision)); }|] >>= + newForeignPtr + [C.funPtr| void deleteKernelCode(gpu::KernelCode* ptr) { delete ptr; }|] + + +dispatchKernel :: ForeignPtr Context -> ForeignPtr Kernel -> IO (ForeignPtr GpuAsync) +dispatchKernel context kernel = + withForeignPtr context $ \c -> + withForeignPtr kernel $ \k -> + [C.throwBlock| GpuAsync* { + auto async = new GpuAsync(); + gpu::dispatchKernel(*$(gpu::Context* c), *$(gpu::Kernel* k), async->promise); + return async; }|] >>= + newForeignPtr + [C.funPtr| void deleteGpuAsync(GpuAsync* ptr) { delete ptr; }|] + +wait :: ForeignPtr Context -> ForeignPtr GpuAsync -> IO () +wait context async = + withForeignPtr context $ \c -> + withForeignPtr async $ \a -> + [C.throwBlock| void { + gpu::wait(*$(gpu::Context* c), $(GpuAsync* a)->future); + }|] + +instance WithVector CInt Int64 where + withVector shape func = + bracket + (do + let len = fromIntegral $ length shape + vec <- [C.throwBlock| std::vector* { + return new std::vector($(int len)); + }|] + ptr <- [C.throwBlock| int64_t* { + return $(std::vector* vec)->data(); + }|] + pokeArray ptr (map fromIntegral shape) + return vec + ) + (\vec -> [C.block| void { delete $(std::vector* vec); }|]) + (\vec -> func vec) + +instance WithVector CInt CSize where + withVector shape func = + bracket + (do + let len = fromIntegral $ length shape + vec <- [C.throwBlock| std::vector* { + return new std::vector($(int len)); + }|] + ptr <- [C.throwBlock| size_t* { + return $(std::vector* vec)->data(); + }|] + pokeArray ptr (map fromIntegral shape) + return vec + ) + (\vec -> [C.block| void { delete $(std::vector* vec); }|]) + (\vec -> func vec) + +instance WithVector (Ptr Tensor) Tensor where + withVector ptrs func = + bracket (do + vec <- [C.throwBlock| std::vector* { return new std::vector(); }|] + forM_ ptrs $ do + \ptr -> [C.throwBlock| void { $(std::vector* vec)->push_back(*$(gpu::Tensor* ptr)); }|] + return vec + ) + (\vec -> [C.block| void { delete $(std::vector* vec); }|]) + (\vec -> func vec) + +withForeignPtrs :: [ForeignPtr a] -> ([Ptr a] -> IO b) -> IO b +withForeignPtrs [] func = func [] +withForeignPtrs (x:xs) func = + withForeignPtr x $ \x' -> + withForeignPtrs xs $ \xs' -> + func (x':xs') + +createKernel :: ForeignPtr Context -> ForeignPtr KernelCode -> [ForeignPtr Tensor] -> [Int] -> [Int] -> IO (ForeignPtr Kernel) +createKernel context kernelCode dataBindings viewOffsets totalWorkgroups = + withForeignPtr context $ \c -> + withForeignPtr kernelCode $ \k -> + withForeignPtrs dataBindings $ \b -> + withVector b $ \b' -> + withVector @CInt (map fromIntegral viewOffsets) $ \v -> + withVector @CInt (map fromIntegral totalWorkgroups) $ \w -> + [C.throwBlock| gpu::Kernel* { + return new gpu::Kernel(gpu::createKernel( + *$(gpu::Context* c), + *$(gpu::KernelCode* k), + $(std::vector* b')->data(), + $(std::vector* b')->size(), + $(std::vector* v)->data(), + vector_to_shape(*$(std::vector* w)))); + }|] >>= + newForeignPtr + [C.funPtr| void deleteKernel(gpu::Kernel* ptr) { delete ptr; }|] + +createTensor :: ForeignPtr Context -> [CInt] -> CInt -> IO (ForeignPtr Tensor) +createTensor context shape dtype = + withVector shape $ \s -> + withForeignPtr context $ \c -> + [C.throwBlock| gpu::Tensor* { + return new gpu::Tensor(gpu::createTensor(*$(gpu::Context* c), vector_to_shape(*$(std::vector* s)), (gpu::NumType)$(int dtype))); + }|] >>= + newForeignPtr + [C.funPtr| void deleteTensor(gpu::Tensor* ptr) { delete ptr; }|] + +createVector :: forall a. Storable a => Int -> IO (V.Vector a) +createVector n = do + ptr <- mallocPlainForeignPtrBytes (n * sizeOf (undefined :: a)) + return $ V.unsafeFromForeignPtr ptr 0 n + +instance GpuStorable CFloat where + toGpu context array tensor = + withForeignPtr context $ \c -> + withForeignPtr tensor $ \t -> + V.unsafeWith array $ \ptr -> + [C.throwBlock| void { + gpu::toGPU(*$(gpu::Context* c), $(float* ptr), *$(gpu::Tensor* t)); + }|] + toCpu context tensor = + withForeignPtr context $ \c -> + withForeignPtr tensor $ \t -> do + (size :: CInt) <- [C.block| int { + size_t u = sizeof(float); + size_t len = $(gpu::Tensor* t)->data.size; + return len/u; + }|] + array <- createVector (fromIntegral size) + V.unsafeWith array $ \ptr -> + [C.throwBlock| void { + gpu::toCPU(*$(gpu::Context* c), *$(gpu::Tensor* t), $(float* ptr), $(int size) * sizeof(float)); + }|] + return array diff --git a/bindings/haskell/src/GpuCpp/Types.hs b/bindings/haskell/src/GpuCpp/Types.hs new file mode 100644 index 0000000..3905aa7 --- /dev/null +++ b/bindings/haskell/src/GpuCpp/Types.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE MultiParamTypeClasses #-} + +module GpuCpp.Types where + +import qualified Language.C.Types as C +import qualified Language.Haskell.TH.Lib as TH +import qualified Data.Map as Map +import Foreign +import qualified Data.Vector.Storable as V + +data Context +data Tensor +data Kernel +data KernelCode +data GpuAsync +data StdVector a + +typeTable :: Map.Map C.TypeSpecifier TH.TypeQ +typeTable = Map.fromList [ + (C.TypeName "gpu::Context", [t|Context|]) + , (C.TypeName "gpu::Tensor", [t|Tensor|]) + , (C.TypeName "gpu::Kernel", [t|Kernel|]) + , (C.TypeName "gpu::KernelCode", [t|KernelCode|]) + , (C.TypeName "GpuAsync", [t|GpuAsync|]) + , (C.TypeName "std::vector", [t|StdVector|]) + ] + + +class WithVector a b where + withVector :: [a] -> (Ptr (StdVector b) -> IO c) -> IO c + +class GpuStorable a where + toGpu :: ForeignPtr Context -> V.Vector a -> ForeignPtr Tensor -> IO () + toCpu :: ForeignPtr Context -> ForeignPtr Tensor -> IO (V.Vector a) + diff --git a/bindings/haskell/test/Main.hs b/bindings/haskell/test/Main.hs new file mode 100644 index 0000000..d66e5c1 --- /dev/null +++ b/bindings/haskell/test/Main.hs @@ -0,0 +1,49 @@ +module Main (main) where + +import Test.Hspec +import GpuCpp.Types +import GpuCpp +import qualified Data.Vector.Storable as V +import Foreign.C.Types + +gelu :: String +gelu= "const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI)\n" <> + "@group(0) @binding(0) var inp: array<{{precision}}>;\n" <> + "@group(0) @binding(1) var out: array<{{precision}}>;\n" <> + "@group(0) @binding(1) var dummy: array<{{precision}}>;\n" <> + "@compute @workgroup_size({{workgroupSize}})\n" <> + "fn main(\n" <> + " @builtin(global_invocation_id) GlobalInvocationID: vec3) {\n" <> + " let i: u32 = GlobalInvocationID.x;\n" <> + " if (i < arrayLength(&inp)) {\n" <> + " let x: f32 = inp[i];\n" <> + " out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR \n" <> + " * (x + .044715 * x * x * x))), x, x > 10.0);\n" <> + " }\n" <> + "}\n" + +main :: IO () +main = do + hspec $ do + describe "toCPU and toGPU" $ do + it "writes and reads back" $ do + context <- createContext + input <- createTensor context [12] kf32 + toGpu context (V.fromList [1 :: CFloat,2,3,4,1,2,3,4,1,2,3,4]) input + output <- toCpu context input :: IO (V.Vector CFloat) + V.toList output `shouldBe` [1,2,3,4,1,2,3,4,1,2,3,4] + describe "call kernel" $ do + it "gelu" $ do + context <- createContext + input <- createTensor context [12] kf32 + output <- createTensor context [12] kf32 + kernelCode <- createKernelCode gelu 256 kf32 + kernel <- createKernel context kernelCode [input, output] [0,0] [12,1,1] + toGpu context (V.fromList [1 :: CFloat,2,3,4,1,2,3,4,1,2,3,4]) input + async <- dispatchKernel context kernel + wait context async + vec <- toCpu context output :: IO (V.Vector CFloat) + V.toList (V.zipWith (\a b -> abs (a - b)) + vec + (V.fromList [0.841192,1.9545977,2.9963627,3.9999297,0.841192,1.9545977,2.9963627,3.9999297,0.841192,1.9545977,2.9963627,3.9999297])) + `shouldSatisfy` all (< 0.001) diff --git a/bindings/python/gpu_cpp.cpp b/bindings/python/gpu_cpp.cpp index a09d577..51f7c9e 100644 --- a/bindings/python/gpu_cpp.cpp +++ b/bindings/python/gpu_cpp.cpp @@ -75,7 +75,6 @@ void py_toGPU_float(Context *ctx, py::array_t array, Tensor *tensor) { toGPU(*ctx, ptr, *tensor); } - struct GpuAsync { std::promise promise; std::future future ; From 4669791a71c574f40c419d9dae1dc785c7afaab8 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:16:08 -0500 Subject: [PATCH 29/44] migrate to updated dawn commit 556f960f44690b3b808c779c08b44d48d4292925 - main library builds and examples run successfully, TODOs are code cleanup + experimental + non-make builds --- Makefile | 14 +- examples/Makefile | 2 +- examples/float16/Makefile | 4 +- examples/gpu_puzzles/Makefile | 4 +- examples/hello_world/Makefile | 4 +- examples/matmul/Makefile | 4 +- examples/physics/Makefile | 2 +- examples/render/Makefile | 2 +- examples/shadertui/Makefile | 2 +- examples/transpose/Makefile | 2 +- gpu.hpp | 328 +-- third_party/headers/webgpu/webgpu.h | 3099 +++++++++++++-------------- 12 files changed, 1702 insertions(+), 1765 deletions(-) diff --git a/Makefile b/Makefile index 063dda7..8e5d67b 100644 --- a/Makefile +++ b/Makefile @@ -27,17 +27,17 @@ HEADER_PATH ?= /usr/include ifeq ($(OS), Linux) OS_TYPE ?= Linux GPU_CPP_LIB_NAME ?= libgpucpp.so -DAWN_LIB_NAME ?= libdawn.so +DAWN_LIB_NAME ?= libwebgpu_dawn.so else ifeq ($(OS), Darwin) OS_TYPE ?= macOS GPU_CPP_LIB_NAME ?= libgpucpp.dylib -DAWN_LIB_NAME ?= libdawn.dylib +DAWN_LIB_NAME ?= libwebgpu_dawn.dylib else OS_TYPE ?= unknown endif lib: check-clang dawnlib - mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -ldawn -ldl -shared -fPIC gpu.cpp -o build/$(GPU_CPP_LIB_NAME) + mkdir -p build && $(CXX) -std=c++17 $(INCLUDES) -L$(LIBDIR) -lwebgpu_dawn -ldl -shared -fPIC gpu.cpp -o build/$(GPU_CPP_LIB_NAME) python3 build.py cp third_party/lib/$(DAWN_LIB_NAME) build/ @@ -54,7 +54,7 @@ uninstall: examples/hello_world/build/hello_world: check-clang dawnlib examples/hello_world/run.cpp check-linux-vulkan $(LIBSPEC) && cd examples/hello_world && make build/hello_world && ./build/hello_world -dawnlib: $(if $(wildcard third_party/lib/libdawn.so third_party/lib/libdawn.dylib),,run_setup) +dawnlib: $(if $(wildcard third_party/lib/libwebgpu_dawn.so third_party/lib/libwebgpu_dawn.dylib),,run_setup) run_setup: check-python python3 setup.py @@ -71,7 +71,7 @@ all: dawnlib check-clang check-linux-vulkan lib pch # Test 16-bit floating point type test-half: dawnlib check-clang - $(LIBSPEC) && clang++ -std=c++17 $(INCLUDES) numeric_types/half.cpp -L$(LIBDIR) -ldawn -ldl -o build/half && ./build/half + $(LIBSPEC) && clang++ -std=c++17 $(INCLUDES) numeric_types/half.cpp -L$(LIBDIR) -lwebgpu_dawn -ldl -o build/half && ./build/half docs: Doxyfile doxygen Doxyfile @@ -102,7 +102,7 @@ all-cmake: check-clang check-cmake ################################################################################ clean-dawnlib: - rm -f third_party/lib/libdawn.so third_party/lib/libdawn.dylib + rm -f third_party/lib/libwebgpu_dawn.so third_party/lib/libwebgpu_dawn.dylib clean: read -r -p "This will delete the contents of build/*. Are you sure? [CTRL-C to abort] " response && rm -rf build/* @@ -119,7 +119,7 @@ clean: rm -f build/half clean-all: - read -r -p "This will delete the contents of build/* and third_party/*. Are you sure? [CTRL-C to abort] " response && rm -rf build/* third_party/fetchcontent/* third_party/gpu-build third_party/gpu-subbuild third_party/gpu-src third_party/lib/libdawn.so third_party/lib/libdawn.dylib + read -r -p "This will delete the contents of build/* and third_party/*. Are you sure? [CTRL-C to abort] " response && rm -rf build/* third_party/fetchcontent/* third_party/gpu-build third_party/gpu-subbuild third_party/gpu-src third_party/lib/libwebgpu_dawn.so third_party/lib/libwebgpu_dawn.dylib ################################################################################ # Checks diff --git a/examples/Makefile b/examples/Makefile index 6420619..3036e22 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -14,7 +14,7 @@ else endif FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib -LFLAGS=-ldl -ldawn +LFLAGS=-ldl -lwebgpu_dawn .PHONY: default all_release all_debug dawnlib run_setup check-python .PHONY: $(addsuffix _release, $(TARGETS)) diff --git a/examples/float16/Makefile b/examples/float16/Makefile index 54835d9..51e895a 100644 --- a/examples/float16/Makefile +++ b/examples/float16/Makefile @@ -9,12 +9,12 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) dawnlib $(LIBSPEC) && ./build/$(TARGET) -dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libdawn.so $(GPUCPP)/third_party/lib/libdawn.dylib),,run_setup) +dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libwebgpu_dawn.so $(GPUCPP)/third_party/lib/libwebgpu_dawn.dylib),,run_setup) run_setup: check-python cd $(GPUCPP) && python3 setup.py diff --git a/examples/gpu_puzzles/Makefile b/examples/gpu_puzzles/Makefile index 849240c..90dfc2d 100644 --- a/examples/gpu_puzzles/Makefile +++ b/examples/gpu_puzzles/Makefile @@ -9,8 +9,8 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn -FLAGS_KEY=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib key.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn +FLAGS_KEY=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib key.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) $(LIBSPEC) && ./build/$(TARGET) diff --git a/examples/hello_world/Makefile b/examples/hello_world/Makefile index 085c7ea..7e64553 100644 --- a/examples/hello_world/Makefile +++ b/examples/hello_world/Makefile @@ -9,12 +9,12 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) dawnlib $(LIBSPEC) && ./build/$(TARGET) -dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libdawn.so $(GPUCPP)/third_party/lib/libdawn.dylib),,run_setup) +dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libwebgpu_dawn.so $(GPUCPP)/third_party/lib/libwebgpu_dawn.dylib),,run_setup) run_setup: check-python cd $(GPUCPP) && python3 setup.py diff --git a/examples/matmul/Makefile b/examples/matmul/Makefile index 78d3c0e..03cd20e 100644 --- a/examples/matmul/Makefile +++ b/examples/matmul/Makefile @@ -10,7 +10,7 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) $(LIBSPEC) && ./build/$(TARGET) @@ -28,7 +28,7 @@ build/$(TARGET): run.cpp build/$(TARGET)_with_metal_profiler: run.cpp mkdir -p build && $(CXX) $(FLAGS) -o ./build/$(TARGET)_with_metal_profiler $(GPUCPP)/experimental/profiler/metal.mm -framework metal -framework Foundation -DMETAL_PROFILER -g - install_name_tool -change @rpath/libdawn.dylib $(GPUCPP)/third_party/lib/libdawn.dylib ./build/$(TARGET)_with_metal_profiler + install_name_tool -change @rpath/libwebgpu_dawn.dylib $(GPUCPP)/third_party/lib/libwebgpu_dawn.dylib ./build/$(TARGET)_with_metal_profiler watch: @command -v entr >/dev/null 2>&1 || { echo >&2 "Please install entr with 'brew install entr' or 'sudo apt-get install entr'"; exit 1; } diff --git a/examples/physics/Makefile b/examples/physics/Makefile index 7cdd3f5..10cfb13 100644 --- a/examples/physics/Makefile +++ b/examples/physics/Makefile @@ -9,7 +9,7 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) $(LIBSPEC) && ./build/$(TARGET) diff --git a/examples/render/Makefile b/examples/render/Makefile index 552bbf0..d07048c 100644 --- a/examples/render/Makefile +++ b/examples/render/Makefile @@ -9,7 +9,7 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) $(LIBSPEC) && ./build/$(TARGET) diff --git a/examples/shadertui/Makefile b/examples/shadertui/Makefile index 82daef1..81c740b 100644 --- a/examples/shadertui/Makefile +++ b/examples/shadertui/Makefile @@ -10,7 +10,7 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) diff --git a/examples/transpose/Makefile b/examples/transpose/Makefile index dca2fb6..1495c96 100644 --- a/examples/transpose/Makefile +++ b/examples/transpose/Makefile @@ -10,7 +10,7 @@ ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/nu else STDLIB := -stdlib=libc++ endif -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn +FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -lwebgpu_dawn run: ./build/$(TARGET) $(LIBSPEC) && ./build/$(TARGET) diff --git a/gpu.hpp b/gpu.hpp index b1dd1cb..d3641f7 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -26,8 +26,6 @@ #ifdef USE_DAWN_API #include "dawn/native/DawnNative.h" - -typedef WGPUBufferUsage WGPUBufferUsageFlags; #endif namespace gpu { @@ -37,7 +35,7 @@ namespace gpu { */ struct Array { WGPUBuffer buffer; - WGPUBufferUsageFlags usage; + WGPUBufferUsage usage; size_t size; // in bytes }; @@ -580,7 +578,7 @@ struct Context { inline Tensor createTensor(TensorPool &pool, WGPUDevice &device, const Shape &shape, NumType dtype, - WGPUBufferUsageFlags usage = WGPUBufferUsage_Storage | + WGPUBufferUsage usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | WGPUBufferUsage_CopySrc) { LOG(kDefLog, kTrace, "Creating tensor"); @@ -603,7 +601,7 @@ createTensor(TensorPool &pool, WGPUDevice &device, const Shape &shape, * the GPU with a given shape and data type. * * Instead of taking the TensoPool and raw WebGPU API WGPUDevice and - * WGPUBufferUsageFlags arguments, this is a convenience wrapper around the + * WGPUBufferUsage arguments, this is a convenience wrapper around the * core createTensor function which has default usage flags for a storage * buffer, and also takes in the Context object. * @@ -782,8 +780,6 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, Context context; { #ifdef __EMSCRIPTEN__ - // Emscripten does not support the instance descriptor - // and throws an assertion error if it is not nullptr. context.instance = wgpuCreateInstance(nullptr); #else context.instance = wgpuCreateInstance(&desc); @@ -798,30 +794,37 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, bool requestEnded = false; }; AdapterData adapterData; + auto onAdapterRequestEnded = [](WGPURequestAdapterStatus status, - WGPUAdapter adapter, char const *message, - void *pUserData) { - AdapterData &adapterData = *reinterpret_cast(pUserData); + WGPUAdapter adapter, WGPUStringView message, + void *pUserData, void *) { + AdapterData &adapterData = *reinterpret_cast(pUserData); #ifdef __EMSCRIPTEN__ - if (status != WGPURequestAdapterStatus_Success) { - LOG(kDefLog, kError, "Could not get WebGPU adapter: %s", message); - LOG(kDefLog, kError, - "\n\nA common reason is that the browser does not have WebGPU " - "enabled, particularly on Linux.\n" - "- Open `chrome://flags/` in the browser and make sure " - "\"WebGPU Support\" is enabled.\n" - "- Chrome is launched with vulkan enabled. From the command line " - "launch chrome as `google-chrome --enable-features=Vulkan`\n"); - } + if (status != WGPURequestAdapterStatus_Success) { + LOG(kDefLog, kError, "Could not get WebGPU adapter: %.*s", + static_cast(message.length), message.data); + LOG(kDefLog, kError, + "\n\nA common reason is that the browser does not have WebGPU " + "enabled, particularly on Linux.\n" + "- Open `chrome://flags/` in the browser and make sure " + "\"WebGPU Support\" is enabled.\n" + "- Chrome is launched with vulkan enabled. From the command line " + "launch chrome as `google-chrome --enable-features=Vulkan`\n"); + } #endif - check(status == WGPURequestAdapterStatus_Success, - "Request WebGPU adapter", __FILE__, __LINE__); - adapterData.adapter = adapter; - adapterData.requestEnded = true; + check(status == WGPURequestAdapterStatus_Success, + "Request WebGPU adapter", __FILE__, __LINE__); + adapterData.adapter = adapter; + adapterData.requestEnded = true; }; - wgpuInstanceRequestAdapter(context.instance, &adapterOpts, - onAdapterRequestEnded, (void *)&adapterData); + WGPURequestAdapterCallbackInfo callbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = onAdapterRequestEnded, + .userdata1 = &adapterData, + .userdata2 = nullptr + }; + wgpuInstanceRequestAdapter(context.instance, &adapterOpts, callbackInfo); while (!adapterData.requestEnded) { processEvents(context.instance); @@ -837,19 +840,27 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, bool requestEnded = false; }; DeviceData devData; + auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, - WGPUDevice device, char const *message, - void *pUserData) { - DeviceData &devData = *reinterpret_cast(pUserData); - check(status == WGPURequestDeviceStatus_Success, - "Could not get WebGPU device.", __FILE__, __LINE__); - LOG(kDefLog, kTrace, "Device Request succeeded %x", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; + WGPUDevice device, WGPUStringView message, + void *pUserData, void *) { + DeviceData &devData = *reinterpret_cast(pUserData); + check(status == WGPURequestDeviceStatus_Success, + "Could not get WebGPU device.", __FILE__, __LINE__); + LOG(kDefLog, kTrace, "Device Request succeeded %x", + static_cast(device)); + devData.device = device; + devData.requestEnded = true; }; - wgpuAdapterRequestDevice(context.adapter, &devDescriptor, - onDeviceRequestEnded, (void *)&devData); + + WGPURequestDeviceCallbackInfo deviceCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = onDeviceRequestEnded, + .userdata1 = &devData, + .userdata2 = nullptr + }; + wgpuAdapterRequestDevice(context.adapter, &devDescriptor, deviceCallbackInfo); + LOG(kDefLog, kInfo, "Waiting for device request to end"); while (!devData.requestEnded) { processEvents(context.instance); @@ -857,13 +868,18 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, LOG(kDefLog, kInfo, "Device request ended"); assert(devData.requestEnded); context.device = devData.device; - wgpuDeviceSetUncapturedErrorCallback( - context.device, - [](WGPUErrorType type, char const *message, void *devData) { - LOG(kDefLog, kError, "Device uncaptured error: %s", message); - throw std::runtime_error("Device uncaptured exception."); + + WGPULoggingCallbackInfo loggingCallbackInfo = { + .callback = [](WGPULoggingType type, WGPUStringView message, void* userdata1, void* userdata2) { + LOG(kDefLog, kError, "Device logging callback: %.*s", (int)message.length, message.data); + if (type == WGPULoggingType_Error) { + throw std::runtime_error("Device error logged."); + } }, - nullptr); + .userdata1 = nullptr, + .userdata2 = nullptr + }; + wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo); } context.queue = wgpuDeviceGetQueue(context.device); return context; @@ -947,19 +963,28 @@ inline Context createContextByGpuIdx(int gpuIdx, bool requestEnded = false; }; DeviceData devData; + auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, - WGPUDevice device, char const *message, - void *pUserData) { - DeviceData &devData = *reinterpret_cast(pUserData); - check(status == WGPURequestDeviceStatus_Success, - "Could not get WebGPU device.", __FILE__, __LINE__); - LOG(kDefLog, kTrace, "Device Request succeeded %x", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; + WGPUDevice device, WGPUStringView message, + void *pUserData, void *) { + DeviceData &devData = *reinterpret_cast(pUserData); + check(status == WGPURequestDeviceStatus_Success, + "Could not get WebGPU device.", __FILE__, __LINE__); + LOG(kDefLog, kTrace, "Device Request succeeded %x", + static_cast(device)); + devData.device = device; + devData.requestEnded = true; +}; + + WGPURequestDeviceCallbackInfo deviceCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = onDeviceRequestEnded, + .userdata1 = &devData, + .userdata2 = nullptr }; - wgpuAdapterRequestDevice(context.adapter, &devDescriptor, - onDeviceRequestEnded, (void *)&devData); + wgpuAdapterRequestDevice(context.adapter, &devDescriptor, deviceCallbackInfo); + + LOG(kDefLog, kInfo, "Waiting for device request to end"); while (!devData.requestEnded) { processEvents(context.instance); @@ -967,13 +992,19 @@ inline Context createContextByGpuIdx(int gpuIdx, LOG(kDefLog, kInfo, "Device request ended"); assert(devData.requestEnded); context.device = devData.device; - wgpuDeviceSetUncapturedErrorCallback( - context.device, - [](WGPUErrorType type, char const *message, void *devData) { - LOG(kDefLog, kError, "Device uncaptured error: %s", message); - throw std::runtime_error("Device uncaptured exception."); + + WGPULoggingCallbackInfo loggingCallbackInfo = { + .callback = [](WGPULoggingType type, WGPUStringView message, void* userdata1, void* userdata2) { + LOG(kDefLog, kError, "Device logging callback: %.*s", (int)message.length, message.data); + if (type == WGPULoggingType_Error) { + throw std::runtime_error("Device error logged."); + } }, - nullptr); + .userdata1 = nullptr, + .userdata2 = nullptr + }; + wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo); + } context.queue = wgpuDeviceGetQueue(context.device); return context; @@ -1005,28 +1036,36 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize, wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; - wgpuQueueOnSubmittedWorkDone( - ctx.queue, - [](WGPUQueueWorkDoneStatus status, void *callbackData) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(callbackData); - wgpuBufferMapAsync( - data->buffer, WGPUMapMode_Read, 0, data->bufferSize, - [](WGPUBufferMapAsyncStatus status, void *captureData) { - const auto *data = static_cast(captureData); - check(status == WGPUBufferMapAsyncStatus_Success, - "Map readbackBuffer", __FILE__, __LINE__); - const void *mappedData = wgpuBufferGetConstMappedRange( - data->buffer, /*offset=*/0, data->bufferSize); - check(mappedData, "Get mapped range", __FILE__, __LINE__); - memcpy(data->output, mappedData, data->bufferSize); - wgpuBufferUnmap(data->buffer); - data->promise->set_value(); - }, - callbackData); + + WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(userdata1); + WGPUBufferMapCallbackInfo mapCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = [](WGPUMapAsyncStatus status, WGPUStringView message, void* userdata1, void* userdata2) { + const auto *data = static_cast(userdata1); + check(status == WGPUMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + .userdata1 = const_cast(data), + .userdata2 = nullptr + }; + wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, data->bufferSize, mapCallbackInfo); }, - &callbackData); + .userdata1 = &callbackData, + .userdata2 = nullptr + }; + wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); + wait(ctx, op.future); } @@ -1113,28 +1152,36 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, wgpuCommandBufferRelease(op.commandBuffer); CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; - wgpuQueueOnSubmittedWorkDone( - ctx.queue, - [](WGPUQueueWorkDoneStatus status, void *callbackData) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(callbackData); - wgpuBufferMapAsync( - data->buffer, WGPUMapMode_Read, 0, data->bufferSize, - [](WGPUBufferMapAsyncStatus status, void *captureData) { - const auto *data = static_cast(captureData); - check(status == WGPUBufferMapAsyncStatus_Success, - "Map readbackBuffer", __FILE__, __LINE__); - const void *mappedData = wgpuBufferGetConstMappedRange( - data->buffer, /*offset=*/0, data->bufferSize); - check(mappedData, "Get mapped range", __FILE__, __LINE__); - memcpy(data->output, mappedData, data->bufferSize); - wgpuBufferUnmap(data->buffer); - data->promise->set_value(); - }, - callbackData); - }, - &callbackData); + + WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(userdata1); + WGPUBufferMapCallbackInfo mapCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = [](WGPUMapAsyncStatus status, WGPUStringView message, void* userdata1, void* userdata2) { + const auto *data = static_cast(userdata1); + check(status == WGPUMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + .userdata1 = const_cast(data), + .userdata2 = nullptr + }; + wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, data->bufferSize, mapCallbackInfo); + }, + .userdata1 = &callbackData, + .userdata2 = nullptr + }; + wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); + wait(ctx, op.future); if (op.readbackBuffer) { wgpuBufferRelease(op.readbackBuffer); @@ -1412,47 +1459,59 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, }; WGPUPipelineLayout pipelineLayout = wgpuDeviceCreatePipelineLayout(device, &pipelineLayoutDesc); - WGPUShaderModuleWGSLDescriptor wgslDesc = { - .code = code.data.c_str(), + + WGPUShaderSourceWGSL wgslDesc = { + .chain = {.sType = WGPUSType_ShaderSourceWGSL}, + .code = {.data = code.data.c_str(), .length = code.data.length()} }; - wgslDesc.chain.sType = WGPUSType_ShaderModuleWGSLDescriptor; + WGPUShaderModuleDescriptor shaderModuleDesc = {}; shaderModuleDesc.nextInChain = &wgslDesc.chain; - shaderModuleDesc.label = code.label.c_str(); + shaderModuleDesc.label = {code.label.c_str(), code.label.length()}; + WGPUComputePipelineDescriptor computePipelineDesc = {}; computePipelineDesc.layout = pipelineLayout; computePipelineDesc.compute.module = wgpuDeviceCreateShaderModule(device, &shaderModuleDesc); - computePipelineDesc.compute.entryPoint = code.entryPoint.c_str(); - computePipelineDesc.label = code.label.c_str(); + computePipelineDesc.compute.entryPoint = {code.entryPoint.c_str(), code.entryPoint.length()}; + computePipelineDesc.label = {code.label.c_str(), code.label.length()}; + op->computePipeline = wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; resetCommandBuffer(device, op); if (cacheKey != nullptr) ctx.kernelPool.data[cacheKey]=op; - - WGPUCompilationInfoCallback cb = - [](WGPUCompilationInfoRequestStatus status, - WGPUCompilationInfo const *compilationInfo, void *userData) { - CompilationInfo *result = static_cast(userData); - if (compilationInfo && result) { + + auto compilationInfoCallback = [](WGPUCompilationInfoRequestStatus status, + WGPUCompilationInfo const *compilationInfo, + void *userdata1, void *userdata2) { + CompilationInfo *result = static_cast(userdata1); + if (compilationInfo && result) { result->status = status; for (uint32_t i = 0; i < compilationInfo->messageCount; ++i) { - printf("Message %d: %s\n", i, compilationInfo->messages[i].message); - result->messages.push_back(compilationInfo->messages[i].message); - result->lineNums.push_back(compilationInfo->messages[i].lineNum); - result->linePos.push_back(compilationInfo->messages[i].linePos); + printf("Message %d: %.*s\n", i, + static_cast(compilationInfo->messages[i].message.length), + compilationInfo->messages[i].message.data); + result->messages.push_back(std::string( + compilationInfo->messages[i].message.data, + compilationInfo->messages[i].message.length)); + result->lineNums.push_back(compilationInfo->messages[i].lineNum); + result->linePos.push_back(compilationInfo->messages[i].linePos); } result->finished = true; - } else { + } else { LOG(kDefLog, kTrace, "No compilation info or result"); - } - }; + } + }; - wgpuShaderModuleGetCompilationInfo( - computePipelineDesc.compute.module, cb, static_cast(compilationInfo)); + WGPUCompilationInfoCallbackInfo compilationCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = compilationInfoCallback, + .userdata1 = static_cast(compilationInfo), + .userdata2 = nullptr + }; while (compilationInfo && !compilationInfo->finished) { processEvents(ctx.instance); @@ -1524,22 +1583,25 @@ Kernel createKernel(Context &ctx, const KernelCode &code, */ inline void dispatchKernel(Context &ctx, Kernel &kernel, std::promise &promise) { - // Submit the command buffer if (kernel->used) { resetCommandBuffer(ctx.device, kernel); } wgpuQueueSubmit(ctx.queue, 1, &kernel->commandBuffer); wgpuCommandBufferRelease(kernel->commandBuffer); kernel->used = true; - wgpuQueueOnSubmittedWorkDone( - ctx.queue, - [](WGPUQueueWorkDoneStatus status, void *data) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - auto *promise = static_cast *>(data); - promise->set_value(); + + WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + auto *promise = static_cast *>(userdata1); + promise->set_value(); }, - &promise); + .userdata1 = &promise, + .userdata2 = nullptr + }; + wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); } } // namespace gpu diff --git a/third_party/headers/webgpu/webgpu.h b/third_party/headers/webgpu/webgpu.h index 505e4bb..b36a758 100644 --- a/third_party/headers/webgpu/webgpu.h +++ b/third_party/headers/webgpu/webgpu.h @@ -27,12 +27,20 @@ // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + #ifdef __EMSCRIPTEN__ #error "Do not include this header. Emscripten already provides headers needed for WebGPU." #endif + #ifndef WEBGPU_H_ #define WEBGPU_H_ +#define WGPU_BREAKING_CHANGE_STRING_VIEW_LABELS +#define WGPU_BREAKING_CHANGE_STRING_VIEW_OUTPUT_STRUCTS +#define WGPU_BREAKING_CHANGE_STRING_VIEW_CALLBACKS +#define WGPU_BREAKING_CHANGE_FUTURE_CALLBACK_TYPES +#define WGPU_BREAKING_CHANGE_LOGGING_CALLBACK_TYPE + #if defined(WGPU_SHARED_LIBRARY) # if defined(_WIN32) # if defined(WGPU_IMPLEMENTATION) @@ -67,6 +75,8 @@ #define WGPU_NULLABLE #endif +#define WGPU_BREAKING_CHANGE_DROP_DESCRIPTOR + #include #include @@ -89,10 +99,11 @@ #define WGPU_LIMIT_U64_UNDEFINED UINT64_MAX #define WGPU_MIP_LEVEL_COUNT_UNDEFINED UINT32_MAX #define WGPU_QUERY_SET_INDEX_UNDEFINED UINT32_MAX +#define WGPU_STRLEN SIZE_MAX #define WGPU_WHOLE_MAP_SIZE SIZE_MAX #define WGPU_WHOLE_SIZE UINT64_MAX -typedef uint32_t WGPUFlags; +typedef uint64_t WGPUFlags; typedef uint32_t WGPUBool; typedef struct WGPUAdapterImpl* WGPUAdapter WGPU_OBJECT_ATTRIBUTE; @@ -119,52 +130,40 @@ typedef struct WGPUSharedBufferMemoryImpl* WGPUSharedBufferMemory WGPU_OBJECT_AT typedef struct WGPUSharedFenceImpl* WGPUSharedFence WGPU_OBJECT_ATTRIBUTE; typedef struct WGPUSharedTextureMemoryImpl* WGPUSharedTextureMemory WGPU_OBJECT_ATTRIBUTE; typedef struct WGPUSurfaceImpl* WGPUSurface WGPU_OBJECT_ATTRIBUTE; -typedef struct WGPUSwapChainImpl* WGPUSwapChain WGPU_OBJECT_ATTRIBUTE; typedef struct WGPUTextureImpl* WGPUTexture WGPU_OBJECT_ATTRIBUTE; typedef struct WGPUTextureViewImpl* WGPUTextureView WGPU_OBJECT_ATTRIBUTE; // Structure forward declarations -struct WGPUAdapterInfo; -struct WGPUAdapterProperties; +struct WGPUINTERNAL_HAVE_EMDAWNWEBGPU_HEADER; struct WGPUAdapterPropertiesD3D; +struct WGPUAdapterPropertiesSubgroups; struct WGPUAdapterPropertiesVk; struct WGPUBindGroupEntry; struct WGPUBlendComponent; struct WGPUBufferBindingLayout; -struct WGPUBufferDescriptor; struct WGPUBufferHostMappedPointer; -struct WGPUBufferMapCallbackInfo; struct WGPUColor; struct WGPUColorTargetStateExpandResolveTextureDawn; -struct WGPUCommandBufferDescriptor; -struct WGPUCommandEncoderDescriptor; -struct WGPUCompilationInfoCallbackInfo; -struct WGPUCompilationMessage; struct WGPUComputePassTimestampWrites; -struct WGPUConstantEntry; struct WGPUCopyTextureForBrowserOptions; -struct WGPUCreateComputePipelineAsyncCallbackInfo; -struct WGPUCreateRenderPipelineAsyncCallbackInfo; struct WGPUDawnWGSLBlocklist; struct WGPUDawnAdapterPropertiesPowerPreference; struct WGPUDawnBufferDescriptorErrorInfoFromWireClient; -struct WGPUDawnCacheDeviceDescriptor; -struct WGPUDawnComputePipelineFullSubgroups; +struct WGPUDawnDrmFormatProperties; struct WGPUDawnEncoderInternalUsageDescriptor; +struct WGPUDawnExperimentalImmediateDataLimits; struct WGPUDawnExperimentalSubgroupLimits; +struct WGPUDawnFormatCapabilities; struct WGPUDawnRenderPassColorAttachmentRenderToSingleSampled; struct WGPUDawnShaderModuleSPIRVOptionsDescriptor; +struct WGPUDawnTexelCopyBufferRowAlignmentLimits; struct WGPUDawnTextureInternalUsageDescriptor; struct WGPUDawnTogglesDescriptor; struct WGPUDawnWireWGSLControl; -struct WGPUDepthStencilStateDepthWriteDefinedDawn; -struct WGPUDeviceLostCallbackInfo; -struct WGPUDrmFormatProperties; struct WGPUExtent2D; struct WGPUExtent3D; struct WGPUExternalTextureBindingEntry; struct WGPUExternalTextureBindingLayout; -struct WGPUFormatCapabilities; struct WGPUFuture; struct WGPUInstanceFeatures; struct WGPULimits; @@ -172,42 +171,28 @@ struct WGPUMemoryHeapInfo; struct WGPUMultisampleState; struct WGPUOrigin2D; struct WGPUOrigin3D; -struct WGPUPipelineLayoutDescriptor; struct WGPUPipelineLayoutStorageAttachment; -struct WGPUPopErrorScopeCallbackInfo; -struct WGPUPrimitiveDepthClipControl; struct WGPUPrimitiveState; -struct WGPUQuerySetDescriptor; -struct WGPUQueueDescriptor; -struct WGPUQueueWorkDoneCallbackInfo; -struct WGPURenderBundleDescriptor; -struct WGPURenderBundleEncoderDescriptor; struct WGPURenderPassDepthStencilAttachment; -struct WGPURenderPassDescriptorMaxDrawCount; +struct WGPURenderPassDescriptorExpandResolveRect; +struct WGPURenderPassMaxDrawCount; struct WGPURenderPassTimestampWrites; -struct WGPURequestAdapterCallbackInfo; struct WGPURequestAdapterOptions; -struct WGPURequestDeviceCallbackInfo; struct WGPUSamplerBindingLayout; -struct WGPUSamplerDescriptor; -struct WGPUShaderModuleSPIRVDescriptor; -struct WGPUShaderModuleWGSLDescriptor; struct WGPUShaderModuleCompilationOptions; -struct WGPUShaderModuleDescriptor; +struct WGPUShaderSourceSPIRV; struct WGPUSharedBufferMemoryBeginAccessDescriptor; -struct WGPUSharedBufferMemoryDescriptor; struct WGPUSharedBufferMemoryEndAccessState; struct WGPUSharedBufferMemoryProperties; struct WGPUSharedFenceDXGISharedHandleDescriptor; struct WGPUSharedFenceDXGISharedHandleExportInfo; struct WGPUSharedFenceMTLSharedEventDescriptor; struct WGPUSharedFenceMTLSharedEventExportInfo; -struct WGPUSharedFenceDescriptor; struct WGPUSharedFenceExportInfo; +struct WGPUSharedFenceSyncFDDescriptor; +struct WGPUSharedFenceSyncFDExportInfo; struct WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor; struct WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo; -struct WGPUSharedFenceVkSemaphoreSyncFDDescriptor; -struct WGPUSharedFenceVkSemaphoreSyncFDExportInfo; struct WGPUSharedFenceVkSemaphoreZirconHandleDescriptor; struct WGPUSharedFenceVkSemaphoreZirconHandleExportInfo; struct WGPUSharedTextureMemoryD3DSwapchainBeginState; @@ -216,7 +201,6 @@ struct WGPUSharedTextureMemoryEGLImageDescriptor; struct WGPUSharedTextureMemoryIOSurfaceDescriptor; struct WGPUSharedTextureMemoryAHardwareBufferDescriptor; struct WGPUSharedTextureMemoryBeginAccessDescriptor; -struct WGPUSharedTextureMemoryDescriptor; struct WGPUSharedTextureMemoryDmaBufPlane; struct WGPUSharedTextureMemoryEndAccessState; struct WGPUSharedTextureMemoryOpaqueFDDescriptor; @@ -227,76 +211,106 @@ struct WGPUSharedTextureMemoryZirconHandleDescriptor; struct WGPUStaticSamplerBindingLayout; struct WGPUStencilFaceState; struct WGPUStorageTextureBindingLayout; +struct WGPUStringView; +struct WGPUSupportedWGSLLanguageFeatures; +struct WGPUSupportedFeatures; struct WGPUSurfaceCapabilities; struct WGPUSurfaceConfiguration; -struct WGPUSurfaceDescriptor; -struct WGPUSurfaceDescriptorFromAndroidNativeWindow; -struct WGPUSurfaceDescriptorFromCanvasHTMLSelector; -struct WGPUSurfaceDescriptorFromMetalLayer; -struct WGPUSurfaceDescriptorFromWaylandSurface; -struct WGPUSurfaceDescriptorFromWindowsHWND; struct WGPUSurfaceDescriptorFromWindowsCoreWindow; struct WGPUSurfaceDescriptorFromWindowsSwapChainPanel; -struct WGPUSurfaceDescriptorFromXlibWindow; +struct WGPUSurfaceSourceXCBWindow; +struct WGPUSurfaceSourceAndroidNativeWindow; +struct WGPUSurfaceSourceMetalLayer; +struct WGPUSurfaceSourceWaylandSurface; +struct WGPUSurfaceSourceWindowsHWND; +struct WGPUSurfaceSourceXlibWindow; struct WGPUSurfaceTexture; -struct WGPUSwapChainDescriptor; struct WGPUTextureBindingLayout; struct WGPUTextureBindingViewDimensionDescriptor; struct WGPUTextureDataLayout; -struct WGPUTextureViewDescriptor; -struct WGPUUncapturedErrorCallbackInfo; struct WGPUVertexAttribute; struct WGPUYCbCrVkDescriptor; struct WGPUAHardwareBufferProperties; +struct WGPUAdapterInfo; struct WGPUAdapterPropertiesMemoryHeaps; struct WGPUBindGroupDescriptor; struct WGPUBindGroupLayoutEntry; struct WGPUBlendState; -struct WGPUCompilationInfo; +struct WGPUBufferDescriptor; +struct WGPUCommandBufferDescriptor; +struct WGPUCommandEncoderDescriptor; +struct WGPUCompilationMessage; struct WGPUComputePassDescriptor; +struct WGPUConstantEntry; +struct WGPUDawnCacheDeviceDescriptor; +struct WGPUDawnDrmFormatCapabilities; struct WGPUDepthStencilState; -struct WGPUDrmFormatCapabilities; +struct WGPUEmscriptenSurfaceSourceCanvasHTMLSelector; struct WGPUExternalTextureDescriptor; struct WGPUFutureWaitInfo; struct WGPUImageCopyBuffer; struct WGPUImageCopyExternalTexture; struct WGPUImageCopyTexture; struct WGPUInstanceDescriptor; +struct WGPUPipelineLayoutDescriptor; struct WGPUPipelineLayoutPixelLocalStorage; -struct WGPUProgrammableStageDescriptor; +struct WGPUQuerySetDescriptor; +struct WGPUQueueDescriptor; +struct WGPURenderBundleDescriptor; +struct WGPURenderBundleEncoderDescriptor; struct WGPURenderPassColorAttachment; struct WGPURenderPassStorageAttachment; struct WGPURequiredLimits; +struct WGPUSamplerDescriptor; +struct WGPUShaderModuleDescriptor; +struct WGPUShaderSourceWGSL; +struct WGPUSharedBufferMemoryDescriptor; +struct WGPUSharedFenceDescriptor; struct WGPUSharedTextureMemoryAHardwareBufferProperties; +struct WGPUSharedTextureMemoryDescriptor; struct WGPUSharedTextureMemoryDmaBufDescriptor; struct WGPUSharedTextureMemoryProperties; struct WGPUSupportedLimits; +struct WGPUSurfaceDescriptor; struct WGPUTextureDescriptor; +struct WGPUTextureViewDescriptor; struct WGPUVertexBufferLayout; struct WGPUBindGroupLayoutDescriptor; struct WGPUColorTargetState; -struct WGPUComputePipelineDescriptor; +struct WGPUCompilationInfo; +struct WGPUComputeState; struct WGPUDeviceDescriptor; struct WGPURenderPassDescriptor; struct WGPURenderPassPixelLocalStorage; struct WGPUVertexState; +struct WGPUComputePipelineDescriptor; struct WGPUFragmentState; struct WGPURenderPipelineDescriptor; typedef enum WGPUWGSLFeatureName { - WGPUWGSLFeatureName_Undefined = 0x00000000, WGPUWGSLFeatureName_ReadonlyAndReadwriteStorageTextures = 0x00000001, WGPUWGSLFeatureName_Packed4x8IntegerDotProduct = 0x00000002, WGPUWGSLFeatureName_UnrestrictedPointerParameters = 0x00000003, WGPUWGSLFeatureName_PointerCompositeAccess = 0x00000004, - WGPUWGSLFeatureName_ChromiumTestingUnimplemented = 0x000003E8, - WGPUWGSLFeatureName_ChromiumTestingUnsafeExperimental = 0x000003E9, - WGPUWGSLFeatureName_ChromiumTestingExperimental = 0x000003EA, - WGPUWGSLFeatureName_ChromiumTestingShippedWithKillswitch = 0x000003EB, - WGPUWGSLFeatureName_ChromiumTestingShipped = 0x000003EC, + WGPUWGSLFeatureName_ChromiumTestingUnimplemented = 0x00050000, + WGPUWGSLFeatureName_ChromiumTestingUnsafeExperimental = 0x00050001, + WGPUWGSLFeatureName_ChromiumTestingExperimental = 0x00050002, + WGPUWGSLFeatureName_ChromiumTestingShippedWithKillswitch = 0x00050003, + WGPUWGSLFeatureName_ChromiumTestingShipped = 0x00050004, WGPUWGSLFeatureName_Force32 = 0x7FFFFFFF } WGPUWGSLFeatureName WGPU_ENUM_ATTRIBUTE; - +typedef enum WGPUWGSLLanguageFeatureName { + WGPUWGSLLanguageFeatureName_ReadonlyAndReadwriteStorageTextures = 0x00000001, + WGPUWGSLLanguageFeatureName_Packed4x8IntegerDotProduct = 0x00000002, + WGPUWGSLLanguageFeatureName_UnrestrictedPointerParameters = 0x00000003, + WGPUWGSLLanguageFeatureName_PointerCompositeAccess = 0x00000004, + WGPUWGSLLanguageFeatureName_ChromiumTestingUnimplemented = 0x00050000, + WGPUWGSLLanguageFeatureName_ChromiumTestingUnsafeExperimental = 0x00050001, + WGPUWGSLLanguageFeatureName_ChromiumTestingExperimental = 0x00050002, + WGPUWGSLLanguageFeatureName_ChromiumTestingShippedWithKillswitch = 0x00050003, + WGPUWGSLLanguageFeatureName_ChromiumTestingShipped = 0x00050004, + WGPUWGSLLanguageFeatureName_Force32 = 0x7FFFFFFF +} WGPUWGSLLanguageFeatureName WGPU_ENUM_ATTRIBUTE; typedef enum WGPUAdapterType { WGPUAdapterType_DiscreteGPU = 0x00000001, WGPUAdapterType_IntegratedGPU = 0x00000002, @@ -304,7 +318,6 @@ typedef enum WGPUAdapterType { WGPUAdapterType_Unknown = 0x00000004, WGPUAdapterType_Force32 = 0x7FFFFFFF } WGPUAdapterType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUAddressMode { WGPUAddressMode_Undefined = 0x00000000, WGPUAddressMode_ClampToEdge = 0x00000001, @@ -312,14 +325,12 @@ typedef enum WGPUAddressMode { WGPUAddressMode_MirrorRepeat = 0x00000003, WGPUAddressMode_Force32 = 0x7FFFFFFF } WGPUAddressMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUAlphaMode { WGPUAlphaMode_Opaque = 0x00000001, WGPUAlphaMode_Premultiplied = 0x00000002, WGPUAlphaMode_Unpremultiplied = 0x00000003, WGPUAlphaMode_Force32 = 0x7FFFFFFF } WGPUAlphaMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUBackendType { WGPUBackendType_Undefined = 0x00000000, WGPUBackendType_Null = 0x00000001, @@ -332,7 +343,6 @@ typedef enum WGPUBackendType { WGPUBackendType_OpenGLES = 0x00000008, WGPUBackendType_Force32 = 0x7FFFFFFF } WGPUBackendType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUBlendFactor { WGPUBlendFactor_Undefined = 0x00000000, WGPUBlendFactor_Zero = 0x00000001, @@ -354,7 +364,6 @@ typedef enum WGPUBlendFactor { WGPUBlendFactor_OneMinusSrc1Alpha = 0x00000011, WGPUBlendFactor_Force32 = 0x7FFFFFFF } WGPUBlendFactor WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUBlendOperation { WGPUBlendOperation_Undefined = 0x00000000, WGPUBlendOperation_Add = 0x00000001, @@ -364,43 +373,26 @@ typedef enum WGPUBlendOperation { WGPUBlendOperation_Max = 0x00000005, WGPUBlendOperation_Force32 = 0x7FFFFFFF } WGPUBlendOperation WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUBufferBindingType { - WGPUBufferBindingType_Undefined = 0x00000000, - WGPUBufferBindingType_Uniform = 0x00000001, - WGPUBufferBindingType_Storage = 0x00000002, - WGPUBufferBindingType_ReadOnlyStorage = 0x00000003, + WGPUBufferBindingType_BindingNotUsed = 0x00000000, + WGPUBufferBindingType_Undefined = 0x00000001, + WGPUBufferBindingType_Uniform = 0x00000002, + WGPUBufferBindingType_Storage = 0x00000003, + WGPUBufferBindingType_ReadOnlyStorage = 0x00000004, WGPUBufferBindingType_Force32 = 0x7FFFFFFF } WGPUBufferBindingType WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUBufferMapAsyncStatus { - WGPUBufferMapAsyncStatus_Success = 0x00000000, - WGPUBufferMapAsyncStatus_InstanceDropped = 0x00000001, - WGPUBufferMapAsyncStatus_ValidationError = 0x00000002, - WGPUBufferMapAsyncStatus_Unknown = 0x00000003, - WGPUBufferMapAsyncStatus_DeviceLost = 0x00000004, - WGPUBufferMapAsyncStatus_DestroyedBeforeCallback = 0x00000005, - WGPUBufferMapAsyncStatus_UnmappedBeforeCallback = 0x00000006, - WGPUBufferMapAsyncStatus_MappingAlreadyPending = 0x00000007, - WGPUBufferMapAsyncStatus_OffsetOutOfRange = 0x00000008, - WGPUBufferMapAsyncStatus_SizeOutOfRange = 0x00000009, - WGPUBufferMapAsyncStatus_Force32 = 0x7FFFFFFF -} WGPUBufferMapAsyncStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUBufferMapState { WGPUBufferMapState_Unmapped = 0x00000001, WGPUBufferMapState_Pending = 0x00000002, WGPUBufferMapState_Mapped = 0x00000003, WGPUBufferMapState_Force32 = 0x7FFFFFFF } WGPUBufferMapState WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCallbackMode { WGPUCallbackMode_WaitAnyOnly = 0x00000001, WGPUCallbackMode_AllowProcessEvents = 0x00000002, WGPUCallbackMode_AllowSpontaneous = 0x00000003, WGPUCallbackMode_Force32 = 0x7FFFFFFF } WGPUCallbackMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCompareFunction { WGPUCompareFunction_Undefined = 0x00000000, WGPUCompareFunction_Never = 0x00000001, @@ -413,23 +405,17 @@ typedef enum WGPUCompareFunction { WGPUCompareFunction_Always = 0x00000008, WGPUCompareFunction_Force32 = 0x7FFFFFFF } WGPUCompareFunction WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCompilationInfoRequestStatus { - WGPUCompilationInfoRequestStatus_Success = 0x00000000, - WGPUCompilationInfoRequestStatus_InstanceDropped = 0x00000001, - WGPUCompilationInfoRequestStatus_Error = 0x00000002, - WGPUCompilationInfoRequestStatus_DeviceLost = 0x00000003, - WGPUCompilationInfoRequestStatus_Unknown = 0x00000004, + WGPUCompilationInfoRequestStatus_Success = 0x00000001, + WGPUCompilationInfoRequestStatus_InstanceDropped = 0x00000002, WGPUCompilationInfoRequestStatus_Force32 = 0x7FFFFFFF } WGPUCompilationInfoRequestStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCompilationMessageType { WGPUCompilationMessageType_Error = 0x00000001, WGPUCompilationMessageType_Warning = 0x00000002, WGPUCompilationMessageType_Info = 0x00000003, WGPUCompilationMessageType_Force32 = 0x7FFFFFFF } WGPUCompilationMessageType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCompositeAlphaMode { WGPUCompositeAlphaMode_Auto = 0x00000000, WGPUCompositeAlphaMode_Opaque = 0x00000001, @@ -438,18 +424,13 @@ typedef enum WGPUCompositeAlphaMode { WGPUCompositeAlphaMode_Inherit = 0x00000004, WGPUCompositeAlphaMode_Force32 = 0x7FFFFFFF } WGPUCompositeAlphaMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCreatePipelineAsyncStatus { - WGPUCreatePipelineAsyncStatus_Success = 0x00000000, - WGPUCreatePipelineAsyncStatus_InstanceDropped = 0x00000001, - WGPUCreatePipelineAsyncStatus_ValidationError = 0x00000002, - WGPUCreatePipelineAsyncStatus_InternalError = 0x00000003, - WGPUCreatePipelineAsyncStatus_DeviceLost = 0x00000004, - WGPUCreatePipelineAsyncStatus_DeviceDestroyed = 0x00000005, - WGPUCreatePipelineAsyncStatus_Unknown = 0x00000006, + WGPUCreatePipelineAsyncStatus_Success = 0x00000001, + WGPUCreatePipelineAsyncStatus_InstanceDropped = 0x00000002, + WGPUCreatePipelineAsyncStatus_ValidationError = 0x00000003, + WGPUCreatePipelineAsyncStatus_InternalError = 0x00000004, WGPUCreatePipelineAsyncStatus_Force32 = 0x7FFFFFFF } WGPUCreatePipelineAsyncStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUCullMode { WGPUCullMode_Undefined = 0x00000000, WGPUCullMode_None = 0x00000001, @@ -457,7 +438,6 @@ typedef enum WGPUCullMode { WGPUCullMode_Back = 0x00000003, WGPUCullMode_Force32 = 0x7FFFFFFF } WGPUCullMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUDeviceLostReason { WGPUDeviceLostReason_Unknown = 0x00000001, WGPUDeviceLostReason_Destroyed = 0x00000002, @@ -465,34 +445,34 @@ typedef enum WGPUDeviceLostReason { WGPUDeviceLostReason_FailedCreation = 0x00000004, WGPUDeviceLostReason_Force32 = 0x7FFFFFFF } WGPUDeviceLostReason WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUErrorFilter { WGPUErrorFilter_Validation = 0x00000001, WGPUErrorFilter_OutOfMemory = 0x00000002, WGPUErrorFilter_Internal = 0x00000003, WGPUErrorFilter_Force32 = 0x7FFFFFFF } WGPUErrorFilter WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUErrorType { - WGPUErrorType_NoError = 0x00000000, - WGPUErrorType_Validation = 0x00000001, - WGPUErrorType_OutOfMemory = 0x00000002, - WGPUErrorType_Internal = 0x00000003, - WGPUErrorType_Unknown = 0x00000004, - WGPUErrorType_DeviceLost = 0x00000005, + WGPUErrorType_NoError = 0x00000001, + WGPUErrorType_Validation = 0x00000002, + WGPUErrorType_OutOfMemory = 0x00000003, + WGPUErrorType_Internal = 0x00000004, + WGPUErrorType_Unknown = 0x00000005, WGPUErrorType_Force32 = 0x7FFFFFFF } WGPUErrorType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUExternalTextureRotation { - WGPUExternalTextureRotation_Rotate0Degrees = 0x00000000, - WGPUExternalTextureRotation_Rotate90Degrees = 0x00000001, - WGPUExternalTextureRotation_Rotate180Degrees = 0x00000002, - WGPUExternalTextureRotation_Rotate270Degrees = 0x00000003, + WGPUExternalTextureRotation_Rotate0Degrees = 0x00000001, + WGPUExternalTextureRotation_Rotate90Degrees = 0x00000002, + WGPUExternalTextureRotation_Rotate180Degrees = 0x00000003, + WGPUExternalTextureRotation_Rotate270Degrees = 0x00000004, WGPUExternalTextureRotation_Force32 = 0x7FFFFFFF } WGPUExternalTextureRotation WGPU_ENUM_ATTRIBUTE; - +typedef enum WGPUFeatureLevel { + WGPUFeatureLevel_Undefined = 0x00000000, + WGPUFeatureLevel_Compatibility = 0x00000001, + WGPUFeatureLevel_Core = 0x00000002, + WGPUFeatureLevel_Force32 = 0x7FFFFFFF +} WGPUFeatureLevel WGPU_ENUM_ATTRIBUTE; typedef enum WGPUFeatureName { - WGPUFeatureName_Undefined = 0x00000000, WGPUFeatureName_DepthClipControl = 0x00000001, WGPUFeatureName_Depth32FloatStencil8 = 0x00000002, WGPUFeatureName_TimestampQuery = 0x00000003, @@ -504,92 +484,93 @@ typedef enum WGPUFeatureName { WGPUFeatureName_RG11B10UfloatRenderable = 0x00000009, WGPUFeatureName_BGRA8UnormStorage = 0x0000000A, WGPUFeatureName_Float32Filterable = 0x0000000B, - WGPUFeatureName_DawnInternalUsages = 0x000003EA, - WGPUFeatureName_DawnMultiPlanarFormats = 0x000003EB, - WGPUFeatureName_DawnNative = 0x000003EC, - WGPUFeatureName_ChromiumExperimentalTimestampQueryInsidePasses = 0x000003EE, - WGPUFeatureName_ImplicitDeviceSynchronization = 0x000003EF, - WGPUFeatureName_SurfaceCapabilities = 0x000003F0, - WGPUFeatureName_TransientAttachments = 0x000003F1, - WGPUFeatureName_MSAARenderToSingleSampled = 0x000003F2, - WGPUFeatureName_DualSourceBlending = 0x000003F3, - WGPUFeatureName_D3D11MultithreadProtected = 0x000003F4, - WGPUFeatureName_ANGLETextureSharing = 0x000003F5, - WGPUFeatureName_ChromiumExperimentalSubgroups = 0x000003F6, - WGPUFeatureName_ChromiumExperimentalSubgroupUniformControlFlow = 0x000003F7, - WGPUFeatureName_PixelLocalStorageCoherent = 0x000003F9, - WGPUFeatureName_PixelLocalStorageNonCoherent = 0x000003FA, - WGPUFeatureName_Unorm16TextureFormats = 0x000003FB, - WGPUFeatureName_Snorm16TextureFormats = 0x000003FC, - WGPUFeatureName_MultiPlanarFormatExtendedUsages = 0x000003FD, - WGPUFeatureName_MultiPlanarFormatP010 = 0x000003FE, - WGPUFeatureName_HostMappedPointer = 0x000003FF, - WGPUFeatureName_MultiPlanarRenderTargets = 0x00000400, - WGPUFeatureName_MultiPlanarFormatNv12a = 0x00000401, - WGPUFeatureName_FramebufferFetch = 0x00000402, - WGPUFeatureName_BufferMapExtendedUsages = 0x00000403, - WGPUFeatureName_AdapterPropertiesMemoryHeaps = 0x00000404, - WGPUFeatureName_AdapterPropertiesD3D = 0x00000405, - WGPUFeatureName_AdapterPropertiesVk = 0x00000406, - WGPUFeatureName_R8UnormStorage = 0x00000407, - WGPUFeatureName_FormatCapabilities = 0x00000408, - WGPUFeatureName_DrmFormatCapabilities = 0x00000409, - WGPUFeatureName_Norm16TextureFormats = 0x0000040A, - WGPUFeatureName_MultiPlanarFormatNv16 = 0x0000040B, - WGPUFeatureName_MultiPlanarFormatNv24 = 0x0000040C, - WGPUFeatureName_MultiPlanarFormatP210 = 0x0000040D, - WGPUFeatureName_MultiPlanarFormatP410 = 0x0000040E, - WGPUFeatureName_SharedTextureMemoryVkDedicatedAllocation = 0x0000044C, - WGPUFeatureName_SharedTextureMemoryAHardwareBuffer = 0x0000044D, - WGPUFeatureName_SharedTextureMemoryDmaBuf = 0x0000044E, - WGPUFeatureName_SharedTextureMemoryOpaqueFD = 0x0000044F, - WGPUFeatureName_SharedTextureMemoryZirconHandle = 0x00000450, - WGPUFeatureName_SharedTextureMemoryDXGISharedHandle = 0x00000451, - WGPUFeatureName_SharedTextureMemoryD3D11Texture2D = 0x00000452, - WGPUFeatureName_SharedTextureMemoryIOSurface = 0x00000453, - WGPUFeatureName_SharedTextureMemoryEGLImage = 0x00000454, - WGPUFeatureName_SharedFenceVkSemaphoreOpaqueFD = 0x000004B0, - WGPUFeatureName_SharedFenceVkSemaphoreSyncFD = 0x000004B1, - WGPUFeatureName_SharedFenceVkSemaphoreZirconHandle = 0x000004B2, - WGPUFeatureName_SharedFenceDXGISharedHandle = 0x000004B3, - WGPUFeatureName_SharedFenceMTLSharedEvent = 0x000004B4, - WGPUFeatureName_SharedBufferMemoryD3D12Resource = 0x000004B5, - WGPUFeatureName_StaticSamplers = 0x000004B6, - WGPUFeatureName_YCbCrVulkanSamplers = 0x000004B7, - WGPUFeatureName_ShaderModuleCompilationOptions = 0x000004B8, - WGPUFeatureName_DawnLoadResolveTexture = 0x000004B9, + WGPUFeatureName_Float32Blendable = 0x0000000C, + WGPUFeatureName_Subgroups = 0x0000000D, + WGPUFeatureName_SubgroupsF16 = 0x0000000E, + WGPUFeatureName_DawnInternalUsages = 0x00050000, + WGPUFeatureName_DawnMultiPlanarFormats = 0x00050001, + WGPUFeatureName_DawnNative = 0x00050002, + WGPUFeatureName_ChromiumExperimentalTimestampQueryInsidePasses = 0x00050003, + WGPUFeatureName_ImplicitDeviceSynchronization = 0x00050004, + WGPUFeatureName_ChromiumExperimentalImmediateData = 0x00050005, + WGPUFeatureName_TransientAttachments = 0x00050006, + WGPUFeatureName_MSAARenderToSingleSampled = 0x00050007, + WGPUFeatureName_DualSourceBlending = 0x00050008, + WGPUFeatureName_D3D11MultithreadProtected = 0x00050009, + WGPUFeatureName_ANGLETextureSharing = 0x0005000A, + WGPUFeatureName_PixelLocalStorageCoherent = 0x0005000B, + WGPUFeatureName_PixelLocalStorageNonCoherent = 0x0005000C, + WGPUFeatureName_Unorm16TextureFormats = 0x0005000D, + WGPUFeatureName_Snorm16TextureFormats = 0x0005000E, + WGPUFeatureName_MultiPlanarFormatExtendedUsages = 0x0005000F, + WGPUFeatureName_MultiPlanarFormatP010 = 0x00050010, + WGPUFeatureName_HostMappedPointer = 0x00050011, + WGPUFeatureName_MultiPlanarRenderTargets = 0x00050012, + WGPUFeatureName_MultiPlanarFormatNv12a = 0x00050013, + WGPUFeatureName_FramebufferFetch = 0x00050014, + WGPUFeatureName_BufferMapExtendedUsages = 0x00050015, + WGPUFeatureName_AdapterPropertiesMemoryHeaps = 0x00050016, + WGPUFeatureName_AdapterPropertiesD3D = 0x00050017, + WGPUFeatureName_AdapterPropertiesVk = 0x00050018, + WGPUFeatureName_R8UnormStorage = 0x00050019, + WGPUFeatureName_DawnFormatCapabilities = 0x0005001A, + WGPUFeatureName_DawnDrmFormatCapabilities = 0x0005001B, + WGPUFeatureName_Norm16TextureFormats = 0x0005001C, + WGPUFeatureName_MultiPlanarFormatNv16 = 0x0005001D, + WGPUFeatureName_MultiPlanarFormatNv24 = 0x0005001E, + WGPUFeatureName_MultiPlanarFormatP210 = 0x0005001F, + WGPUFeatureName_MultiPlanarFormatP410 = 0x00050020, + WGPUFeatureName_SharedTextureMemoryVkDedicatedAllocation = 0x00050021, + WGPUFeatureName_SharedTextureMemoryAHardwareBuffer = 0x00050022, + WGPUFeatureName_SharedTextureMemoryDmaBuf = 0x00050023, + WGPUFeatureName_SharedTextureMemoryOpaqueFD = 0x00050024, + WGPUFeatureName_SharedTextureMemoryZirconHandle = 0x00050025, + WGPUFeatureName_SharedTextureMemoryDXGISharedHandle = 0x00050026, + WGPUFeatureName_SharedTextureMemoryD3D11Texture2D = 0x00050027, + WGPUFeatureName_SharedTextureMemoryIOSurface = 0x00050028, + WGPUFeatureName_SharedTextureMemoryEGLImage = 0x00050029, + WGPUFeatureName_SharedFenceVkSemaphoreOpaqueFD = 0x0005002A, + WGPUFeatureName_SharedFenceSyncFD = 0x0005002B, + WGPUFeatureName_SharedFenceVkSemaphoreZirconHandle = 0x0005002C, + WGPUFeatureName_SharedFenceDXGISharedHandle = 0x0005002D, + WGPUFeatureName_SharedFenceMTLSharedEvent = 0x0005002E, + WGPUFeatureName_SharedBufferMemoryD3D12Resource = 0x0005002F, + WGPUFeatureName_StaticSamplers = 0x00050030, + WGPUFeatureName_YCbCrVulkanSamplers = 0x00050031, + WGPUFeatureName_ShaderModuleCompilationOptions = 0x00050032, + WGPUFeatureName_DawnLoadResolveTexture = 0x00050033, + WGPUFeatureName_DawnPartialLoadResolveTexture = 0x00050034, + WGPUFeatureName_MultiDrawIndirect = 0x00050035, + WGPUFeatureName_ClipDistances = 0x00050036, + WGPUFeatureName_DawnTexelCopyBufferRowAlignment = 0x00050037, + WGPUFeatureName_FlexibleTextureViews = 0x00050038, WGPUFeatureName_Force32 = 0x7FFFFFFF } WGPUFeatureName WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUFilterMode { WGPUFilterMode_Undefined = 0x00000000, WGPUFilterMode_Nearest = 0x00000001, WGPUFilterMode_Linear = 0x00000002, WGPUFilterMode_Force32 = 0x7FFFFFFF } WGPUFilterMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUFrontFace { WGPUFrontFace_Undefined = 0x00000000, WGPUFrontFace_CCW = 0x00000001, WGPUFrontFace_CW = 0x00000002, WGPUFrontFace_Force32 = 0x7FFFFFFF } WGPUFrontFace WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUIndexFormat { WGPUIndexFormat_Undefined = 0x00000000, WGPUIndexFormat_Uint16 = 0x00000001, WGPUIndexFormat_Uint32 = 0x00000002, WGPUIndexFormat_Force32 = 0x7FFFFFFF } WGPUIndexFormat WGPU_ENUM_ATTRIBUTE; - typedef enum WGPULoadOp { WGPULoadOp_Undefined = 0x00000000, - WGPULoadOp_Clear = 0x00000001, - WGPULoadOp_Load = 0x00000002, - WGPULoadOp_ExpandResolveTexture = 0x00000003, + WGPULoadOp_Load = 0x00000001, + WGPULoadOp_Clear = 0x00000002, + WGPULoadOp_ExpandResolveTexture = 0x00050003, WGPULoadOp_Force32 = 0x7FFFFFFF } WGPULoadOp WGPU_ENUM_ATTRIBUTE; - typedef enum WGPULoggingType { WGPULoggingType_Verbose = 0x00000001, WGPULoggingType_Info = 0x00000002, @@ -597,44 +578,44 @@ typedef enum WGPULoggingType { WGPULoggingType_Error = 0x00000004, WGPULoggingType_Force32 = 0x7FFFFFFF } WGPULoggingType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUMapAsyncStatus { - WGPUMapAsyncStatus_Success = 0x00000000, - WGPUMapAsyncStatus_InstanceDropped = 0x00000001, - WGPUMapAsyncStatus_Error = 0x00000002, - WGPUMapAsyncStatus_Aborted = 0x00000003, - WGPUMapAsyncStatus_Unknown = 0x00000004, + WGPUMapAsyncStatus_Success = 0x00000001, + WGPUMapAsyncStatus_InstanceDropped = 0x00000002, + WGPUMapAsyncStatus_Error = 0x00000003, + WGPUMapAsyncStatus_Aborted = 0x00000004, WGPUMapAsyncStatus_Force32 = 0x7FFFFFFF } WGPUMapAsyncStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUMipmapFilterMode { WGPUMipmapFilterMode_Undefined = 0x00000000, WGPUMipmapFilterMode_Nearest = 0x00000001, WGPUMipmapFilterMode_Linear = 0x00000002, WGPUMipmapFilterMode_Force32 = 0x7FFFFFFF } WGPUMipmapFilterMode WGPU_ENUM_ATTRIBUTE; - +typedef enum WGPUOptionalBool { + WGPUOptionalBool_False = 0x00000000, + WGPUOptionalBool_True = 0x00000001, + WGPUOptionalBool_Undefined = 0x00000002, + WGPUOptionalBool_Force32 = 0x7FFFFFFF +} WGPUOptionalBool WGPU_ENUM_ATTRIBUTE; typedef enum WGPUPopErrorScopeStatus { - WGPUPopErrorScopeStatus_Success = 0x00000000, - WGPUPopErrorScopeStatus_InstanceDropped = 0x00000001, + WGPUPopErrorScopeStatus_Success = 0x00000001, + WGPUPopErrorScopeStatus_InstanceDropped = 0x00000002, + WGPUPopErrorScopeStatus_EmptyStack = 0x00000003, WGPUPopErrorScopeStatus_Force32 = 0x7FFFFFFF } WGPUPopErrorScopeStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUPowerPreference { WGPUPowerPreference_Undefined = 0x00000000, WGPUPowerPreference_LowPower = 0x00000001, WGPUPowerPreference_HighPerformance = 0x00000002, WGPUPowerPreference_Force32 = 0x7FFFFFFF } WGPUPowerPreference WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUPresentMode { - WGPUPresentMode_Fifo = 0x00000000, - WGPUPresentMode_FifoRelaxed = 0x00000001, - WGPUPresentMode_Immediate = 0x00000002, - WGPUPresentMode_Mailbox = 0x00000003, + WGPUPresentMode_Fifo = 0x00000001, + WGPUPresentMode_FifoRelaxed = 0x00000002, + WGPUPresentMode_Immediate = 0x00000003, + WGPUPresentMode_Mailbox = 0x00000004, WGPUPresentMode_Force32 = 0x7FFFFFFF } WGPUPresentMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUPrimitiveTopology { WGPUPrimitiveTopology_Undefined = 0x00000000, WGPUPrimitiveTopology_PointList = 0x00000001, @@ -644,138 +625,126 @@ typedef enum WGPUPrimitiveTopology { WGPUPrimitiveTopology_TriangleStrip = 0x00000005, WGPUPrimitiveTopology_Force32 = 0x7FFFFFFF } WGPUPrimitiveTopology WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUQueryType { WGPUQueryType_Occlusion = 0x00000001, WGPUQueryType_Timestamp = 0x00000002, WGPUQueryType_Force32 = 0x7FFFFFFF } WGPUQueryType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUQueueWorkDoneStatus { - WGPUQueueWorkDoneStatus_Success = 0x00000000, - WGPUQueueWorkDoneStatus_InstanceDropped = 0x00000001, - WGPUQueueWorkDoneStatus_Error = 0x00000002, - WGPUQueueWorkDoneStatus_Unknown = 0x00000003, - WGPUQueueWorkDoneStatus_DeviceLost = 0x00000004, + WGPUQueueWorkDoneStatus_Success = 0x00000001, + WGPUQueueWorkDoneStatus_InstanceDropped = 0x00000002, + WGPUQueueWorkDoneStatus_Error = 0x00000003, WGPUQueueWorkDoneStatus_Force32 = 0x7FFFFFFF } WGPUQueueWorkDoneStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPURequestAdapterStatus { - WGPURequestAdapterStatus_Success = 0x00000000, - WGPURequestAdapterStatus_InstanceDropped = 0x00000001, - WGPURequestAdapterStatus_Unavailable = 0x00000002, - WGPURequestAdapterStatus_Error = 0x00000003, - WGPURequestAdapterStatus_Unknown = 0x00000004, + WGPURequestAdapterStatus_Success = 0x00000001, + WGPURequestAdapterStatus_InstanceDropped = 0x00000002, + WGPURequestAdapterStatus_Unavailable = 0x00000003, + WGPURequestAdapterStatus_Error = 0x00000004, WGPURequestAdapterStatus_Force32 = 0x7FFFFFFF } WGPURequestAdapterStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPURequestDeviceStatus { - WGPURequestDeviceStatus_Success = 0x00000000, - WGPURequestDeviceStatus_InstanceDropped = 0x00000001, - WGPURequestDeviceStatus_Error = 0x00000002, - WGPURequestDeviceStatus_Unknown = 0x00000003, + WGPURequestDeviceStatus_Success = 0x00000001, + WGPURequestDeviceStatus_InstanceDropped = 0x00000002, + WGPURequestDeviceStatus_Error = 0x00000003, WGPURequestDeviceStatus_Force32 = 0x7FFFFFFF } WGPURequestDeviceStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUSType { - WGPUSType_Invalid = 0x00000000, - WGPUSType_SurfaceDescriptorFromMetalLayer = 0x00000001, - WGPUSType_SurfaceDescriptorFromWindowsHWND = 0x00000002, - WGPUSType_SurfaceDescriptorFromXlibWindow = 0x00000003, - WGPUSType_SurfaceDescriptorFromCanvasHTMLSelector = 0x00000004, - WGPUSType_ShaderModuleSPIRVDescriptor = 0x00000005, - WGPUSType_ShaderModuleWGSLDescriptor = 0x00000006, - WGPUSType_PrimitiveDepthClipControl = 0x00000007, - WGPUSType_SurfaceDescriptorFromWaylandSurface = 0x00000008, - WGPUSType_SurfaceDescriptorFromAndroidNativeWindow = 0x00000009, - WGPUSType_SurfaceDescriptorFromWindowsCoreWindow = 0x0000000B, - WGPUSType_ExternalTextureBindingEntry = 0x0000000C, - WGPUSType_ExternalTextureBindingLayout = 0x0000000D, - WGPUSType_SurfaceDescriptorFromWindowsSwapChainPanel = 0x0000000E, - WGPUSType_RenderPassDescriptorMaxDrawCount = 0x0000000F, - WGPUSType_DepthStencilStateDepthWriteDefinedDawn = 0x00000010, - WGPUSType_TextureBindingViewDimensionDescriptor = 0x00000011, - WGPUSType_DawnTextureInternalUsageDescriptor = 0x000003E8, - WGPUSType_DawnEncoderInternalUsageDescriptor = 0x000003EB, - WGPUSType_DawnInstanceDescriptor = 0x000003EC, - WGPUSType_DawnCacheDeviceDescriptor = 0x000003ED, - WGPUSType_DawnAdapterPropertiesPowerPreference = 0x000003EE, - WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient = 0x000003EF, - WGPUSType_DawnTogglesDescriptor = 0x000003F0, - WGPUSType_DawnShaderModuleSPIRVOptionsDescriptor = 0x000003F1, - WGPUSType_RequestAdapterOptionsLUID = 0x000003F2, - WGPUSType_RequestAdapterOptionsGetGLProc = 0x000003F3, - WGPUSType_RequestAdapterOptionsD3D11Device = 0x000003F4, - WGPUSType_DawnRenderPassColorAttachmentRenderToSingleSampled = 0x000003F6, - WGPUSType_RenderPassPixelLocalStorage = 0x000003F7, - WGPUSType_PipelineLayoutPixelLocalStorage = 0x000003F8, - WGPUSType_BufferHostMappedPointer = 0x000003F9, - WGPUSType_DawnExperimentalSubgroupLimits = 0x000003FA, - WGPUSType_AdapterPropertiesMemoryHeaps = 0x000003FB, - WGPUSType_AdapterPropertiesD3D = 0x000003FC, - WGPUSType_AdapterPropertiesVk = 0x000003FD, - WGPUSType_DawnComputePipelineFullSubgroups = 0x000003FE, - WGPUSType_DawnWireWGSLControl = 0x000003FF, - WGPUSType_DawnWGSLBlocklist = 0x00000400, - WGPUSType_DrmFormatCapabilities = 0x00000401, - WGPUSType_ShaderModuleCompilationOptions = 0x00000402, - WGPUSType_ColorTargetStateExpandResolveTextureDawn = 0x00000403, - WGPUSType_SharedTextureMemoryVkDedicatedAllocationDescriptor = 0x0000044D, - WGPUSType_SharedTextureMemoryAHardwareBufferDescriptor = 0x0000044E, - WGPUSType_SharedTextureMemoryDmaBufDescriptor = 0x0000044F, - WGPUSType_SharedTextureMemoryOpaqueFDDescriptor = 0x00000450, - WGPUSType_SharedTextureMemoryZirconHandleDescriptor = 0x00000451, - WGPUSType_SharedTextureMemoryDXGISharedHandleDescriptor = 0x00000452, - WGPUSType_SharedTextureMemoryD3D11Texture2DDescriptor = 0x00000453, - WGPUSType_SharedTextureMemoryIOSurfaceDescriptor = 0x00000454, - WGPUSType_SharedTextureMemoryEGLImageDescriptor = 0x00000455, - WGPUSType_SharedTextureMemoryInitializedBeginState = 0x000004B0, - WGPUSType_SharedTextureMemoryInitializedEndState = 0x000004B1, - WGPUSType_SharedTextureMemoryVkImageLayoutBeginState = 0x000004B2, - WGPUSType_SharedTextureMemoryVkImageLayoutEndState = 0x000004B3, - WGPUSType_SharedTextureMemoryD3DSwapchainBeginState = 0x000004B4, - WGPUSType_SharedFenceVkSemaphoreOpaqueFDDescriptor = 0x000004B5, - WGPUSType_SharedFenceVkSemaphoreOpaqueFDExportInfo = 0x000004B6, - WGPUSType_SharedFenceVkSemaphoreSyncFDDescriptor = 0x000004B7, - WGPUSType_SharedFenceVkSemaphoreSyncFDExportInfo = 0x000004B8, - WGPUSType_SharedFenceVkSemaphoreZirconHandleDescriptor = 0x000004B9, - WGPUSType_SharedFenceVkSemaphoreZirconHandleExportInfo = 0x000004BA, - WGPUSType_SharedFenceDXGISharedHandleDescriptor = 0x000004BB, - WGPUSType_SharedFenceDXGISharedHandleExportInfo = 0x000004BC, - WGPUSType_SharedFenceMTLSharedEventDescriptor = 0x000004BD, - WGPUSType_SharedFenceMTLSharedEventExportInfo = 0x000004BE, - WGPUSType_SharedBufferMemoryD3D12ResourceDescriptor = 0x000004BF, - WGPUSType_StaticSamplerBindingLayout = 0x000004C0, - WGPUSType_YCbCrVkDescriptor = 0x000004C1, - WGPUSType_SharedTextureMemoryAHardwareBufferProperties = 0x000004C2, - WGPUSType_AHardwareBufferProperties = 0x000004C3, + WGPUSType_ShaderSourceSPIRV = 0x00000001, + WGPUSType_ShaderSourceWGSL = 0x00000002, + WGPUSType_RenderPassMaxDrawCount = 0x00000003, + WGPUSType_SurfaceSourceMetalLayer = 0x00000004, + WGPUSType_SurfaceSourceWindowsHWND = 0x00000005, + WGPUSType_SurfaceSourceXlibWindow = 0x00000006, + WGPUSType_SurfaceSourceWaylandSurface = 0x00000007, + WGPUSType_SurfaceSourceAndroidNativeWindow = 0x00000008, + WGPUSType_SurfaceSourceXCBWindow = 0x00000009, + WGPUSType_AdapterPropertiesSubgroups = 0x0000000A, + WGPUSType_TextureBindingViewDimensionDescriptor = 0x00020000, + WGPUSType_EmscriptenSurfaceSourceCanvasHTMLSelector = 0x00040000, + WGPUSType_SurfaceDescriptorFromWindowsCoreWindow = 0x00050000, + WGPUSType_ExternalTextureBindingEntry = 0x00050001, + WGPUSType_ExternalTextureBindingLayout = 0x00050002, + WGPUSType_SurfaceDescriptorFromWindowsSwapChainPanel = 0x00050003, + WGPUSType_DawnTextureInternalUsageDescriptor = 0x00050004, + WGPUSType_DawnEncoderInternalUsageDescriptor = 0x00050005, + WGPUSType_DawnInstanceDescriptor = 0x00050006, + WGPUSType_DawnCacheDeviceDescriptor = 0x00050007, + WGPUSType_DawnAdapterPropertiesPowerPreference = 0x00050008, + WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient = 0x00050009, + WGPUSType_DawnTogglesDescriptor = 0x0005000A, + WGPUSType_DawnShaderModuleSPIRVOptionsDescriptor = 0x0005000B, + WGPUSType_RequestAdapterOptionsLUID = 0x0005000C, + WGPUSType_RequestAdapterOptionsGetGLProc = 0x0005000D, + WGPUSType_RequestAdapterOptionsD3D11Device = 0x0005000E, + WGPUSType_DawnRenderPassColorAttachmentRenderToSingleSampled = 0x0005000F, + WGPUSType_RenderPassPixelLocalStorage = 0x00050010, + WGPUSType_PipelineLayoutPixelLocalStorage = 0x00050011, + WGPUSType_BufferHostMappedPointer = 0x00050012, + WGPUSType_DawnExperimentalSubgroupLimits = 0x00050013, + WGPUSType_AdapterPropertiesMemoryHeaps = 0x00050014, + WGPUSType_AdapterPropertiesD3D = 0x00050015, + WGPUSType_AdapterPropertiesVk = 0x00050016, + WGPUSType_DawnWireWGSLControl = 0x00050017, + WGPUSType_DawnWGSLBlocklist = 0x00050018, + WGPUSType_DawnDrmFormatCapabilities = 0x00050019, + WGPUSType_ShaderModuleCompilationOptions = 0x0005001A, + WGPUSType_ColorTargetStateExpandResolveTextureDawn = 0x0005001B, + WGPUSType_RenderPassDescriptorExpandResolveRect = 0x0005001C, + WGPUSType_SharedTextureMemoryVkDedicatedAllocationDescriptor = 0x0005001D, + WGPUSType_SharedTextureMemoryAHardwareBufferDescriptor = 0x0005001E, + WGPUSType_SharedTextureMemoryDmaBufDescriptor = 0x0005001F, + WGPUSType_SharedTextureMemoryOpaqueFDDescriptor = 0x00050020, + WGPUSType_SharedTextureMemoryZirconHandleDescriptor = 0x00050021, + WGPUSType_SharedTextureMemoryDXGISharedHandleDescriptor = 0x00050022, + WGPUSType_SharedTextureMemoryD3D11Texture2DDescriptor = 0x00050023, + WGPUSType_SharedTextureMemoryIOSurfaceDescriptor = 0x00050024, + WGPUSType_SharedTextureMemoryEGLImageDescriptor = 0x00050025, + WGPUSType_SharedTextureMemoryInitializedBeginState = 0x00050026, + WGPUSType_SharedTextureMemoryInitializedEndState = 0x00050027, + WGPUSType_SharedTextureMemoryVkImageLayoutBeginState = 0x00050028, + WGPUSType_SharedTextureMemoryVkImageLayoutEndState = 0x00050029, + WGPUSType_SharedTextureMemoryD3DSwapchainBeginState = 0x0005002A, + WGPUSType_SharedFenceVkSemaphoreOpaqueFDDescriptor = 0x0005002B, + WGPUSType_SharedFenceVkSemaphoreOpaqueFDExportInfo = 0x0005002C, + WGPUSType_SharedFenceSyncFDDescriptor = 0x0005002D, + WGPUSType_SharedFenceSyncFDExportInfo = 0x0005002E, + WGPUSType_SharedFenceVkSemaphoreZirconHandleDescriptor = 0x0005002F, + WGPUSType_SharedFenceVkSemaphoreZirconHandleExportInfo = 0x00050030, + WGPUSType_SharedFenceDXGISharedHandleDescriptor = 0x00050031, + WGPUSType_SharedFenceDXGISharedHandleExportInfo = 0x00050032, + WGPUSType_SharedFenceMTLSharedEventDescriptor = 0x00050033, + WGPUSType_SharedFenceMTLSharedEventExportInfo = 0x00050034, + WGPUSType_SharedBufferMemoryD3D12ResourceDescriptor = 0x00050035, + WGPUSType_StaticSamplerBindingLayout = 0x00050036, + WGPUSType_YCbCrVkDescriptor = 0x00050037, + WGPUSType_SharedTextureMemoryAHardwareBufferProperties = 0x00050038, + WGPUSType_AHardwareBufferProperties = 0x00050039, + WGPUSType_DawnExperimentalImmediateDataLimits = 0x0005003A, + WGPUSType_DawnTexelCopyBufferRowAlignmentLimits = 0x0005003B, WGPUSType_Force32 = 0x7FFFFFFF } WGPUSType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUSamplerBindingType { - WGPUSamplerBindingType_Undefined = 0x00000000, - WGPUSamplerBindingType_Filtering = 0x00000001, - WGPUSamplerBindingType_NonFiltering = 0x00000002, - WGPUSamplerBindingType_Comparison = 0x00000003, + WGPUSamplerBindingType_BindingNotUsed = 0x00000000, + WGPUSamplerBindingType_Undefined = 0x00000001, + WGPUSamplerBindingType_Filtering = 0x00000002, + WGPUSamplerBindingType_NonFiltering = 0x00000003, + WGPUSamplerBindingType_Comparison = 0x00000004, WGPUSamplerBindingType_Force32 = 0x7FFFFFFF } WGPUSamplerBindingType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUSharedFenceType { - WGPUSharedFenceType_Undefined = 0x00000000, WGPUSharedFenceType_VkSemaphoreOpaqueFD = 0x00000001, - WGPUSharedFenceType_VkSemaphoreSyncFD = 0x00000002, + WGPUSharedFenceType_SyncFD = 0x00000002, WGPUSharedFenceType_VkSemaphoreZirconHandle = 0x00000003, WGPUSharedFenceType_DXGISharedHandle = 0x00000004, WGPUSharedFenceType_MTLSharedEvent = 0x00000005, WGPUSharedFenceType_Force32 = 0x7FFFFFFF } WGPUSharedFenceType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUStatus { - WGPUStatus_Success = 0x00000000, - WGPUStatus_Error = 0x00000001, + WGPUStatus_Success = 0x00000001, + WGPUStatus_Error = 0x00000002, WGPUStatus_Force32 = 0x7FFFFFFF } WGPUStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUStencilOperation { WGPUStencilOperation_Undefined = 0x00000000, WGPUStencilOperation_Keep = 0x00000001, @@ -788,44 +757,40 @@ typedef enum WGPUStencilOperation { WGPUStencilOperation_DecrementWrap = 0x00000008, WGPUStencilOperation_Force32 = 0x7FFFFFFF } WGPUStencilOperation WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUStorageTextureAccess { - WGPUStorageTextureAccess_Undefined = 0x00000000, - WGPUStorageTextureAccess_WriteOnly = 0x00000001, - WGPUStorageTextureAccess_ReadOnly = 0x00000002, - WGPUStorageTextureAccess_ReadWrite = 0x00000003, + WGPUStorageTextureAccess_BindingNotUsed = 0x00000000, + WGPUStorageTextureAccess_Undefined = 0x00000001, + WGPUStorageTextureAccess_WriteOnly = 0x00000002, + WGPUStorageTextureAccess_ReadOnly = 0x00000003, + WGPUStorageTextureAccess_ReadWrite = 0x00000004, WGPUStorageTextureAccess_Force32 = 0x7FFFFFFF } WGPUStorageTextureAccess WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUStoreOp { WGPUStoreOp_Undefined = 0x00000000, WGPUStoreOp_Store = 0x00000001, WGPUStoreOp_Discard = 0x00000002, WGPUStoreOp_Force32 = 0x7FFFFFFF } WGPUStoreOp WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUSurfaceGetCurrentTextureStatus { - WGPUSurfaceGetCurrentTextureStatus_Success = 0x00000000, - WGPUSurfaceGetCurrentTextureStatus_Timeout = 0x00000001, - WGPUSurfaceGetCurrentTextureStatus_Outdated = 0x00000002, - WGPUSurfaceGetCurrentTextureStatus_Lost = 0x00000003, - WGPUSurfaceGetCurrentTextureStatus_OutOfMemory = 0x00000004, - WGPUSurfaceGetCurrentTextureStatus_DeviceLost = 0x00000005, - WGPUSurfaceGetCurrentTextureStatus_Error = 0x00000006, + WGPUSurfaceGetCurrentTextureStatus_Success = 0x00000001, + WGPUSurfaceGetCurrentTextureStatus_Timeout = 0x00000002, + WGPUSurfaceGetCurrentTextureStatus_Outdated = 0x00000003, + WGPUSurfaceGetCurrentTextureStatus_Lost = 0x00000004, + WGPUSurfaceGetCurrentTextureStatus_OutOfMemory = 0x00000005, + WGPUSurfaceGetCurrentTextureStatus_DeviceLost = 0x00000006, + WGPUSurfaceGetCurrentTextureStatus_Error = 0x00000007, WGPUSurfaceGetCurrentTextureStatus_Force32 = 0x7FFFFFFF } WGPUSurfaceGetCurrentTextureStatus WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUTextureAspect { WGPUTextureAspect_Undefined = 0x00000000, WGPUTextureAspect_All = 0x00000001, WGPUTextureAspect_StencilOnly = 0x00000002, WGPUTextureAspect_DepthOnly = 0x00000003, - WGPUTextureAspect_Plane0Only = 0x00000004, - WGPUTextureAspect_Plane1Only = 0x00000005, - WGPUTextureAspect_Plane2Only = 0x00000006, + WGPUTextureAspect_Plane0Only = 0x00050000, + WGPUTextureAspect_Plane1Only = 0x00050001, + WGPUTextureAspect_Plane2Only = 0x00050002, WGPUTextureAspect_Force32 = 0x7FFFFFFF } WGPUTextureAspect WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUTextureDimension { WGPUTextureDimension_Undefined = 0x00000000, WGPUTextureDimension_1D = 0x00000001, @@ -833,7 +798,6 @@ typedef enum WGPUTextureDimension { WGPUTextureDimension_3D = 0x00000003, WGPUTextureDimension_Force32 = 0x7FFFFFFF } WGPUTextureDimension WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUTextureFormat { WGPUTextureFormat_Undefined = 0x00000000, WGPUTextureFormat_R8Unorm = 0x00000001, @@ -931,33 +895,32 @@ typedef enum WGPUTextureFormat { WGPUTextureFormat_ASTC12x10UnormSrgb = 0x0000005D, WGPUTextureFormat_ASTC12x12Unorm = 0x0000005E, WGPUTextureFormat_ASTC12x12UnormSrgb = 0x0000005F, - WGPUTextureFormat_R16Unorm = 0x00000060, - WGPUTextureFormat_RG16Unorm = 0x00000061, - WGPUTextureFormat_RGBA16Unorm = 0x00000062, - WGPUTextureFormat_R16Snorm = 0x00000063, - WGPUTextureFormat_RG16Snorm = 0x00000064, - WGPUTextureFormat_RGBA16Snorm = 0x00000065, - WGPUTextureFormat_R8BG8Biplanar420Unorm = 0x00000066, - WGPUTextureFormat_R10X6BG10X6Biplanar420Unorm = 0x00000067, - WGPUTextureFormat_R8BG8A8Triplanar420Unorm = 0x00000068, - WGPUTextureFormat_R8BG8Biplanar422Unorm = 0x00000069, - WGPUTextureFormat_R8BG8Biplanar444Unorm = 0x0000006A, - WGPUTextureFormat_R10X6BG10X6Biplanar422Unorm = 0x0000006B, - WGPUTextureFormat_R10X6BG10X6Biplanar444Unorm = 0x0000006C, - WGPUTextureFormat_External = 0x0000006D, + WGPUTextureFormat_R16Unorm = 0x00050000, + WGPUTextureFormat_RG16Unorm = 0x00050001, + WGPUTextureFormat_RGBA16Unorm = 0x00050002, + WGPUTextureFormat_R16Snorm = 0x00050003, + WGPUTextureFormat_RG16Snorm = 0x00050004, + WGPUTextureFormat_RGBA16Snorm = 0x00050005, + WGPUTextureFormat_R8BG8Biplanar420Unorm = 0x00050006, + WGPUTextureFormat_R10X6BG10X6Biplanar420Unorm = 0x00050007, + WGPUTextureFormat_R8BG8A8Triplanar420Unorm = 0x00050008, + WGPUTextureFormat_R8BG8Biplanar422Unorm = 0x00050009, + WGPUTextureFormat_R8BG8Biplanar444Unorm = 0x0005000A, + WGPUTextureFormat_R10X6BG10X6Biplanar422Unorm = 0x0005000B, + WGPUTextureFormat_R10X6BG10X6Biplanar444Unorm = 0x0005000C, + WGPUTextureFormat_External = 0x0005000D, WGPUTextureFormat_Force32 = 0x7FFFFFFF } WGPUTextureFormat WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUTextureSampleType { - WGPUTextureSampleType_Undefined = 0x00000000, - WGPUTextureSampleType_Float = 0x00000001, - WGPUTextureSampleType_UnfilterableFloat = 0x00000002, - WGPUTextureSampleType_Depth = 0x00000003, - WGPUTextureSampleType_Sint = 0x00000004, - WGPUTextureSampleType_Uint = 0x00000005, + WGPUTextureSampleType_BindingNotUsed = 0x00000000, + WGPUTextureSampleType_Undefined = 0x00000001, + WGPUTextureSampleType_Float = 0x00000002, + WGPUTextureSampleType_UnfilterableFloat = 0x00000003, + WGPUTextureSampleType_Depth = 0x00000004, + WGPUTextureSampleType_Sint = 0x00000005, + WGPUTextureSampleType_Uint = 0x00000006, WGPUTextureSampleType_Force32 = 0x7FFFFFFF } WGPUTextureSampleType WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUTextureViewDimension { WGPUTextureViewDimension_Undefined = 0x00000000, WGPUTextureViewDimension_1D = 0x00000001, @@ -968,304 +931,395 @@ typedef enum WGPUTextureViewDimension { WGPUTextureViewDimension_3D = 0x00000006, WGPUTextureViewDimension_Force32 = 0x7FFFFFFF } WGPUTextureViewDimension WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUVertexFormat { - WGPUVertexFormat_Undefined = 0x00000000, - WGPUVertexFormat_Uint8x2 = 0x00000001, - WGPUVertexFormat_Uint8x4 = 0x00000002, - WGPUVertexFormat_Sint8x2 = 0x00000003, - WGPUVertexFormat_Sint8x4 = 0x00000004, - WGPUVertexFormat_Unorm8x2 = 0x00000005, - WGPUVertexFormat_Unorm8x4 = 0x00000006, - WGPUVertexFormat_Snorm8x2 = 0x00000007, - WGPUVertexFormat_Snorm8x4 = 0x00000008, - WGPUVertexFormat_Uint16x2 = 0x00000009, - WGPUVertexFormat_Uint16x4 = 0x0000000A, - WGPUVertexFormat_Sint16x2 = 0x0000000B, - WGPUVertexFormat_Sint16x4 = 0x0000000C, - WGPUVertexFormat_Unorm16x2 = 0x0000000D, - WGPUVertexFormat_Unorm16x4 = 0x0000000E, - WGPUVertexFormat_Snorm16x2 = 0x0000000F, - WGPUVertexFormat_Snorm16x4 = 0x00000010, - WGPUVertexFormat_Float16x2 = 0x00000011, - WGPUVertexFormat_Float16x4 = 0x00000012, - WGPUVertexFormat_Float32 = 0x00000013, - WGPUVertexFormat_Float32x2 = 0x00000014, - WGPUVertexFormat_Float32x3 = 0x00000015, - WGPUVertexFormat_Float32x4 = 0x00000016, - WGPUVertexFormat_Uint32 = 0x00000017, - WGPUVertexFormat_Uint32x2 = 0x00000018, - WGPUVertexFormat_Uint32x3 = 0x00000019, - WGPUVertexFormat_Uint32x4 = 0x0000001A, - WGPUVertexFormat_Sint32 = 0x0000001B, - WGPUVertexFormat_Sint32x2 = 0x0000001C, - WGPUVertexFormat_Sint32x3 = 0x0000001D, - WGPUVertexFormat_Sint32x4 = 0x0000001E, - WGPUVertexFormat_Unorm10_10_10_2 = 0x0000001F, + WGPUVertexFormat_Uint8 = 0x00000001, + WGPUVertexFormat_Uint8x2 = 0x00000002, + WGPUVertexFormat_Uint8x4 = 0x00000003, + WGPUVertexFormat_Sint8 = 0x00000004, + WGPUVertexFormat_Sint8x2 = 0x00000005, + WGPUVertexFormat_Sint8x4 = 0x00000006, + WGPUVertexFormat_Unorm8 = 0x00000007, + WGPUVertexFormat_Unorm8x2 = 0x00000008, + WGPUVertexFormat_Unorm8x4 = 0x00000009, + WGPUVertexFormat_Snorm8 = 0x0000000A, + WGPUVertexFormat_Snorm8x2 = 0x0000000B, + WGPUVertexFormat_Snorm8x4 = 0x0000000C, + WGPUVertexFormat_Uint16 = 0x0000000D, + WGPUVertexFormat_Uint16x2 = 0x0000000E, + WGPUVertexFormat_Uint16x4 = 0x0000000F, + WGPUVertexFormat_Sint16 = 0x00000010, + WGPUVertexFormat_Sint16x2 = 0x00000011, + WGPUVertexFormat_Sint16x4 = 0x00000012, + WGPUVertexFormat_Unorm16 = 0x00000013, + WGPUVertexFormat_Unorm16x2 = 0x00000014, + WGPUVertexFormat_Unorm16x4 = 0x00000015, + WGPUVertexFormat_Snorm16 = 0x00000016, + WGPUVertexFormat_Snorm16x2 = 0x00000017, + WGPUVertexFormat_Snorm16x4 = 0x00000018, + WGPUVertexFormat_Float16 = 0x00000019, + WGPUVertexFormat_Float16x2 = 0x0000001A, + WGPUVertexFormat_Float16x4 = 0x0000001B, + WGPUVertexFormat_Float32 = 0x0000001C, + WGPUVertexFormat_Float32x2 = 0x0000001D, + WGPUVertexFormat_Float32x3 = 0x0000001E, + WGPUVertexFormat_Float32x4 = 0x0000001F, + WGPUVertexFormat_Uint32 = 0x00000020, + WGPUVertexFormat_Uint32x2 = 0x00000021, + WGPUVertexFormat_Uint32x3 = 0x00000022, + WGPUVertexFormat_Uint32x4 = 0x00000023, + WGPUVertexFormat_Sint32 = 0x00000024, + WGPUVertexFormat_Sint32x2 = 0x00000025, + WGPUVertexFormat_Sint32x3 = 0x00000026, + WGPUVertexFormat_Sint32x4 = 0x00000027, + WGPUVertexFormat_Unorm10_10_10_2 = 0x00000028, + WGPUVertexFormat_Unorm8x4BGRA = 0x00000029, WGPUVertexFormat_Force32 = 0x7FFFFFFF } WGPUVertexFormat WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUVertexStepMode { WGPUVertexStepMode_Undefined = 0x00000000, - WGPUVertexStepMode_VertexBufferNotUsed = 0x00000001, - WGPUVertexStepMode_Vertex = 0x00000002, - WGPUVertexStepMode_Instance = 0x00000003, + WGPUVertexStepMode_Vertex = 0x00000001, + WGPUVertexStepMode_Instance = 0x00000002, WGPUVertexStepMode_Force32 = 0x7FFFFFFF } WGPUVertexStepMode WGPU_ENUM_ATTRIBUTE; - typedef enum WGPUWaitStatus { - WGPUWaitStatus_Success = 0x00000000, - WGPUWaitStatus_TimedOut = 0x00000001, - WGPUWaitStatus_UnsupportedTimeout = 0x00000002, - WGPUWaitStatus_UnsupportedCount = 0x00000003, - WGPUWaitStatus_UnsupportedMixedSources = 0x00000004, - WGPUWaitStatus_Unknown = 0x00000005, + WGPUWaitStatus_Success = 0x00000001, + WGPUWaitStatus_TimedOut = 0x00000002, + WGPUWaitStatus_Error = 0x00000003, WGPUWaitStatus_Force32 = 0x7FFFFFFF } WGPUWaitStatus WGPU_ENUM_ATTRIBUTE; -typedef enum WGPUBufferUsage { - WGPUBufferUsage_None = 0x00000000, - WGPUBufferUsage_MapRead = 0x00000001, - WGPUBufferUsage_MapWrite = 0x00000002, - WGPUBufferUsage_CopySrc = 0x00000004, - WGPUBufferUsage_CopyDst = 0x00000008, - WGPUBufferUsage_Index = 0x00000010, - WGPUBufferUsage_Vertex = 0x00000020, - WGPUBufferUsage_Uniform = 0x00000040, - WGPUBufferUsage_Storage = 0x00000080, - WGPUBufferUsage_Indirect = 0x00000100, - WGPUBufferUsage_QueryResolve = 0x00000200, - WGPUBufferUsage_Force32 = 0x7FFFFFFF -} WGPUBufferUsage WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUBufferUsageFlags WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUColorWriteMask { - WGPUColorWriteMask_None = 0x00000000, - WGPUColorWriteMask_Red = 0x00000001, - WGPUColorWriteMask_Green = 0x00000002, - WGPUColorWriteMask_Blue = 0x00000004, - WGPUColorWriteMask_Alpha = 0x00000008, - WGPUColorWriteMask_All = 0x0000000F, - WGPUColorWriteMask_Force32 = 0x7FFFFFFF -} WGPUColorWriteMask WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUColorWriteMaskFlags WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUHeapProperty { - WGPUHeapProperty_Undefined = 0x00000000, - WGPUHeapProperty_DeviceLocal = 0x00000001, - WGPUHeapProperty_HostVisible = 0x00000002, - WGPUHeapProperty_HostCoherent = 0x00000004, - WGPUHeapProperty_HostUncached = 0x00000008, - WGPUHeapProperty_HostCached = 0x00000010, - WGPUHeapProperty_Force32 = 0x7FFFFFFF -} WGPUHeapProperty WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUHeapPropertyFlags WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUMapMode { - WGPUMapMode_None = 0x00000000, - WGPUMapMode_Read = 0x00000001, - WGPUMapMode_Write = 0x00000002, - WGPUMapMode_Force32 = 0x7FFFFFFF -} WGPUMapMode WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUMapModeFlags WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUShaderStage { - WGPUShaderStage_None = 0x00000000, - WGPUShaderStage_Vertex = 0x00000001, - WGPUShaderStage_Fragment = 0x00000002, - WGPUShaderStage_Compute = 0x00000004, - WGPUShaderStage_Force32 = 0x7FFFFFFF -} WGPUShaderStage WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUShaderStageFlags WGPU_ENUM_ATTRIBUTE; - -typedef enum WGPUTextureUsage { - WGPUTextureUsage_None = 0x00000000, - WGPUTextureUsage_CopySrc = 0x00000001, - WGPUTextureUsage_CopyDst = 0x00000002, - WGPUTextureUsage_TextureBinding = 0x00000004, - WGPUTextureUsage_StorageBinding = 0x00000008, - WGPUTextureUsage_RenderAttachment = 0x00000010, - WGPUTextureUsage_TransientAttachment = 0x00000020, - WGPUTextureUsage_StorageAttachment = 0x00000040, - WGPUTextureUsage_Force32 = 0x7FFFFFFF -} WGPUTextureUsage WGPU_ENUM_ATTRIBUTE; -typedef WGPUFlags WGPUTextureUsageFlags WGPU_ENUM_ATTRIBUTE; - -typedef void (*WGPUBufferMapCallback)(WGPUBufferMapAsyncStatus status, void * userdata) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFlags WGPUBufferUsage; +static const WGPUBufferUsage WGPUBufferUsage_None = 0x0000000000000000; +static const WGPUBufferUsage WGPUBufferUsage_MapRead = 0x0000000000000001; +static const WGPUBufferUsage WGPUBufferUsage_MapWrite = 0x0000000000000002; +static const WGPUBufferUsage WGPUBufferUsage_CopySrc = 0x0000000000000004; +static const WGPUBufferUsage WGPUBufferUsage_CopyDst = 0x0000000000000008; +static const WGPUBufferUsage WGPUBufferUsage_Index = 0x0000000000000010; +static const WGPUBufferUsage WGPUBufferUsage_Vertex = 0x0000000000000020; +static const WGPUBufferUsage WGPUBufferUsage_Uniform = 0x0000000000000040; +static const WGPUBufferUsage WGPUBufferUsage_Storage = 0x0000000000000080; +static const WGPUBufferUsage WGPUBufferUsage_Indirect = 0x0000000000000100; +static const WGPUBufferUsage WGPUBufferUsage_QueryResolve = 0x0000000000000200; +typedef WGPUFlags WGPUColorWriteMask; +static const WGPUColorWriteMask WGPUColorWriteMask_None = 0x0000000000000000; +static const WGPUColorWriteMask WGPUColorWriteMask_Red = 0x0000000000000001; +static const WGPUColorWriteMask WGPUColorWriteMask_Green = 0x0000000000000002; +static const WGPUColorWriteMask WGPUColorWriteMask_Blue = 0x0000000000000004; +static const WGPUColorWriteMask WGPUColorWriteMask_Alpha = 0x0000000000000008; +static const WGPUColorWriteMask WGPUColorWriteMask_All = 0x000000000000000F; +typedef WGPUFlags WGPUHeapProperty; +static const WGPUHeapProperty WGPUHeapProperty_DeviceLocal = 0x0000000000000001; +static const WGPUHeapProperty WGPUHeapProperty_HostVisible = 0x0000000000000002; +static const WGPUHeapProperty WGPUHeapProperty_HostCoherent = 0x0000000000000004; +static const WGPUHeapProperty WGPUHeapProperty_HostUncached = 0x0000000000000008; +static const WGPUHeapProperty WGPUHeapProperty_HostCached = 0x0000000000000010; +typedef WGPUFlags WGPUMapMode; +static const WGPUMapMode WGPUMapMode_None = 0x0000000000000000; +static const WGPUMapMode WGPUMapMode_Read = 0x0000000000000001; +static const WGPUMapMode WGPUMapMode_Write = 0x0000000000000002; +typedef WGPUFlags WGPUShaderStage; +static const WGPUShaderStage WGPUShaderStage_None = 0x0000000000000000; +static const WGPUShaderStage WGPUShaderStage_Vertex = 0x0000000000000001; +static const WGPUShaderStage WGPUShaderStage_Fragment = 0x0000000000000002; +static const WGPUShaderStage WGPUShaderStage_Compute = 0x0000000000000004; +typedef WGPUFlags WGPUTextureUsage; +static const WGPUTextureUsage WGPUTextureUsage_None = 0x0000000000000000; +static const WGPUTextureUsage WGPUTextureUsage_CopySrc = 0x0000000000000001; +static const WGPUTextureUsage WGPUTextureUsage_CopyDst = 0x0000000000000002; +static const WGPUTextureUsage WGPUTextureUsage_TextureBinding = 0x0000000000000004; +static const WGPUTextureUsage WGPUTextureUsage_StorageBinding = 0x0000000000000008; +static const WGPUTextureUsage WGPUTextureUsage_RenderAttachment = 0x0000000000000010; +static const WGPUTextureUsage WGPUTextureUsage_TransientAttachment = 0x0000000000000020; +static const WGPUTextureUsage WGPUTextureUsage_StorageAttachment = 0x0000000000000040; typedef void (*WGPUCallback)(void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCompilationInfoCallback)(WGPUCompilationInfoRequestStatus status, struct WGPUCompilationInfo const * compilationInfo, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCreateComputePipelineAsyncCallback)(WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline pipeline, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCreateRenderPipelineAsyncCallback)(WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline pipeline, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; typedef size_t (*WGPUDawnLoadCacheDataFunction)(void const * key, size_t keySize, void * value, size_t valueSize, void * userdata) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUDawnStoreCacheDataFunction)(void const * key, size_t keySize, void const * value, size_t valueSize, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUDeviceLostCallback)(WGPUDeviceLostReason reason, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUDeviceLostCallbackNew)(WGPUDevice const * device, WGPUDeviceLostReason reason, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUErrorCallback)(WGPUErrorType type, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPULoggingCallback)(WGPULoggingType type, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUPopErrorScopeCallback)(WGPUPopErrorScopeStatus status, WGPUErrorType type, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProc)(void) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUQueueWorkDoneCallback)(WGPUQueueWorkDoneStatus status, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPURequestAdapterCallback)(WGPURequestAdapterStatus status, WGPUAdapter adapter, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPURequestDeviceCallback)(WGPURequestDeviceStatus status, WGPUDevice device, char const * message, void * userdata) WGPU_FUNCTION_ATTRIBUTE; // Callback function pointers -typedef void (*WGPUBufferMapCallback2)(WGPUMapAsyncStatus status, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCompilationInfoCallback2)(WGPUCompilationInfoRequestStatus status, struct WGPUCompilationInfo const * compilationInfo, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCreateComputePipelineAsyncCallback2)(WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline pipeline, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUCreateRenderPipelineAsyncCallback2)(WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline pipeline, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUPopErrorScopeCallback2)(WGPUPopErrorScopeStatus status, WGPUErrorType type, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUQueueWorkDoneCallback2)(WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPURequestAdapterCallback2)(WGPURequestAdapterStatus status, WGPUAdapter adapter, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPURequestDeviceCallback2)(WGPURequestDeviceStatus status, WGPUDevice device, char const * message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUBufferMapCallback)(WGPUMapAsyncStatus status, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUCompilationInfoCallback)(WGPUCompilationInfoRequestStatus status, struct WGPUCompilationInfo const * compilationInfo, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUCreateComputePipelineAsyncCallback)(WGPUCreatePipelineAsyncStatus status, WGPUComputePipeline pipeline, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUCreateRenderPipelineAsyncCallback)(WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline pipeline, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUDeviceLostCallback)(WGPUDevice const * device, WGPUDeviceLostReason reason, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPULoggingCallback)(WGPULoggingType type, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUPopErrorScopeCallback)(WGPUPopErrorScopeStatus status, WGPUErrorType type, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUQueueWorkDoneCallback)(WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPURequestAdapterCallback)(WGPURequestAdapterStatus status, WGPUAdapter adapter, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPURequestDeviceCallback)(WGPURequestDeviceStatus status, WGPUDevice device, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUUncapturedErrorCallback)(WGPUDevice const * device, WGPUErrorType type, struct WGPUStringView message, void* userdata1, void* userdata2) WGPU_FUNCTION_ATTRIBUTE; typedef struct WGPUChainedStruct { - struct WGPUChainedStruct const * next; + struct WGPUChainedStruct * next; WGPUSType sType; } WGPUChainedStruct WGPU_STRUCTURE_ATTRIBUTE; -typedef struct WGPUChainedStructOut { - struct WGPUChainedStructOut * next; - WGPUSType sType; -} WGPUChainedStructOut WGPU_STRUCTURE_ATTRIBUTE; - #define WGPU_COMMA , -typedef struct WGPUAdapterInfo { - WGPUChainedStructOut * nextInChain; - char const * vendor; - char const * architecture; - char const * device; - char const * description; - WGPUBackendType backendType; - WGPUAdapterType adapterType; - uint32_t vendorID; - uint32_t deviceID; -} WGPUAdapterInfo WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUBufferMapCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUBufferMapCallback callback; + void* userdata1; + void* userdata2; +} WGPUBufferMapCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_ADAPTER_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterInfo, { \ +#define WGPU_BUFFER_MAP_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferMapCallbackInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.vendor=*/nullptr WGPU_COMMA \ - /*.architecture=*/nullptr WGPU_COMMA \ - /*.device=*/nullptr WGPU_COMMA \ - /*.description=*/nullptr WGPU_COMMA \ - /*.backendType=*/{} WGPU_COMMA \ - /*.adapterType=*/{} WGPU_COMMA \ - /*.vendorID=*/{} WGPU_COMMA \ - /*.deviceID=*/{} WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUAdapterProperties { - WGPUChainedStructOut * nextInChain; - uint32_t vendorID; - char const * vendorName; - char const * architecture; - uint32_t deviceID; - char const * name; - char const * driverDescription; - WGPUAdapterType adapterType; - WGPUBackendType backendType; - WGPUBool compatibilityMode; -} WGPUAdapterProperties WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUCompilationInfoCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUCompilationInfoCallback callback; + void* userdata1; + void* userdata2; +} WGPUCompilationInfoCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_ADAPTER_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterProperties, { \ +#define WGPU_COMPILATION_INFO_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfoCallbackInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.vendorID=*/{} WGPU_COMMA \ - /*.vendorName=*/nullptr WGPU_COMMA \ - /*.architecture=*/nullptr WGPU_COMMA \ - /*.deviceID=*/{} WGPU_COMMA \ - /*.name=*/nullptr WGPU_COMMA \ - /*.driverDescription=*/nullptr WGPU_COMMA \ - /*.adapterType=*/{} WGPU_COMMA \ - /*.backendType=*/{} WGPU_COMMA \ - /*.compatibilityMode=*/false WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -// Can be chained in WGPUAdapterProperties -typedef struct WGPUAdapterPropertiesD3D { - WGPUChainedStructOut chain; - uint32_t shaderModel; -} WGPUAdapterPropertiesD3D WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUCreateComputePipelineAsyncCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUCreateComputePipelineAsyncCallback callback; + void* userdata1; + void* userdata2; +} WGPUCreateComputePipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_ADAPTER_PROPERTIES_D3D_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesD3D, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.shaderModel=*/{} WGPU_COMMA \ +#define WGPU_CREATE_COMPUTE_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateComputePipelineAsyncCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -// Can be chained in WGPUAdapterProperties -typedef struct WGPUAdapterPropertiesVk { - WGPUChainedStructOut chain; - uint32_t driverVersion; -} WGPUAdapterPropertiesVk WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUCreateRenderPipelineAsyncCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUCreateRenderPipelineAsyncCallback callback; + void* userdata1; + void* userdata2; +} WGPUCreateRenderPipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_ADAPTER_PROPERTIES_VK_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesVk, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.driverVersion=*/{} WGPU_COMMA \ +#define WGPU_CREATE_RENDER_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateRenderPipelineAsyncCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUBindGroupEntry { - WGPUChainedStruct const * nextInChain; - uint32_t binding; - WGPU_NULLABLE WGPUBuffer buffer; - uint64_t offset; - uint64_t size; - WGPU_NULLABLE WGPUSampler sampler; - WGPU_NULLABLE WGPUTextureView textureView; -} WGPUBindGroupEntry WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUDeviceLostCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUDeviceLostCallback callback; + void* userdata1; + void* userdata2; +} WGPUDeviceLostCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_BIND_GROUP_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupEntry, { \ +#define WGPU_DEVICE_LOST_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUDeviceLostCallbackInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.binding=*/{} WGPU_COMMA \ - /*.buffer=*/nullptr WGPU_COMMA \ - /*.offset=*/0 WGPU_COMMA \ - /*.size=*/WGPU_WHOLE_SIZE WGPU_COMMA \ - /*.sampler=*/nullptr WGPU_COMMA \ - /*.textureView=*/nullptr WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUBlendComponent { - WGPUBlendOperation operation; - WGPUBlendFactor srcFactor; - WGPUBlendFactor dstFactor; -} WGPUBlendComponent WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPULoggingCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPULoggingCallback callback; + void* userdata1; + void* userdata2; +} WGPULoggingCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_BLEND_COMPONENT_INIT WGPU_MAKE_INIT_STRUCT(WGPUBlendComponent, { \ - /*.operation=*/WGPUBlendOperation_Add WGPU_COMMA \ - /*.srcFactor=*/WGPUBlendFactor_One WGPU_COMMA \ - /*.dstFactor=*/WGPUBlendFactor_Zero WGPU_COMMA \ +#define WGPU_LOGGING_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPULoggingCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUBufferBindingLayout { - WGPUChainedStruct const * nextInChain; - WGPUBufferBindingType type; - WGPUBool hasDynamicOffset; - uint64_t minBindingSize; -} WGPUBufferBindingLayout WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUPopErrorScopeCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUPopErrorScopeCallback callback; + void* userdata1; + void* userdata2; +} WGPUPopErrorScopeCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_BUFFER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferBindingLayout, { \ +#define WGPU_POP_ERROR_SCOPE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUPopErrorScopeCallbackInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.type=*/WGPUBufferBindingType_Undefined WGPU_COMMA \ - /*.hasDynamicOffset=*/false WGPU_COMMA \ - /*.minBindingSize=*/0 WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUBufferDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUBufferUsageFlags usage; - uint64_t size; - WGPUBool mappedAtCreation; -} WGPUBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUQueueWorkDoneCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPUQueueWorkDoneCallback callback; + void* userdata1; + void* userdata2; +} WGPUQueueWorkDoneCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferDescriptor, { \ +#define WGPU_QUEUE_WORK_DONE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueWorkDoneCallbackInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.usage=*/{} WGPU_COMMA \ - /*.size=*/{} WGPU_COMMA \ - /*.mappedAtCreation=*/false WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ }) -// Can be chained in WGPUBufferDescriptor +typedef struct WGPURequestAdapterCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPURequestAdapterCallback callback; + void* userdata1; + void* userdata2; +} WGPURequestAdapterCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_REQUEST_ADAPTER_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ +}) + +typedef struct WGPURequestDeviceCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUCallbackMode mode; + WGPURequestDeviceCallback callback; + void* userdata1; + void* userdata2; +} WGPURequestDeviceCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_REQUEST_DEVICE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestDeviceCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.mode=*/{} WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ +}) + +typedef struct WGPUUncapturedErrorCallbackInfo { + WGPUChainedStruct* nextInChain; + WGPUUncapturedErrorCallback callback; + void* userdata1; + void* userdata2; +} WGPUUncapturedErrorCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_UNCAPTURED_ERROR_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUUncapturedErrorCallbackInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.callback=*/nullptr WGPU_COMMA \ + /*.userdata1=*/nullptr WGPU_COMMA \ + /*.userdata2=*/nullptr WGPU_COMMA \ +}) + + +typedef struct WGPUINTERNAL_HAVE_EMDAWNWEBGPU_HEADER { + WGPUBool unused; +} WGPUINTERNAL_HAVE_EMDAWNWEBGPU_HEADER WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_INTERNAL_HAVE_EMDAWNWEBGPU_HEADER_INIT WGPU_MAKE_INIT_STRUCT(WGPUINTERNAL_HAVE_EMDAWNWEBGPU_HEADER, { \ + /*.unused=*/false WGPU_COMMA \ +}) + +// Can be chained in WGPUAdapterInfo +typedef struct WGPUAdapterPropertiesD3D { + WGPUChainedStruct chain; + uint32_t shaderModel; +} WGPUAdapterPropertiesD3D WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_ADAPTER_PROPERTIES_D3D_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesD3D, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesD3D} WGPU_COMMA \ + /*.shaderModel=*/{} WGPU_COMMA \ +}) + +// Can be chained in WGPUAdapterInfo +typedef struct WGPUAdapterPropertiesSubgroups { + WGPUChainedStruct chain; + uint32_t subgroupMinSize; + uint32_t subgroupMaxSize; +} WGPUAdapterPropertiesSubgroups WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_ADAPTER_PROPERTIES_SUBGROUPS_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesSubgroups, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesSubgroups} WGPU_COMMA \ + /*.subgroupMinSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ + /*.subgroupMaxSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ +}) + +// Can be chained in WGPUAdapterInfo +typedef struct WGPUAdapterPropertiesVk { + WGPUChainedStruct chain; + uint32_t driverVersion; +} WGPUAdapterPropertiesVk WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_ADAPTER_PROPERTIES_VK_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesVk, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesVk} WGPU_COMMA \ + /*.driverVersion=*/{} WGPU_COMMA \ +}) + +typedef struct WGPUBindGroupEntry { + WGPUChainedStruct* nextInChain; + uint32_t binding; + WGPU_NULLABLE WGPUBuffer buffer; + uint64_t offset; + uint64_t size; + WGPU_NULLABLE WGPUSampler sampler; + WGPU_NULLABLE WGPUTextureView textureView; +} WGPUBindGroupEntry WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_BIND_GROUP_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupEntry, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.binding=*/{} WGPU_COMMA \ + /*.buffer=*/nullptr WGPU_COMMA \ + /*.offset=*/0 WGPU_COMMA \ + /*.size=*/WGPU_WHOLE_SIZE WGPU_COMMA \ + /*.sampler=*/nullptr WGPU_COMMA \ + /*.textureView=*/nullptr WGPU_COMMA \ +}) + +typedef struct WGPUBlendComponent { + WGPUBlendOperation operation; + WGPUBlendFactor srcFactor; + WGPUBlendFactor dstFactor; +} WGPUBlendComponent WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_BLEND_COMPONENT_INIT WGPU_MAKE_INIT_STRUCT(WGPUBlendComponent, { \ + /*.operation=*/WGPUBlendOperation_Add WGPU_COMMA \ + /*.srcFactor=*/WGPUBlendFactor_One WGPU_COMMA \ + /*.dstFactor=*/WGPUBlendFactor_Zero WGPU_COMMA \ +}) + +typedef struct WGPUBufferBindingLayout { + WGPUChainedStruct* nextInChain; + WGPUBufferBindingType type; + WGPUBool hasDynamicOffset; + uint64_t minBindingSize; +} WGPUBufferBindingLayout WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_BUFFER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferBindingLayout, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.type=*/WGPUBufferBindingType_Uniform WGPU_COMMA \ + /*.hasDynamicOffset=*/false WGPU_COMMA \ + /*.minBindingSize=*/0 WGPU_COMMA \ +}) + +// Can be chained in WGPUBufferDescriptor typedef struct WGPUBufferHostMappedPointer { WGPUChainedStruct chain; void * pointer; @@ -1274,26 +1328,12 @@ typedef struct WGPUBufferHostMappedPointer { } WGPUBufferHostMappedPointer WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BUFFER_HOST_MAPPED_POINTER_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferHostMappedPointer, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_BufferHostMappedPointer} WGPU_COMMA \ /*.pointer=*/{} WGPU_COMMA \ /*.disposeCallback=*/{} WGPU_COMMA \ /*.userdata=*/{} WGPU_COMMA \ }) -typedef struct WGPUBufferMapCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUBufferMapCallback callback; - void * userdata; -} WGPUBufferMapCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_BUFFER_MAP_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferMapCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - typedef struct WGPUColor { double r; double g; @@ -1315,70 +1355,10 @@ typedef struct WGPUColorTargetStateExpandResolveTextureDawn { } WGPUColorTargetStateExpandResolveTextureDawn WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COLOR_TARGET_STATE_EXPAND_RESOLVE_TEXTURE_DAWN_INIT WGPU_MAKE_INIT_STRUCT(WGPUColorTargetStateExpandResolveTextureDawn, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ColorTargetStateExpandResolveTextureDawn} WGPU_COMMA \ /*.enabled=*/false WGPU_COMMA \ }) -typedef struct WGPUCommandBufferDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUCommandBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMMAND_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandBufferDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCommandEncoderDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUCommandEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMMAND_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandEncoderDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCompilationInfoCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUCompilationInfoCallback callback; - void * userdata; -} WGPUCompilationInfoCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMPILATION_INFO_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfoCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCompilationMessage { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * message; - WGPUCompilationMessageType type; - uint64_t lineNum; - uint64_t linePos; - uint64_t offset; - uint64_t length; - uint64_t utf16LinePos; - uint64_t utf16Offset; - uint64_t utf16Length; -} WGPUCompilationMessage WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMPILATION_MESSAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationMessage, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.message=*/nullptr WGPU_COMMA \ - /*.type=*/{} WGPU_COMMA \ - /*.lineNum=*/{} WGPU_COMMA \ - /*.linePos=*/{} WGPU_COMMA \ - /*.offset=*/{} WGPU_COMMA \ - /*.length=*/{} WGPU_COMMA \ - /*.utf16LinePos=*/{} WGPU_COMMA \ - /*.utf16Offset=*/{} WGPU_COMMA \ - /*.utf16Length=*/{} WGPU_COMMA \ -}) - typedef struct WGPUComputePassTimestampWrites { WGPUQuerySet querySet; uint32_t beginningOfPassWriteIndex; @@ -1391,20 +1371,8 @@ typedef struct WGPUComputePassTimestampWrites { /*.endOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ }) -typedef struct WGPUConstantEntry { - WGPUChainedStruct const * nextInChain; - char const * key; - double value; -} WGPUConstantEntry WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_CONSTANT_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUConstantEntry, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.key=*/{} WGPU_COMMA \ - /*.value=*/{} WGPU_COMMA \ -}) - typedef struct WGPUCopyTextureForBrowserOptions { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool flipY; WGPUBool needsColorSpaceConversion; WGPUAlphaMode srcAlphaMode; @@ -1427,34 +1395,6 @@ typedef struct WGPUCopyTextureForBrowserOptions { /*.internalUsage=*/false WGPU_COMMA \ }) -typedef struct WGPUCreateComputePipelineAsyncCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUCreateComputePipelineAsyncCallback callback; - void * userdata; -} WGPUCreateComputePipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_CREATE_COMPUTE_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateComputePipelineAsyncCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - -typedef struct WGPUCreateRenderPipelineAsyncCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUCreateRenderPipelineAsyncCallback callback; - void * userdata; -} WGPUCreateRenderPipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_CREATE_RENDER_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateRenderPipelineAsyncCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - // Can be chained in WGPUInstanceDescriptor typedef struct WGPUDawnWGSLBlocklist { WGPUChainedStruct chain; @@ -1463,19 +1403,19 @@ typedef struct WGPUDawnWGSLBlocklist { } WGPUDawnWGSLBlocklist WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_WGSL_BLOCKLIST_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnWGSLBlocklist, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnWGSLBlocklist} WGPU_COMMA \ /*.blocklistedFeatureCount=*/0 WGPU_COMMA \ /*.blocklistedFeatures=*/{} WGPU_COMMA \ }) -// Can be chained in WGPUAdapterProperties +// Can be chained in WGPUAdapterInfo typedef struct WGPUDawnAdapterPropertiesPowerPreference { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; WGPUPowerPreference powerPreference; } WGPUDawnAdapterPropertiesPowerPreference WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_ADAPTER_PROPERTIES_POWER_PREFERENCE_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnAdapterPropertiesPowerPreference, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnAdapterPropertiesPowerPreference} WGPU_COMMA \ /*.powerPreference=*/WGPUPowerPreference_Undefined WGPU_COMMA \ }) @@ -1486,36 +1426,18 @@ typedef struct WGPUDawnBufferDescriptorErrorInfoFromWireClient { } WGPUDawnBufferDescriptorErrorInfoFromWireClient WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_BUFFER_DESCRIPTOR_ERROR_INFO_FROM_WIRE_CLIENT_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnBufferDescriptorErrorInfoFromWireClient, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient} WGPU_COMMA \ /*.outOfMemory=*/false WGPU_COMMA \ }) -// Can be chained in WGPUDeviceDescriptor -typedef struct WGPUDawnCacheDeviceDescriptor { - WGPUChainedStruct chain; - char const * isolationKey; - WGPUDawnLoadCacheDataFunction loadDataFunction; - WGPUDawnStoreCacheDataFunction storeDataFunction; - void * functionUserdata; -} WGPUDawnCacheDeviceDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_DAWN_CACHE_DEVICE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnCacheDeviceDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.isolationKey=*/"" WGPU_COMMA \ - /*.loadDataFunction=*/nullptr WGPU_COMMA \ - /*.storeDataFunction=*/nullptr WGPU_COMMA \ - /*.functionUserdata=*/nullptr WGPU_COMMA \ -}) - -// Can be chained in WGPUComputePipelineDescriptor -typedef struct WGPUDawnComputePipelineFullSubgroups { - WGPUChainedStruct chain; - WGPUBool requiresFullSubgroups; -} WGPUDawnComputePipelineFullSubgroups WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUDawnDrmFormatProperties { + uint64_t modifier; + uint32_t modifierPlaneCount; +} WGPUDawnDrmFormatProperties WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_DAWN_COMPUTE_PIPELINE_FULL_SUBGROUPS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnComputePipelineFullSubgroups, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.requiresFullSubgroups=*/false WGPU_COMMA \ +#define WGPU_DAWN_DRM_FORMAT_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnDrmFormatProperties, { \ + /*.modifier=*/{} WGPU_COMMA \ + /*.modifierPlaneCount=*/{} WGPU_COMMA \ }) // Can be chained in WGPUCommandEncoderDescriptor @@ -1525,23 +1447,42 @@ typedef struct WGPUDawnEncoderInternalUsageDescriptor { } WGPUDawnEncoderInternalUsageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_ENCODER_INTERNAL_USAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnEncoderInternalUsageDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnEncoderInternalUsageDescriptor} WGPU_COMMA \ /*.useInternalUsages=*/false WGPU_COMMA \ }) +// Can be chained in WGPUSupportedLimits +typedef struct WGPUDawnExperimentalImmediateDataLimits { + WGPUChainedStruct chain; + uint32_t maxImmediateDataRangeByteSize; +} WGPUDawnExperimentalImmediateDataLimits WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_DAWN_EXPERIMENTAL_IMMEDIATE_DATA_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnExperimentalImmediateDataLimits, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalImmediateDataLimits} WGPU_COMMA \ + /*.maxImmediateDataRangeByteSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ +}) + // Can be chained in WGPUSupportedLimits typedef struct WGPUDawnExperimentalSubgroupLimits { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; uint32_t minSubgroupSize; uint32_t maxSubgroupSize; } WGPUDawnExperimentalSubgroupLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_EXPERIMENTAL_SUBGROUP_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnExperimentalSubgroupLimits, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalSubgroupLimits} WGPU_COMMA \ /*.minSubgroupSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ /*.maxSubgroupSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) +typedef struct WGPUDawnFormatCapabilities { + WGPUChainedStruct* nextInChain; +} WGPUDawnFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_DAWN_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnFormatCapabilities, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ +}) + // Can be chained in WGPURenderPassColorAttachment typedef struct WGPUDawnRenderPassColorAttachmentRenderToSingleSampled { WGPUChainedStruct chain; @@ -1549,7 +1490,7 @@ typedef struct WGPUDawnRenderPassColorAttachmentRenderToSingleSampled { } WGPUDawnRenderPassColorAttachmentRenderToSingleSampled WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_RENDER_PASS_COLOR_ATTACHMENT_RENDER_TO_SINGLE_SAMPLED_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnRenderPassColorAttachmentRenderToSingleSampled, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnRenderPassColorAttachmentRenderToSingleSampled} WGPU_COMMA \ /*.implicitSampleCount=*/1 WGPU_COMMA \ }) @@ -1560,18 +1501,29 @@ typedef struct WGPUDawnShaderModuleSPIRVOptionsDescriptor { } WGPUDawnShaderModuleSPIRVOptionsDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_SHADER_MODULE_SPIRV_OPTIONS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnShaderModuleSPIRVOptionsDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnShaderModuleSPIRVOptionsDescriptor} WGPU_COMMA \ /*.allowNonUniformDerivatives=*/false WGPU_COMMA \ }) +// Can be chained in WGPUSupportedLimits +typedef struct WGPUDawnTexelCopyBufferRowAlignmentLimits { + WGPUChainedStruct chain; + uint32_t minTexelCopyBufferRowAlignment; +} WGPUDawnTexelCopyBufferRowAlignmentLimits WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_DAWN_TEXEL_COPY_BUFFER_ROW_ALIGNMENT_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTexelCopyBufferRowAlignmentLimits, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTexelCopyBufferRowAlignmentLimits} WGPU_COMMA \ + /*.minTexelCopyBufferRowAlignment=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ +}) + // Can be chained in WGPUTextureDescriptor typedef struct WGPUDawnTextureInternalUsageDescriptor { WGPUChainedStruct chain; - WGPUTextureUsageFlags internalUsage; + WGPUTextureUsage internalUsage; } WGPUDawnTextureInternalUsageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_TEXTURE_INTERNAL_USAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTextureInternalUsageDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTextureInternalUsageDescriptor} WGPU_COMMA \ /*.internalUsage=*/WGPUTextureUsage_None WGPU_COMMA \ }) @@ -1587,7 +1539,7 @@ typedef struct WGPUDawnTogglesDescriptor { } WGPUDawnTogglesDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_TOGGLES_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTogglesDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTogglesDescriptor} WGPU_COMMA \ /*.enabledToggleCount=*/0 WGPU_COMMA \ /*.enabledToggles=*/{} WGPU_COMMA \ /*.disabledToggleCount=*/0 WGPU_COMMA \ @@ -1603,47 +1555,12 @@ typedef struct WGPUDawnWireWGSLControl { } WGPUDawnWireWGSLControl WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_WIRE_WGSL_CONTROL_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnWireWGSLControl, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnWireWGSLControl} WGPU_COMMA \ /*.enableExperimental=*/false WGPU_COMMA \ /*.enableUnsafe=*/false WGPU_COMMA \ /*.enableTesting=*/false WGPU_COMMA \ }) -// Can be chained in WGPUDepthStencilState -typedef struct WGPUDepthStencilStateDepthWriteDefinedDawn { - WGPUChainedStruct chain; - WGPUBool depthWriteDefined; -} WGPUDepthStencilStateDepthWriteDefinedDawn WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_DEPTH_STENCIL_STATE_DEPTH_WRITE_DEFINED_DAWN_INIT WGPU_MAKE_INIT_STRUCT(WGPUDepthStencilStateDepthWriteDefinedDawn, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.depthWriteDefined=*/{} WGPU_COMMA \ -}) - -typedef struct WGPUDeviceLostCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUDeviceLostCallbackNew callback; - void * userdata; -} WGPUDeviceLostCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_DEVICE_LOST_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUDeviceLostCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/WGPUCallbackMode_WaitAnyOnly WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUDrmFormatProperties { - uint64_t modifier; - uint32_t modifierPlaneCount; -} WGPUDrmFormatProperties WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_DRM_FORMAT_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDrmFormatProperties, { \ - /*.modifier=*/{} WGPU_COMMA \ - /*.modifierPlaneCount=*/{} WGPU_COMMA \ -}) - typedef struct WGPUExtent2D { uint32_t width; uint32_t height; @@ -1673,7 +1590,7 @@ typedef struct WGPUExternalTextureBindingEntry { } WGPUExternalTextureBindingEntry WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EXTERNAL_TEXTURE_BINDING_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureBindingEntry, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingEntry} WGPU_COMMA \ /*.externalTexture=*/{} WGPU_COMMA \ }) @@ -1683,15 +1600,7 @@ typedef struct WGPUExternalTextureBindingLayout { } WGPUExternalTextureBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EXTERNAL_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureBindingLayout, { \ - /*.chain=*/{} WGPU_COMMA \ -}) - -typedef struct WGPUFormatCapabilities { - WGPUChainedStructOut * nextInChain; -} WGPUFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUFormatCapabilities, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingLayout} WGPU_COMMA \ }) typedef struct WGPUFuture { @@ -1703,7 +1612,7 @@ typedef struct WGPUFuture { }) typedef struct WGPUInstanceFeatures { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool timedWaitAnyEnable; size_t timedWaitAnyMaxCount; } WGPUInstanceFeatures WGPU_STRUCTURE_ATTRIBUTE; @@ -1747,6 +1656,10 @@ typedef struct WGPULimits { uint32_t maxComputeWorkgroupSizeY; uint32_t maxComputeWorkgroupSizeZ; uint32_t maxComputeWorkgroupsPerDimension; + uint32_t maxStorageBuffersInVertexStage; + uint32_t maxStorageTexturesInVertexStage; + uint32_t maxStorageBuffersInFragmentStage; + uint32_t maxStorageTexturesInFragmentStage; } WGPULimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPULimits, { \ @@ -1782,10 +1695,14 @@ typedef struct WGPULimits { /*.maxComputeWorkgroupSizeY=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ /*.maxComputeWorkgroupSizeZ=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ /*.maxComputeWorkgroupsPerDimension=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ + /*.maxStorageBuffersInVertexStage=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ + /*.maxStorageTexturesInVertexStage=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ + /*.maxStorageBuffersInFragmentStage=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ + /*.maxStorageTexturesInFragmentStage=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) typedef struct WGPUMemoryHeapInfo { - WGPUHeapPropertyFlags properties; + WGPUHeapProperty properties; uint64_t size; } WGPUMemoryHeapInfo WGPU_STRUCTURE_ATTRIBUTE; @@ -1795,7 +1712,7 @@ typedef struct WGPUMemoryHeapInfo { }) typedef struct WGPUMultisampleState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; uint32_t count; uint32_t mask; WGPUBool alphaToCoverageEnabled; @@ -1830,65 +1747,23 @@ typedef struct WGPUOrigin3D { /*.z=*/0 WGPU_COMMA \ }) -typedef struct WGPUPipelineLayoutDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - size_t bindGroupLayoutCount; - WGPUBindGroupLayout const * bindGroupLayouts; -} WGPUPipelineLayoutDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_PIPELINE_LAYOUT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.bindGroupLayoutCount=*/{} WGPU_COMMA \ - /*.bindGroupLayouts=*/{} WGPU_COMMA \ -}) - typedef struct WGPUPipelineLayoutStorageAttachment { - WGPUChainedStruct const * nextInChain; uint64_t offset; WGPUTextureFormat format; } WGPUPipelineLayoutStorageAttachment WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PIPELINE_LAYOUT_STORAGE_ATTACHMENT_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutStorageAttachment, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ /*.offset=*/0 WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ }) -typedef struct WGPUPopErrorScopeCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUPopErrorScopeCallback callback; - WGPUErrorCallback oldCallback; - void * userdata; -} WGPUPopErrorScopeCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_POP_ERROR_SCOPE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUPopErrorScopeCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/WGPUCallbackMode_WaitAnyOnly WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.oldCallback=*/{} WGPU_COMMA \ - /*.userdata=*/nullptr WGPU_COMMA \ -}) - -// Can be chained in WGPUPrimitiveState -typedef struct WGPUPrimitiveDepthClipControl { - WGPUChainedStruct chain; - WGPUBool unclippedDepth; -} WGPUPrimitiveDepthClipControl WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_PRIMITIVE_DEPTH_CLIP_CONTROL_INIT WGPU_MAKE_INIT_STRUCT(WGPUPrimitiveDepthClipControl, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.unclippedDepth=*/false WGPU_COMMA \ -}) - typedef struct WGPUPrimitiveState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUPrimitiveTopology topology; WGPUIndexFormat stripIndexFormat; WGPUFrontFace frontFace; WGPUCullMode cullMode; + WGPUBool unclippedDepth; } WGPUPrimitiveState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PRIMITIVE_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUPrimitiveState, { \ @@ -1897,76 +1772,7 @@ typedef struct WGPUPrimitiveState { /*.stripIndexFormat=*/WGPUIndexFormat_Undefined WGPU_COMMA \ /*.frontFace=*/WGPUFrontFace_CCW WGPU_COMMA \ /*.cullMode=*/WGPUCullMode_None WGPU_COMMA \ -}) - -typedef struct WGPUQuerySetDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUQueryType type; - uint32_t count; -} WGPUQuerySetDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_QUERY_SET_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQuerySetDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.type=*/{} WGPU_COMMA \ - /*.count=*/{} WGPU_COMMA \ -}) - -typedef struct WGPUQueueDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUQueueDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_QUEUE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUQueueWorkDoneCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPUQueueWorkDoneCallback callback; - void * userdata; -} WGPUQueueWorkDoneCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_QUEUE_WORK_DONE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueWorkDoneCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - -typedef struct WGPURenderBundleDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPURenderBundleDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_RENDER_BUNDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPURenderBundleEncoderDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - size_t colorFormatCount; - WGPUTextureFormat const * colorFormats; - WGPUTextureFormat depthStencilFormat; - uint32_t sampleCount; - WGPUBool depthReadOnly; - WGPUBool stencilReadOnly; -} WGPURenderBundleEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_RENDER_BUNDLE_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleEncoderDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.colorFormatCount=*/{} WGPU_COMMA \ - /*.colorFormats=*/{} WGPU_COMMA \ - /*.depthStencilFormat=*/WGPUTextureFormat_Undefined WGPU_COMMA \ - /*.sampleCount=*/1 WGPU_COMMA \ - /*.depthReadOnly=*/false WGPU_COMMA \ - /*.stencilReadOnly=*/false WGPU_COMMA \ + /*.unclippedDepth=*/false WGPU_COMMA \ }) typedef struct WGPURenderPassDepthStencilAttachment { @@ -1994,13 +1800,30 @@ typedef struct WGPURenderPassDepthStencilAttachment { }) // Can be chained in WGPURenderPassDescriptor -typedef struct WGPURenderPassDescriptorMaxDrawCount { +typedef struct WGPURenderPassDescriptorExpandResolveRect { + WGPUChainedStruct chain; + uint32_t x; + uint32_t y; + uint32_t width; + uint32_t height; +} WGPURenderPassDescriptorExpandResolveRect WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_RENDER_PASS_DESCRIPTOR_EXPAND_RESOLVE_RECT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDescriptorExpandResolveRect, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassDescriptorExpandResolveRect} WGPU_COMMA \ + /*.x=*/{} WGPU_COMMA \ + /*.y=*/{} WGPU_COMMA \ + /*.width=*/{} WGPU_COMMA \ + /*.height=*/{} WGPU_COMMA \ +}) + +// Can be chained in WGPURenderPassDescriptor +typedef struct WGPURenderPassMaxDrawCount { WGPUChainedStruct chain; uint64_t maxDrawCount; -} WGPURenderPassDescriptorMaxDrawCount WGPU_STRUCTURE_ATTRIBUTE; +} WGPURenderPassMaxDrawCount WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_RENDER_PASS_DESCRIPTOR_MAX_DRAW_COUNT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDescriptorMaxDrawCount, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_RENDER_PASS_MAX_DRAW_COUNT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassMaxDrawCount, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassMaxDrawCount} WGPU_COMMA \ /*.maxDrawCount=*/50000000 WGPU_COMMA \ }) @@ -2016,114 +1839,32 @@ typedef struct WGPURenderPassTimestampWrites { /*.endOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ }) -typedef struct WGPURequestAdapterCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPURequestAdapterCallback callback; - void * userdata; -} WGPURequestAdapterCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_REQUEST_ADAPTER_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - typedef struct WGPURequestAdapterOptions { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPU_NULLABLE WGPUSurface compatibleSurface; + WGPUFeatureLevel featureLevel; WGPUPowerPreference powerPreference; WGPUBackendType backendType; WGPUBool forceFallbackAdapter; - WGPUBool compatibilityMode; } WGPURequestAdapterOptions WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_REQUEST_ADAPTER_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterOptions, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.compatibleSurface=*/nullptr WGPU_COMMA \ - /*.powerPreference=*/WGPUPowerPreference_Undefined WGPU_COMMA \ - /*.backendType=*/WGPUBackendType_Undefined WGPU_COMMA \ - /*.forceFallbackAdapter=*/false WGPU_COMMA \ - /*.compatibilityMode=*/false WGPU_COMMA \ -}) - -typedef struct WGPURequestDeviceCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUCallbackMode mode; - WGPURequestDeviceCallback callback; - void * userdata; -} WGPURequestDeviceCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_REQUEST_DEVICE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestDeviceCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/{} WGPU_COMMA \ - /*.userdata=*/{} WGPU_COMMA \ -}) - -typedef struct WGPUSamplerBindingLayout { - WGPUChainedStruct const * nextInChain; - WGPUSamplerBindingType type; -} WGPUSamplerBindingLayout WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SAMPLER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerBindingLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.type=*/WGPUSamplerBindingType_Undefined WGPU_COMMA \ -}) - -typedef struct WGPUSamplerDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUAddressMode addressModeU; - WGPUAddressMode addressModeV; - WGPUAddressMode addressModeW; - WGPUFilterMode magFilter; - WGPUFilterMode minFilter; - WGPUMipmapFilterMode mipmapFilter; - float lodMinClamp; - float lodMaxClamp; - WGPUCompareFunction compare; - uint16_t maxAnisotropy; -} WGPUSamplerDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SAMPLER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.addressModeU=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ - /*.addressModeV=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ - /*.addressModeW=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ - /*.magFilter=*/WGPUFilterMode_Nearest WGPU_COMMA \ - /*.minFilter=*/WGPUFilterMode_Nearest WGPU_COMMA \ - /*.mipmapFilter=*/WGPUMipmapFilterMode_Nearest WGPU_COMMA \ - /*.lodMinClamp=*/0.0f WGPU_COMMA \ - /*.lodMaxClamp=*/32.0f WGPU_COMMA \ - /*.compare=*/WGPUCompareFunction_Undefined WGPU_COMMA \ - /*.maxAnisotropy=*/1 WGPU_COMMA \ -}) - -// Can be chained in WGPUShaderModuleDescriptor -typedef struct WGPUShaderModuleSPIRVDescriptor { - WGPUChainedStruct chain; - uint32_t codeSize; - uint32_t const * code; -} WGPUShaderModuleSPIRVDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SHADER_MODULE_SPIRV_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleSPIRVDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.codeSize=*/{} WGPU_COMMA \ - /*.code=*/{} WGPU_COMMA \ +#define WGPU_REQUEST_ADAPTER_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterOptions, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.compatibleSurface=*/nullptr WGPU_COMMA \ + /*.featureLevel=*/WGPUFeatureLevel_Core WGPU_COMMA \ + /*.powerPreference=*/WGPUPowerPreference_Undefined WGPU_COMMA \ + /*.backendType=*/WGPUBackendType_Undefined WGPU_COMMA \ + /*.forceFallbackAdapter=*/false WGPU_COMMA \ }) -// Can be chained in WGPUShaderModuleDescriptor -typedef struct WGPUShaderModuleWGSLDescriptor { - WGPUChainedStruct chain; - char const * code; -} WGPUShaderModuleWGSLDescriptor WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUSamplerBindingLayout { + WGPUChainedStruct* nextInChain; + WGPUSamplerBindingType type; +} WGPUSamplerBindingLayout WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHADER_MODULE_WGSL_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleWGSLDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.code=*/{} WGPU_COMMA \ +#define WGPU_SAMPLER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerBindingLayout, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.type=*/WGPUSamplerBindingType_Filtering WGPU_COMMA \ }) // Can be chained in WGPUShaderModuleDescriptor @@ -2133,22 +1874,25 @@ typedef struct WGPUShaderModuleCompilationOptions { } WGPUShaderModuleCompilationOptions WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHADER_MODULE_COMPILATION_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleCompilationOptions, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderModuleCompilationOptions} WGPU_COMMA \ /*.strictMath=*/{} WGPU_COMMA \ }) -typedef struct WGPUShaderModuleDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUShaderModuleDescriptor WGPU_STRUCTURE_ATTRIBUTE; +// Can be chained in WGPUShaderModuleDescriptor +typedef struct WGPUShaderSourceSPIRV { + WGPUChainedStruct chain; + uint32_t codeSize; + uint32_t const * code; +} WGPUShaderSourceSPIRV WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHADER_MODULE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ +#define WGPU_SHADER_SOURCE_SPIRV_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderSourceSPIRV, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceSPIRV} WGPU_COMMA \ + /*.codeSize=*/{} WGPU_COMMA \ + /*.code=*/{} WGPU_COMMA \ }) typedef struct WGPUSharedBufferMemoryBeginAccessDescriptor { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool initialized; size_t fenceCount; WGPUSharedFence const * fences; @@ -2163,18 +1907,8 @@ typedef struct WGPUSharedBufferMemoryBeginAccessDescriptor { /*.signaledValues=*/{} WGPU_COMMA \ }) -typedef struct WGPUSharedBufferMemoryDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUSharedBufferMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SHARED_BUFFER_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - typedef struct WGPUSharedBufferMemoryEndAccessState { - WGPUChainedStructOut * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool initialized; size_t fenceCount; WGPUSharedFence const * fences; @@ -2190,8 +1924,8 @@ typedef struct WGPUSharedBufferMemoryEndAccessState { }) typedef struct WGPUSharedBufferMemoryProperties { - WGPUChainedStructOut * nextInChain; - WGPUBufferUsageFlags usage; + WGPUChainedStruct* nextInChain; + WGPUBufferUsage usage; uint64_t size; } WGPUSharedBufferMemoryProperties WGPU_STRUCTURE_ATTRIBUTE; @@ -2208,18 +1942,18 @@ typedef struct WGPUSharedFenceDXGISharedHandleDescriptor { } WGPUSharedFenceDXGISharedHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_DXGI_SHARED_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDXGISharedHandleDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceExportInfo typedef struct WGPUSharedFenceDXGISharedHandleExportInfo { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; void * handle; } WGPUSharedFenceDXGISharedHandleExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_DXGI_SHARED_HANDLE_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDXGISharedHandleExportInfo, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2230,33 +1964,23 @@ typedef struct WGPUSharedFenceMTLSharedEventDescriptor { } WGPUSharedFenceMTLSharedEventDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_MTL_SHARED_EVENT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceMTLSharedEventDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventDescriptor} WGPU_COMMA \ /*.sharedEvent=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceExportInfo typedef struct WGPUSharedFenceMTLSharedEventExportInfo { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; void * sharedEvent; } WGPUSharedFenceMTLSharedEventExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_MTL_SHARED_EVENT_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceMTLSharedEventExportInfo, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventExportInfo} WGPU_COMMA \ /*.sharedEvent=*/{} WGPU_COMMA \ }) -typedef struct WGPUSharedFenceDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUSharedFenceDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SHARED_FENCE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - typedef struct WGPUSharedFenceExportInfo { - WGPUChainedStructOut * nextInChain; + WGPUChainedStruct* nextInChain; WGPUSharedFenceType type; } WGPUSharedFenceExportInfo WGPU_STRUCTURE_ATTRIBUTE; @@ -2266,46 +1990,46 @@ typedef struct WGPUSharedFenceExportInfo { }) // Can be chained in WGPUSharedFenceDescriptor -typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor { +typedef struct WGPUSharedFenceSyncFDDescriptor { WGPUChainedStruct chain; int handle; -} WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSharedFenceSyncFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SHARED_FENCE_SYNC_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceSyncFDDescriptor, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceExportInfo -typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo { - WGPUChainedStructOut chain; +typedef struct WGPUSharedFenceSyncFDExportInfo { + WGPUChainedStruct chain; int handle; -} WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSharedFenceSyncFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SHARED_FENCE_SYNC_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceSyncFDExportInfo, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceDescriptor -typedef struct WGPUSharedFenceVkSemaphoreSyncFDDescriptor { +typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor { WGPUChainedStruct chain; int handle; -} WGPUSharedFenceVkSemaphoreSyncFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHARED_FENCE_VK_SEMAPHORE_SYNC_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreSyncFDDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceExportInfo -typedef struct WGPUSharedFenceVkSemaphoreSyncFDExportInfo { - WGPUChainedStructOut chain; +typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo { + WGPUChainedStruct chain; int handle; -} WGPUSharedFenceVkSemaphoreSyncFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SHARED_FENCE_VK_SEMAPHORE_SYNC_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreSyncFDExportInfo, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2316,18 +2040,18 @@ typedef struct WGPUSharedFenceVkSemaphoreZirconHandleDescriptor { } WGPUSharedFenceVkSemaphoreZirconHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_ZIRCON_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreZirconHandleDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedFenceExportInfo typedef struct WGPUSharedFenceVkSemaphoreZirconHandleExportInfo { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; uint32_t handle; } WGPUSharedFenceVkSemaphoreZirconHandleExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_ZIRCON_HANDLE_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreZirconHandleExportInfo, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2338,7 +2062,7 @@ typedef struct WGPUSharedTextureMemoryD3DSwapchainBeginState { } WGPUSharedTextureMemoryD3DSwapchainBeginState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_D3D_SWAPCHAIN_BEGIN_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryD3DSwapchainBeginState, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryD3DSwapchainBeginState} WGPU_COMMA \ /*.isSwapchain=*/false WGPU_COMMA \ }) @@ -2350,7 +2074,7 @@ typedef struct WGPUSharedTextureMemoryDXGISharedHandleDescriptor { } WGPUSharedTextureMemoryDXGISharedHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_DXGI_SHARED_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDXGISharedHandleDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDXGISharedHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ /*.useKeyedMutex=*/{} WGPU_COMMA \ }) @@ -2362,7 +2086,7 @@ typedef struct WGPUSharedTextureMemoryEGLImageDescriptor { } WGPUSharedTextureMemoryEGLImageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_EGL_IMAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryEGLImageDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryEGLImageDescriptor} WGPU_COMMA \ /*.image=*/{} WGPU_COMMA \ }) @@ -2373,7 +2097,7 @@ typedef struct WGPUSharedTextureMemoryIOSurfaceDescriptor { } WGPUSharedTextureMemoryIOSurfaceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_IO_SURFACE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryIOSurfaceDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryIOSurfaceDescriptor} WGPU_COMMA \ /*.ioSurface=*/{} WGPU_COMMA \ }) @@ -2385,13 +2109,13 @@ typedef struct WGPUSharedTextureMemoryAHardwareBufferDescriptor { } WGPUSharedTextureMemoryAHardwareBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_A_HARDWARE_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryAHardwareBufferDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ /*.useExternalFormat=*/{} WGPU_COMMA \ }) typedef struct WGPUSharedTextureMemoryBeginAccessDescriptor { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool concurrentRead; WGPUBool initialized; size_t fenceCount; @@ -2408,16 +2132,6 @@ typedef struct WGPUSharedTextureMemoryBeginAccessDescriptor { /*.signaledValues=*/{} WGPU_COMMA \ }) -typedef struct WGPUSharedTextureMemoryDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUSharedTextureMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SHARED_TEXTURE_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ -}) - typedef struct WGPUSharedTextureMemoryDmaBufPlane { int fd; uint64_t offset; @@ -2431,7 +2145,7 @@ typedef struct WGPUSharedTextureMemoryDmaBufPlane { }) typedef struct WGPUSharedTextureMemoryEndAccessState { - WGPUChainedStructOut * nextInChain; + WGPUChainedStruct* nextInChain; WGPUBool initialized; size_t fenceCount; WGPUSharedFence const * fences; @@ -2457,7 +2171,7 @@ typedef struct WGPUSharedTextureMemoryOpaqueFDDescriptor { } WGPUSharedTextureMemoryOpaqueFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_OPAQUE_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryOpaqueFDDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryOpaqueFDDescriptor} WGPU_COMMA \ /*.vkImageCreateInfo=*/{} WGPU_COMMA \ /*.memoryFD=*/{} WGPU_COMMA \ /*.memoryTypeIndex=*/{} WGPU_COMMA \ @@ -2472,7 +2186,7 @@ typedef struct WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor { } WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_DEDICATED_ALLOCATION_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkDedicatedAllocationDescriptor} WGPU_COMMA \ /*.dedicatedAllocation=*/{} WGPU_COMMA \ }) @@ -2484,20 +2198,20 @@ typedef struct WGPUSharedTextureMemoryVkImageLayoutBeginState { } WGPUSharedTextureMemoryVkImageLayoutBeginState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_IMAGE_LAYOUT_BEGIN_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkImageLayoutBeginState, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutBeginState} WGPU_COMMA \ /*.oldLayout=*/{} WGPU_COMMA \ /*.newLayout=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSharedTextureMemoryEndAccessState typedef struct WGPUSharedTextureMemoryVkImageLayoutEndState { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; int32_t oldLayout; int32_t newLayout; } WGPUSharedTextureMemoryVkImageLayoutEndState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_IMAGE_LAYOUT_END_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkImageLayoutEndState, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutEndState} WGPU_COMMA \ /*.oldLayout=*/{} WGPU_COMMA \ /*.newLayout=*/{} WGPU_COMMA \ }) @@ -2510,7 +2224,7 @@ typedef struct WGPUSharedTextureMemoryZirconHandleDescriptor { } WGPUSharedTextureMemoryZirconHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_ZIRCON_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryZirconHandleDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryZirconHandleDescriptor} WGPU_COMMA \ /*.memoryFD=*/{} WGPU_COMMA \ /*.allocationSize=*/{} WGPU_COMMA \ }) @@ -2519,11 +2233,13 @@ typedef struct WGPUSharedTextureMemoryZirconHandleDescriptor { typedef struct WGPUStaticSamplerBindingLayout { WGPUChainedStruct chain; WGPUSampler sampler; + uint32_t sampledTextureBinding; } WGPUStaticSamplerBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_STATIC_SAMPLER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUStaticSamplerBindingLayout, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_StaticSamplerBindingLayout} WGPU_COMMA \ /*.sampler=*/{} WGPU_COMMA \ + /*.sampledTextureBinding=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) typedef struct WGPUStencilFaceState { @@ -2541,7 +2257,7 @@ typedef struct WGPUStencilFaceState { }) typedef struct WGPUStorageTextureBindingLayout { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUStorageTextureAccess access; WGPUTextureFormat format; WGPUTextureViewDimension viewDimension; @@ -2549,14 +2265,44 @@ typedef struct WGPUStorageTextureBindingLayout { #define WGPU_STORAGE_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUStorageTextureBindingLayout, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.access=*/WGPUStorageTextureAccess_Undefined WGPU_COMMA \ + /*.access=*/WGPUStorageTextureAccess_WriteOnly WGPU_COMMA \ /*.format=*/WGPUTextureFormat_Undefined WGPU_COMMA \ /*.viewDimension=*/WGPUTextureViewDimension_2D WGPU_COMMA \ }) +typedef struct WGPUStringView { + WGPU_NULLABLE char const * data; + size_t length; +} WGPUStringView WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_STRING_VIEW_INIT WGPU_MAKE_INIT_STRUCT(WGPUStringView, { \ + /*.data=*/nullptr WGPU_COMMA \ + /*.length=*/WGPU_STRLEN WGPU_COMMA \ +}) + +typedef struct WGPUSupportedWGSLLanguageFeatures { + size_t featureCount; + WGPUWGSLLanguageFeatureName const * features; +} WGPUSupportedWGSLLanguageFeatures WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SUPPORTED_WGSL_LANGUAGE_FEATURES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSupportedWGSLLanguageFeatures, { \ + /*.featureCount=*/{} WGPU_COMMA \ + /*.features=*/{} WGPU_COMMA \ +}) + +typedef struct WGPUSupportedFeatures { + size_t featureCount; + WGPUFeatureName const * features; +} WGPUSupportedFeatures WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SUPPORTED_FEATURES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSupportedFeatures, { \ + /*.featureCount=*/{} WGPU_COMMA \ + /*.features=*/{} WGPU_COMMA \ +}) + typedef struct WGPUSurfaceCapabilities { - WGPUChainedStructOut * nextInChain; - WGPUTextureUsageFlags usages; + WGPUChainedStruct* nextInChain; + WGPUTextureUsage usages; size_t formatCount; WGPUTextureFormat const * formats; size_t presentModeCount; @@ -2577,10 +2323,10 @@ typedef struct WGPUSurfaceCapabilities { }) typedef struct WGPUSurfaceConfiguration { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUDevice device; WGPUTextureFormat format; - WGPUTextureUsageFlags usage; + WGPUTextureUsage usage; size_t viewFormatCount; WGPUTextureFormat const * viewFormats; WGPUCompositeAlphaMode alphaMode; @@ -2595,113 +2341,105 @@ typedef struct WGPUSurfaceConfiguration { /*.format=*/{} WGPU_COMMA \ /*.usage=*/WGPUTextureUsage_RenderAttachment WGPU_COMMA \ /*.viewFormatCount=*/0 WGPU_COMMA \ - /*.viewFormats=*/{} WGPU_COMMA \ + /*.viewFormats=*/nullptr WGPU_COMMA \ /*.alphaMode=*/WGPUCompositeAlphaMode_Auto WGPU_COMMA \ /*.width=*/{} WGPU_COMMA \ /*.height=*/{} WGPU_COMMA \ /*.presentMode=*/WGPUPresentMode_Fifo WGPU_COMMA \ }) -typedef struct WGPUSurfaceDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; -} WGPUSurfaceDescriptor WGPU_STRUCTURE_ATTRIBUTE; +// Can be chained in WGPUSurfaceDescriptor +typedef struct WGPUSurfaceDescriptorFromWindowsCoreWindow { + WGPUChainedStruct chain; + void * coreWindow; +} WGPUSurfaceDescriptorFromWindowsCoreWindow WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ +#define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_CORE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsCoreWindow, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsCoreWindow} WGPU_COMMA \ + /*.coreWindow=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromAndroidNativeWindow { +typedef struct WGPUSurfaceDescriptorFromWindowsSwapChainPanel { WGPUChainedStruct chain; - void * window; -} WGPUSurfaceDescriptorFromAndroidNativeWindow WGPU_STRUCTURE_ATTRIBUTE; + void * swapChainPanel; +} WGPUSurfaceDescriptorFromWindowsSwapChainPanel WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_SWAP_CHAIN_PANEL_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsSwapChainPanel, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsSwapChainPanel} WGPU_COMMA \ + /*.swapChainPanel=*/{} WGPU_COMMA \ +}) + +// Can be chained in WGPUSurfaceDescriptor +typedef struct WGPUSurfaceSourceXCBWindow { + WGPUChainedStruct chain; + void * connection; + uint32_t window; +} WGPUSurfaceSourceXCBWindow WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_ANDROID_NATIVE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromAndroidNativeWindow, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_XCB_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceXCBWindow, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXCBWindow} WGPU_COMMA \ + /*.connection=*/{} WGPU_COMMA \ /*.window=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromCanvasHTMLSelector { +typedef struct WGPUSurfaceSourceAndroidNativeWindow { WGPUChainedStruct chain; - char const * selector; -} WGPUSurfaceDescriptorFromCanvasHTMLSelector WGPU_STRUCTURE_ATTRIBUTE; + void * window; +} WGPUSurfaceSourceAndroidNativeWindow WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_CANVAS_HTML_SELECTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromCanvasHTMLSelector, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.selector=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_ANDROID_NATIVE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceAndroidNativeWindow, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceAndroidNativeWindow} WGPU_COMMA \ + /*.window=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromMetalLayer { +typedef struct WGPUSurfaceSourceMetalLayer { WGPUChainedStruct chain; void * layer; -} WGPUSurfaceDescriptorFromMetalLayer WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSurfaceSourceMetalLayer WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_METAL_LAYER_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromMetalLayer, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_METAL_LAYER_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceMetalLayer, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceMetalLayer} WGPU_COMMA \ /*.layer=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromWaylandSurface { +typedef struct WGPUSurfaceSourceWaylandSurface { WGPUChainedStruct chain; void * display; void * surface; -} WGPUSurfaceDescriptorFromWaylandSurface WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSurfaceSourceWaylandSurface WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_WAYLAND_SURFACE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWaylandSurface, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_WAYLAND_SURFACE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceWaylandSurface, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWaylandSurface} WGPU_COMMA \ /*.display=*/{} WGPU_COMMA \ /*.surface=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromWindowsHWND { +typedef struct WGPUSurfaceSourceWindowsHWND { WGPUChainedStruct chain; void * hinstance; void * hwnd; -} WGPUSurfaceDescriptorFromWindowsHWND WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSurfaceSourceWindowsHWND WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_HWND_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsHWND, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_WINDOWS_HWND_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceWindowsHWND, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWindowsHWND} WGPU_COMMA \ /*.hinstance=*/{} WGPU_COMMA \ /*.hwnd=*/{} WGPU_COMMA \ }) // Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromWindowsCoreWindow { - WGPUChainedStruct chain; - void * coreWindow; -} WGPUSurfaceDescriptorFromWindowsCoreWindow WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_CORE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsCoreWindow, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.coreWindow=*/{} WGPU_COMMA \ -}) - -// Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromWindowsSwapChainPanel { - WGPUChainedStruct chain; - void * swapChainPanel; -} WGPUSurfaceDescriptorFromWindowsSwapChainPanel WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_SWAP_CHAIN_PANEL_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsSwapChainPanel, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.swapChainPanel=*/{} WGPU_COMMA \ -}) - -// Can be chained in WGPUSurfaceDescriptor -typedef struct WGPUSurfaceDescriptorFromXlibWindow { +typedef struct WGPUSurfaceSourceXlibWindow { WGPUChainedStruct chain; void * display; uint64_t window; -} WGPUSurfaceDescriptorFromXlibWindow WGPU_STRUCTURE_ATTRIBUTE; +} WGPUSurfaceSourceXlibWindow WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_SURFACE_DESCRIPTOR_FROM_XLIB_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromXlibWindow, { \ - /*.chain=*/{} WGPU_COMMA \ +#define WGPU_SURFACE_SOURCE_XLIB_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceXlibWindow, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXlibWindow} WGPU_COMMA \ /*.display=*/{} WGPU_COMMA \ /*.window=*/{} WGPU_COMMA \ }) @@ -2718,28 +2456,8 @@ typedef struct WGPUSurfaceTexture { /*.status=*/{} WGPU_COMMA \ }) -typedef struct WGPUSwapChainDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUTextureUsageFlags usage; - WGPUTextureFormat format; - uint32_t width; - uint32_t height; - WGPUPresentMode presentMode; -} WGPUSwapChainDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_SWAP_CHAIN_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSwapChainDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.usage=*/{} WGPU_COMMA \ - /*.format=*/{} WGPU_COMMA \ - /*.width=*/{} WGPU_COMMA \ - /*.height=*/{} WGPU_COMMA \ - /*.presentMode=*/{} WGPU_COMMA \ -}) - typedef struct WGPUTextureBindingLayout { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUTextureSampleType sampleType; WGPUTextureViewDimension viewDimension; WGPUBool multisampled; @@ -2747,7 +2465,7 @@ typedef struct WGPUTextureBindingLayout { #define WGPU_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureBindingLayout, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.sampleType=*/WGPUTextureSampleType_Undefined WGPU_COMMA \ + /*.sampleType=*/WGPUTextureSampleType_Float WGPU_COMMA \ /*.viewDimension=*/WGPUTextureViewDimension_2D WGPU_COMMA \ /*.multisampled=*/false WGPU_COMMA \ }) @@ -2759,12 +2477,12 @@ typedef struct WGPUTextureBindingViewDimensionDescriptor { } WGPUTextureBindingViewDimensionDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_BINDING_VIEW_DIMENSION_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureBindingViewDimensionDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_TextureBindingViewDimensionDescriptor} WGPU_COMMA \ /*.textureBindingViewDimension=*/WGPUTextureViewDimension_Undefined WGPU_COMMA \ }) typedef struct WGPUTextureDataLayout { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; uint64_t offset; uint32_t bytesPerRow; uint32_t rowsPerImage; @@ -2777,42 +2495,6 @@ typedef struct WGPUTextureDataLayout { /*.rowsPerImage=*/WGPU_COPY_STRIDE_UNDEFINED WGPU_COMMA \ }) -typedef struct WGPUTextureViewDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUTextureFormat format; - WGPUTextureViewDimension dimension; - uint32_t baseMipLevel; - uint32_t mipLevelCount; - uint32_t baseArrayLayer; - uint32_t arrayLayerCount; - WGPUTextureAspect aspect; -} WGPUTextureViewDescriptor WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_TEXTURE_VIEW_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureViewDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.format=*/WGPUTextureFormat_Undefined WGPU_COMMA \ - /*.dimension=*/WGPUTextureViewDimension_Undefined WGPU_COMMA \ - /*.baseMipLevel=*/0 WGPU_COMMA \ - /*.mipLevelCount=*/WGPU_MIP_LEVEL_COUNT_UNDEFINED WGPU_COMMA \ - /*.baseArrayLayer=*/0 WGPU_COMMA \ - /*.arrayLayerCount=*/WGPU_ARRAY_LAYER_COUNT_UNDEFINED WGPU_COMMA \ - /*.aspect=*/WGPUTextureAspect_All WGPU_COMMA \ -}) - -typedef struct WGPUUncapturedErrorCallbackInfo { - WGPUChainedStruct const * nextInChain; - WGPUErrorCallback callback; - void * userdata; -} WGPUUncapturedErrorCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_UNCAPTURED_ERROR_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUUncapturedErrorCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata=*/nullptr WGPU_COMMA \ -}) - typedef struct WGPUVertexAttribute { WGPUVertexFormat format; uint64_t offset; @@ -2844,7 +2526,7 @@ typedef struct WGPUYCbCrVkDescriptor { } WGPUYCbCrVkDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_Y_CB_CR_VK_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUYCbCrVkDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_YCbCrVkDescriptor} WGPU_COMMA \ /*.vkFormat=*/0 WGPU_COMMA \ /*.vkYCbCrModel=*/0 WGPU_COMMA \ /*.vkYCbCrRange=*/0 WGPU_COMMA \ @@ -2867,22 +2549,48 @@ typedef struct WGPUAHardwareBufferProperties { /*.yCbCrInfo=*/WGPU_Y_CB_CR_VK_DESCRIPTOR_INIT WGPU_COMMA \ }) -// Can be chained in WGPUAdapterProperties +typedef struct WGPUAdapterInfo { + WGPUChainedStruct* nextInChain; + WGPUStringView vendor; + WGPUStringView architecture; + WGPUStringView device; + WGPUStringView description; + WGPUBackendType backendType; + WGPUAdapterType adapterType; + uint32_t vendorID; + uint32_t deviceID; + WGPUBool compatibilityMode; +} WGPUAdapterInfo WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_ADAPTER_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterInfo, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.vendor=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.architecture=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.device=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.description=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.backendType=*/{} WGPU_COMMA \ + /*.adapterType=*/{} WGPU_COMMA \ + /*.vendorID=*/{} WGPU_COMMA \ + /*.deviceID=*/{} WGPU_COMMA \ + /*.compatibilityMode=*/false WGPU_COMMA \ +}) + +// Can be chained in WGPUAdapterInfo typedef struct WGPUAdapterPropertiesMemoryHeaps { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; size_t heapCount; WGPUMemoryHeapInfo const * heapInfo; } WGPUAdapterPropertiesMemoryHeaps WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_PROPERTIES_MEMORY_HEAPS_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesMemoryHeaps, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesMemoryHeaps} WGPU_COMMA \ /*.heapCount=*/{} WGPU_COMMA \ /*.heapInfo=*/{} WGPU_COMMA \ }) typedef struct WGPUBindGroupDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; WGPUBindGroupLayout layout; size_t entryCount; WGPUBindGroupEntry const * entries; @@ -2890,16 +2598,16 @@ typedef struct WGPUBindGroupDescriptor { #define WGPU_BIND_GROUP_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.layout=*/{} WGPU_COMMA \ /*.entryCount=*/{} WGPU_COMMA \ /*.entries=*/{} WGPU_COMMA \ }) typedef struct WGPUBindGroupLayoutEntry { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; uint32_t binding; - WGPUShaderStageFlags visibility; + WGPUShaderStage visibility; WGPUBufferBindingLayout buffer; WGPUSamplerBindingLayout sampler; WGPUTextureBindingLayout texture; @@ -2926,34 +2634,126 @@ typedef struct WGPUBlendState { /*.alpha=*/WGPU_BLEND_COMPONENT_INIT WGPU_COMMA \ }) -typedef struct WGPUCompilationInfo { - WGPUChainedStruct const * nextInChain; - size_t messageCount; - WGPUCompilationMessage const * messages; -} WGPUCompilationInfo WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUBufferDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPUBufferUsage usage; + uint64_t size; + WGPUBool mappedAtCreation; +} WGPUBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_COMPILATION_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfo, { \ +#define WGPU_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.messageCount=*/{} WGPU_COMMA \ - /*.messages=*/{} WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.usage=*/{} WGPU_COMMA \ + /*.size=*/{} WGPU_COMMA \ + /*.mappedAtCreation=*/false WGPU_COMMA \ +}) + +typedef struct WGPUCommandBufferDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUCommandBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_COMMAND_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandBufferDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPUCommandEncoderDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUCommandEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_COMMAND_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandEncoderDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPUCompilationMessage { + WGPUChainedStruct* nextInChain; + WGPUStringView message; + WGPUCompilationMessageType type; + uint64_t lineNum; + uint64_t linePos; + uint64_t offset; + uint64_t length; + uint64_t utf16LinePos; + uint64_t utf16Offset; + uint64_t utf16Length; +} WGPUCompilationMessage WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_COMPILATION_MESSAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationMessage, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.message=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.type=*/{} WGPU_COMMA \ + /*.lineNum=*/{} WGPU_COMMA \ + /*.linePos=*/{} WGPU_COMMA \ + /*.offset=*/{} WGPU_COMMA \ + /*.length=*/{} WGPU_COMMA \ + /*.utf16LinePos=*/{} WGPU_COMMA \ + /*.utf16Offset=*/{} WGPU_COMMA \ + /*.utf16Length=*/{} WGPU_COMMA \ }) typedef struct WGPUComputePassDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; WGPU_NULLABLE WGPUComputePassTimestampWrites const * timestampWrites; } WGPUComputePassDescriptor WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_COMPUTE_PASS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePassDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.timestampWrites=*/nullptr WGPU_COMMA \ +#define WGPU_COMPUTE_PASS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePassDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.timestampWrites=*/nullptr WGPU_COMMA \ +}) + +typedef struct WGPUConstantEntry { + WGPUChainedStruct* nextInChain; + WGPUStringView key; + double value; +} WGPUConstantEntry WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_CONSTANT_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUConstantEntry, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.key=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.value=*/{} WGPU_COMMA \ +}) + +// Can be chained in WGPUDeviceDescriptor +typedef struct WGPUDawnCacheDeviceDescriptor { + WGPUChainedStruct chain; + WGPUStringView isolationKey; + WGPUDawnLoadCacheDataFunction loadDataFunction; + WGPUDawnStoreCacheDataFunction storeDataFunction; + void * functionUserdata; +} WGPUDawnCacheDeviceDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_DAWN_CACHE_DEVICE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnCacheDeviceDescriptor, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnCacheDeviceDescriptor} WGPU_COMMA \ + /*.isolationKey=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.loadDataFunction=*/nullptr WGPU_COMMA \ + /*.storeDataFunction=*/nullptr WGPU_COMMA \ + /*.functionUserdata=*/nullptr WGPU_COMMA \ +}) + +// Can be chained in WGPUDawnFormatCapabilities +typedef struct WGPUDawnDrmFormatCapabilities { + WGPUChainedStruct chain; + size_t propertiesCount; + WGPUDawnDrmFormatProperties const * properties; +} WGPUDawnDrmFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_DAWN_DRM_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnDrmFormatCapabilities, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnDrmFormatCapabilities} WGPU_COMMA \ + /*.propertiesCount=*/{} WGPU_COMMA \ + /*.properties=*/{} WGPU_COMMA \ }) typedef struct WGPUDepthStencilState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUTextureFormat format; - WGPUBool depthWriteEnabled; + WGPUOptionalBool depthWriteEnabled; WGPUCompareFunction depthCompare; WGPUStencilFaceState stencilFront; WGPUStencilFaceState stencilBack; @@ -2967,7 +2767,7 @@ typedef struct WGPUDepthStencilState { #define WGPU_DEPTH_STENCIL_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUDepthStencilState, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ - /*.depthWriteEnabled=*/false WGPU_COMMA \ + /*.depthWriteEnabled=*/WGPUOptionalBool_Undefined WGPU_COMMA \ /*.depthCompare=*/WGPUCompareFunction_Undefined WGPU_COMMA \ /*.stencilFront=*/WGPU_STENCIL_FACE_STATE_INIT WGPU_COMMA \ /*.stencilBack=*/WGPU_STENCIL_FACE_STATE_INIT WGPU_COMMA \ @@ -2978,26 +2778,25 @@ typedef struct WGPUDepthStencilState { /*.depthBiasClamp=*/0.0f WGPU_COMMA \ }) -// Can be chained in WGPUFormatCapabilities -typedef struct WGPUDrmFormatCapabilities { - WGPUChainedStructOut chain; - size_t propertiesCount; - WGPUDrmFormatProperties const * properties; -} WGPUDrmFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; +// Can be chained in WGPUSurfaceDescriptor +typedef struct WGPUEmscriptenSurfaceSourceCanvasHTMLSelector { + WGPUChainedStruct chain; + WGPUStringView selector; +} WGPUEmscriptenSurfaceSourceCanvasHTMLSelector WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_DRM_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDrmFormatCapabilities, { \ - /*.chain=*/{} WGPU_COMMA \ - /*.propertiesCount=*/{} WGPU_COMMA \ - /*.properties=*/{} WGPU_COMMA \ +#define WGPU_EMSCRIPTEN_SURFACE_SOURCE_CANVAS_HTML_SELECTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUEmscriptenSurfaceSourceCanvasHTMLSelector, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_EmscriptenSurfaceSourceCanvasHTMLSelector} WGPU_COMMA \ + /*.selector=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) typedef struct WGPUExternalTextureDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; WGPUTextureView plane0; WGPU_NULLABLE WGPUTextureView plane1; - WGPUOrigin2D visibleOrigin; - WGPUExtent2D visibleSize; + WGPUOrigin2D cropOrigin; + WGPUExtent2D cropSize; + WGPUExtent2D apparentSize; WGPUBool doYuvToRgbConversionOnly; WGPU_NULLABLE float const * yuvToRgbConversionMatrix; float const * srcTransferFunctionParameters; @@ -3009,11 +2808,12 @@ typedef struct WGPUExternalTextureDescriptor { #define WGPU_EXTERNAL_TEXTURE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.plane0=*/{} WGPU_COMMA \ /*.plane1=*/nullptr WGPU_COMMA \ - /*.visibleOrigin=*/WGPU_ORIGIN_2D_INIT WGPU_COMMA \ - /*.visibleSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ + /*.cropOrigin=*/WGPU_ORIGIN_2D_INIT WGPU_COMMA \ + /*.cropSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ + /*.apparentSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ /*.doYuvToRgbConversionOnly=*/false WGPU_COMMA \ /*.yuvToRgbConversionMatrix=*/nullptr WGPU_COMMA \ /*.srcTransferFunctionParameters=*/{} WGPU_COMMA \ @@ -3044,7 +2844,7 @@ typedef struct WGPUImageCopyBuffer { }) typedef struct WGPUImageCopyExternalTexture { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUExternalTexture externalTexture; WGPUOrigin3D origin; WGPUExtent2D naturalSize; @@ -3072,7 +2872,7 @@ typedef struct WGPUImageCopyTexture { }) typedef struct WGPUInstanceDescriptor { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUInstanceFeatures features; } WGPUInstanceDescriptor WGPU_STRUCTURE_ATTRIBUTE; @@ -3081,6 +2881,22 @@ typedef struct WGPUInstanceDescriptor { /*.features=*/WGPU_INSTANCE_FEATURES_INIT WGPU_COMMA \ }) +typedef struct WGPUPipelineLayoutDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + size_t bindGroupLayoutCount; + WGPU_NULLABLE WGPUBindGroupLayout const * bindGroupLayouts; + uint32_t immediateDataRangeByteSize; +} WGPUPipelineLayoutDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_PIPELINE_LAYOUT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.bindGroupLayoutCount=*/{} WGPU_COMMA \ + /*.bindGroupLayouts=*/nullptr WGPU_COMMA \ + /*.immediateDataRangeByteSize=*/0 WGPU_COMMA \ +}) + // Can be chained in WGPUPipelineLayoutDescriptor typedef struct WGPUPipelineLayoutPixelLocalStorage { WGPUChainedStruct chain; @@ -3090,30 +2906,70 @@ typedef struct WGPUPipelineLayoutPixelLocalStorage { } WGPUPipelineLayoutPixelLocalStorage WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PIPELINE_LAYOUT_PIXEL_LOCAL_STORAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutPixelLocalStorage, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_PipelineLayoutPixelLocalStorage} WGPU_COMMA \ /*.totalPixelLocalStorageSize=*/{} WGPU_COMMA \ /*.storageAttachmentCount=*/0 WGPU_COMMA \ /*.storageAttachments=*/{} WGPU_COMMA \ }) -typedef struct WGPUProgrammableStageDescriptor { - WGPUChainedStruct const * nextInChain; - WGPUShaderModule module; - WGPU_NULLABLE char const * entryPoint; - size_t constantCount; - WGPUConstantEntry const * constants; -} WGPUProgrammableStageDescriptor WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUQuerySetDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPUQueryType type; + uint32_t count; +} WGPUQuerySetDescriptor WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_PROGRAMMABLE_STAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUProgrammableStageDescriptor, { \ +#define WGPU_QUERY_SET_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQuerySetDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.module=*/{} WGPU_COMMA \ - /*.entryPoint=*/nullptr WGPU_COMMA \ - /*.constantCount=*/0 WGPU_COMMA \ - /*.constants=*/{} WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.type=*/{} WGPU_COMMA \ + /*.count=*/{} WGPU_COMMA \ +}) + +typedef struct WGPUQueueDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUQueueDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_QUEUE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPURenderBundleDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPURenderBundleDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_RENDER_BUNDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPURenderBundleEncoderDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + size_t colorFormatCount; + WGPUTextureFormat const * colorFormats; + WGPUTextureFormat depthStencilFormat; + uint32_t sampleCount; + WGPUBool depthReadOnly; + WGPUBool stencilReadOnly; +} WGPURenderBundleEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_RENDER_BUNDLE_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleEncoderDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.colorFormatCount=*/{} WGPU_COMMA \ + /*.colorFormats=*/{} WGPU_COMMA \ + /*.depthStencilFormat=*/WGPUTextureFormat_Undefined WGPU_COMMA \ + /*.sampleCount=*/1 WGPU_COMMA \ + /*.depthReadOnly=*/false WGPU_COMMA \ + /*.stencilReadOnly=*/false WGPU_COMMA \ }) typedef struct WGPURenderPassColorAttachment { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPU_NULLABLE WGPUTextureView view; uint32_t depthSlice; WGPU_NULLABLE WGPUTextureView resolveTarget; @@ -3133,7 +2989,7 @@ typedef struct WGPURenderPassColorAttachment { }) typedef struct WGPURenderPassStorageAttachment { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; uint64_t offset; WGPUTextureView storage; WGPULoadOp loadOp; @@ -3151,7 +3007,7 @@ typedef struct WGPURenderPassStorageAttachment { }) typedef struct WGPURequiredLimits { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPULimits limits; } WGPURequiredLimits WGPU_STRUCTURE_ATTRIBUTE; @@ -3160,17 +3016,98 @@ typedef struct WGPURequiredLimits { /*.limits=*/WGPU_LIMITS_INIT WGPU_COMMA \ }) +typedef struct WGPUSamplerDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPUAddressMode addressModeU; + WGPUAddressMode addressModeV; + WGPUAddressMode addressModeW; + WGPUFilterMode magFilter; + WGPUFilterMode minFilter; + WGPUMipmapFilterMode mipmapFilter; + float lodMinClamp; + float lodMaxClamp; + WGPUCompareFunction compare; + uint16_t maxAnisotropy; +} WGPUSamplerDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SAMPLER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.addressModeU=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ + /*.addressModeV=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ + /*.addressModeW=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ + /*.magFilter=*/WGPUFilterMode_Nearest WGPU_COMMA \ + /*.minFilter=*/WGPUFilterMode_Nearest WGPU_COMMA \ + /*.mipmapFilter=*/WGPUMipmapFilterMode_Nearest WGPU_COMMA \ + /*.lodMinClamp=*/0.0f WGPU_COMMA \ + /*.lodMaxClamp=*/32.0f WGPU_COMMA \ + /*.compare=*/WGPUCompareFunction_Undefined WGPU_COMMA \ + /*.maxAnisotropy=*/1 WGPU_COMMA \ +}) + +typedef struct WGPUShaderModuleDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUShaderModuleDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SHADER_MODULE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +// Can be chained in WGPUShaderModuleDescriptor +typedef struct WGPUShaderSourceWGSL { + WGPUChainedStruct chain; + WGPUStringView code; +} WGPUShaderSourceWGSL WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SHADER_SOURCE_WGSL_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderSourceWGSL, { \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceWGSL} WGPU_COMMA \ + /*.code=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPUSharedBufferMemoryDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUSharedBufferMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SHARED_BUFFER_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + +typedef struct WGPUSharedFenceDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUSharedFenceDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SHARED_FENCE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + // Can be chained in WGPUSharedTextureMemoryProperties typedef struct WGPUSharedTextureMemoryAHardwareBufferProperties { - WGPUChainedStructOut chain; + WGPUChainedStruct chain; WGPUYCbCrVkDescriptor yCbCrInfo; } WGPUSharedTextureMemoryAHardwareBufferProperties WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_A_HARDWARE_BUFFER_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryAHardwareBufferProperties, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferProperties} WGPU_COMMA \ /*.yCbCrInfo=*/WGPU_Y_CB_CR_VK_DESCRIPTOR_INIT WGPU_COMMA \ }) +typedef struct WGPUSharedTextureMemoryDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUSharedTextureMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SHARED_TEXTURE_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + // Can be chained in WGPUSharedTextureMemoryDescriptor typedef struct WGPUSharedTextureMemoryDmaBufDescriptor { WGPUChainedStruct chain; @@ -3182,7 +3119,7 @@ typedef struct WGPUSharedTextureMemoryDmaBufDescriptor { } WGPUSharedTextureMemoryDmaBufDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_DMA_BUF_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDmaBufDescriptor, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDmaBufDescriptor} WGPU_COMMA \ /*.size=*/WGPU_EXTENT_3D_INIT WGPU_COMMA \ /*.drmFormat=*/{} WGPU_COMMA \ /*.drmModifier=*/{} WGPU_COMMA \ @@ -3191,8 +3128,8 @@ typedef struct WGPUSharedTextureMemoryDmaBufDescriptor { }) typedef struct WGPUSharedTextureMemoryProperties { - WGPUChainedStructOut * nextInChain; - WGPUTextureUsageFlags usage; + WGPUChainedStruct* nextInChain; + WGPUTextureUsage usage; WGPUExtent3D size; WGPUTextureFormat format; } WGPUSharedTextureMemoryProperties WGPU_STRUCTURE_ATTRIBUTE; @@ -3205,7 +3142,7 @@ typedef struct WGPUSharedTextureMemoryProperties { }) typedef struct WGPUSupportedLimits { - WGPUChainedStructOut * nextInChain; + WGPUChainedStruct* nextInChain; WGPULimits limits; } WGPUSupportedLimits WGPU_STRUCTURE_ATTRIBUTE; @@ -3214,10 +3151,20 @@ typedef struct WGPUSupportedLimits { /*.limits=*/WGPU_LIMITS_INIT WGPU_COMMA \ }) +typedef struct WGPUSurfaceDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; +} WGPUSurfaceDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_SURFACE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ +}) + typedef struct WGPUTextureDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPUTextureUsageFlags usage; + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPUTextureUsage usage; WGPUTextureDimension dimension; WGPUExtent3D size; WGPUTextureFormat format; @@ -3229,7 +3176,7 @@ typedef struct WGPUTextureDescriptor { #define WGPU_TEXTURE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.usage=*/{} WGPU_COMMA \ /*.dimension=*/WGPUTextureDimension_2D WGPU_COMMA \ /*.size=*/WGPU_EXTENT_3D_INIT WGPU_COMMA \ @@ -3237,7 +3184,33 @@ typedef struct WGPUTextureDescriptor { /*.mipLevelCount=*/1 WGPU_COMMA \ /*.sampleCount=*/1 WGPU_COMMA \ /*.viewFormatCount=*/0 WGPU_COMMA \ - /*.viewFormats=*/{} WGPU_COMMA \ + /*.viewFormats=*/nullptr WGPU_COMMA \ +}) + +typedef struct WGPUTextureViewDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPUTextureFormat format; + WGPUTextureViewDimension dimension; + uint32_t baseMipLevel; + uint32_t mipLevelCount; + uint32_t baseArrayLayer; + uint32_t arrayLayerCount; + WGPUTextureAspect aspect; + WGPUTextureUsage usage; +} WGPUTextureViewDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_TEXTURE_VIEW_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureViewDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.format=*/WGPUTextureFormat_Undefined WGPU_COMMA \ + /*.dimension=*/WGPUTextureViewDimension_Undefined WGPU_COMMA \ + /*.baseMipLevel=*/0 WGPU_COMMA \ + /*.mipLevelCount=*/WGPU_MIP_LEVEL_COUNT_UNDEFINED WGPU_COMMA \ + /*.baseArrayLayer=*/0 WGPU_COMMA \ + /*.arrayLayerCount=*/WGPU_ARRAY_LAYER_COUNT_UNDEFINED WGPU_COMMA \ + /*.aspect=*/WGPUTextureAspect_All WGPU_COMMA \ + /*.usage=*/WGPUTextureUsage_None WGPU_COMMA \ }) typedef struct WGPUVertexBufferLayout { @@ -3249,30 +3222,30 @@ typedef struct WGPUVertexBufferLayout { #define WGPU_VERTEX_BUFFER_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUVertexBufferLayout, { \ /*.arrayStride=*/{} WGPU_COMMA \ - /*.stepMode=*/WGPUVertexStepMode_Vertex WGPU_COMMA \ + /*.stepMode=*/{} WGPU_COMMA \ /*.attributeCount=*/{} WGPU_COMMA \ /*.attributes=*/{} WGPU_COMMA \ }) typedef struct WGPUBindGroupLayoutDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; size_t entryCount; WGPUBindGroupLayoutEntry const * entries; } WGPUBindGroupLayoutDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BIND_GROUP_LAYOUT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupLayoutDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.entryCount=*/{} WGPU_COMMA \ /*.entries=*/{} WGPU_COMMA \ }) typedef struct WGPUColorTargetState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUTextureFormat format; WGPU_NULLABLE WGPUBlendState const * blend; - WGPUColorWriteMaskFlags writeMask; + WGPUColorWriteMask writeMask; } WGPUColorTargetState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COLOR_TARGET_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUColorTargetState, { \ @@ -3282,49 +3255,59 @@ typedef struct WGPUColorTargetState { /*.writeMask=*/WGPUColorWriteMask_All WGPU_COMMA \ }) -typedef struct WGPUComputePipelineDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; - WGPU_NULLABLE WGPUPipelineLayout layout; - WGPUProgrammableStageDescriptor compute; -} WGPUComputePipelineDescriptor WGPU_STRUCTURE_ATTRIBUTE; +typedef struct WGPUCompilationInfo { + WGPUChainedStruct* nextInChain; + size_t messageCount; + WGPUCompilationMessage const * messages; +} WGPUCompilationInfo WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_COMPUTE_PIPELINE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePipelineDescriptor, { \ +#define WGPU_COMPILATION_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfo, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ - /*.layout=*/nullptr WGPU_COMMA \ - /*.compute=*/WGPU_PROGRAMMABLE_STAGE_DESCRIPTOR_INIT WGPU_COMMA \ + /*.messageCount=*/{} WGPU_COMMA \ + /*.messages=*/{} WGPU_COMMA \ +}) + +typedef struct WGPUComputeState { + WGPUChainedStruct* nextInChain; + WGPUShaderModule module; + WGPUStringView entryPoint; + size_t constantCount; + WGPUConstantEntry const * constants; +} WGPUComputeState WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_COMPUTE_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputeState, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.module=*/{} WGPU_COMMA \ + /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.constantCount=*/0 WGPU_COMMA \ + /*.constants=*/{} WGPU_COMMA \ }) typedef struct WGPUDeviceDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; size_t requiredFeatureCount; WGPUFeatureName const * requiredFeatures; WGPU_NULLABLE WGPURequiredLimits const * requiredLimits; WGPUQueueDescriptor defaultQueue; - WGPUDeviceLostCallback deviceLostCallback; - void * deviceLostUserdata; WGPUDeviceLostCallbackInfo deviceLostCallbackInfo; WGPUUncapturedErrorCallbackInfo uncapturedErrorCallbackInfo; } WGPUDeviceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DEVICE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDeviceDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.requiredFeatureCount=*/0 WGPU_COMMA \ /*.requiredFeatures=*/nullptr WGPU_COMMA \ /*.requiredLimits=*/nullptr WGPU_COMMA \ /*.defaultQueue=*/WGPU_QUEUE_DESCRIPTOR_INIT WGPU_COMMA \ - /*.deviceLostCallback=*/nullptr WGPU_COMMA \ - /*.deviceLostUserdata=*/nullptr WGPU_COMMA \ - /*.deviceLostCallbackInfo=*/WGPU_DEVICE_LOST_CALLBACK_INFO_INIT WGPU_COMMA \ - /*.uncapturedErrorCallbackInfo=*/WGPU_UNCAPTURED_ERROR_CALLBACK_INFO_INIT WGPU_COMMA \ + /*.deviceLostCallbackInfo=*/{} WGPU_COMMA \ + /*.uncapturedErrorCallbackInfo=*/{} WGPU_COMMA \ }) typedef struct WGPURenderPassDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; size_t colorAttachmentCount; WGPURenderPassColorAttachment const * colorAttachments; WGPU_NULLABLE WGPURenderPassDepthStencilAttachment const * depthStencilAttachment; @@ -3334,7 +3317,7 @@ typedef struct WGPURenderPassDescriptor { #define WGPU_RENDER_PASS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.colorAttachmentCount=*/{} WGPU_COMMA \ /*.colorAttachments=*/{} WGPU_COMMA \ /*.depthStencilAttachment=*/nullptr WGPU_COMMA \ @@ -3351,16 +3334,16 @@ typedef struct WGPURenderPassPixelLocalStorage { } WGPURenderPassPixelLocalStorage WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_PIXEL_LOCAL_STORAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassPixelLocalStorage, { \ - /*.chain=*/{} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassPixelLocalStorage} WGPU_COMMA \ /*.totalPixelLocalStorageSize=*/{} WGPU_COMMA \ /*.storageAttachmentCount=*/0 WGPU_COMMA \ /*.storageAttachments=*/{} WGPU_COMMA \ }) typedef struct WGPUVertexState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUShaderModule module; - WGPU_NULLABLE char const * entryPoint; + WGPUStringView entryPoint; size_t constantCount; WGPUConstantEntry const * constants; size_t bufferCount; @@ -3370,17 +3353,31 @@ typedef struct WGPUVertexState { #define WGPU_VERTEX_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUVertexState, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ /*.module=*/{} WGPU_COMMA \ - /*.entryPoint=*/nullptr WGPU_COMMA \ + /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.constantCount=*/0 WGPU_COMMA \ /*.constants=*/{} WGPU_COMMA \ /*.bufferCount=*/0 WGPU_COMMA \ /*.buffers=*/{} WGPU_COMMA \ }) +typedef struct WGPUComputePipelineDescriptor { + WGPUChainedStruct* nextInChain; + WGPUStringView label; + WGPU_NULLABLE WGPUPipelineLayout layout; + WGPUComputeState compute; +} WGPUComputePipelineDescriptor WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_COMPUTE_PIPELINE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePipelineDescriptor, { \ + /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ + /*.layout=*/nullptr WGPU_COMMA \ + /*.compute=*/WGPU_COMPUTE_STATE_INIT WGPU_COMMA \ +}) + typedef struct WGPUFragmentState { - WGPUChainedStruct const * nextInChain; + WGPUChainedStruct* nextInChain; WGPUShaderModule module; - WGPU_NULLABLE char const * entryPoint; + WGPUStringView entryPoint; size_t constantCount; WGPUConstantEntry const * constants; size_t targetCount; @@ -3390,7 +3387,7 @@ typedef struct WGPUFragmentState { #define WGPU_FRAGMENT_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUFragmentState, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ /*.module=*/{} WGPU_COMMA \ - /*.entryPoint=*/nullptr WGPU_COMMA \ + /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.constantCount=*/0 WGPU_COMMA \ /*.constants=*/{} WGPU_COMMA \ /*.targetCount=*/{} WGPU_COMMA \ @@ -3398,8 +3395,8 @@ typedef struct WGPUFragmentState { }) typedef struct WGPURenderPipelineDescriptor { - WGPUChainedStruct const * nextInChain; - WGPU_NULLABLE char const * label; + WGPUChainedStruct* nextInChain; + WGPUStringView label; WGPU_NULLABLE WGPUPipelineLayout layout; WGPUVertexState vertex; WGPUPrimitiveState primitive; @@ -3410,7 +3407,7 @@ typedef struct WGPURenderPipelineDescriptor { #define WGPU_RENDER_PIPELINE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPipelineDescriptor, { \ /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.label=*/nullptr WGPU_COMMA \ + /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.layout=*/nullptr WGPU_COMMA \ /*.vertex=*/WGPU_VERTEX_STATE_INIT WGPU_COMMA \ /*.primitive=*/WGPU_PRIMITIVE_STATE_INIT WGPU_COMMA \ @@ -3419,125 +3416,41 @@ typedef struct WGPURenderPipelineDescriptor { /*.fragment=*/nullptr WGPU_COMMA \ }) -typedef struct WGPUBufferMapCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUBufferMapCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUBufferMapCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_BUFFER_MAP_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferMapCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCompilationInfoCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUCompilationInfoCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUCompilationInfoCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMPILATION_INFO_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfoCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCreateComputePipelineAsyncCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUCreateComputePipelineAsyncCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUCreateComputePipelineAsyncCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_CREATE_COMPUTE_PIPELINE_ASYNC_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateComputePipelineAsyncCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) - -typedef struct WGPUCreateRenderPipelineAsyncCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUCreateRenderPipelineAsyncCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUCreateRenderPipelineAsyncCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_CREATE_RENDER_PIPELINE_ASYNC_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateRenderPipelineAsyncCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) +// WGPURenderPassDescriptorMaxDrawCount is deprecated. +// Use WGPURenderPassMaxDrawCount instead. +typedef WGPURenderPassMaxDrawCount WGPURenderPassDescriptorMaxDrawCount; -typedef struct WGPUPopErrorScopeCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUPopErrorScopeCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUPopErrorScopeCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; +// WGPUShaderModuleSPIRVDescriptor is deprecated. +// Use WGPUShaderSourceSPIRV instead. +typedef WGPUShaderSourceSPIRV WGPUShaderModuleSPIRVDescriptor; -#define WGPU_POP_ERROR_SCOPE_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUPopErrorScopeCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) +// WGPUShaderModuleWGSLDescriptor is deprecated. +// Use WGPUShaderSourceWGSL instead. +typedef WGPUShaderSourceWGSL WGPUShaderModuleWGSLDescriptor; -typedef struct WGPUQueueWorkDoneCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPUQueueWorkDoneCallback2 callback; - void* userdata1; - void* userdata2; -} WGPUQueueWorkDoneCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; +// WGPUSurfaceDescriptorFromAndroidNativeWindow is deprecated. +// Use WGPUSurfaceSourceAndroidNativeWindow instead. +typedef WGPUSurfaceSourceAndroidNativeWindow WGPUSurfaceDescriptorFromAndroidNativeWindow; -#define WGPU_QUEUE_WORK_DONE_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueWorkDoneCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) +// WGPUSurfaceDescriptorFromMetalLayer is deprecated. +// Use WGPUSurfaceSourceMetalLayer instead. +typedef WGPUSurfaceSourceMetalLayer WGPUSurfaceDescriptorFromMetalLayer; -typedef struct WGPURequestAdapterCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPURequestAdapterCallback2 callback; - void* userdata1; - void* userdata2; -} WGPURequestAdapterCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; +// WGPUSurfaceDescriptorFromWaylandSurface is deprecated. +// Use WGPUSurfaceSourceWaylandSurface instead. +typedef WGPUSurfaceSourceWaylandSurface WGPUSurfaceDescriptorFromWaylandSurface; -#define WGPU_REQUEST_ADAPTER_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) +// WGPUSurfaceDescriptorFromWindowsHWND is deprecated. +// Use WGPUSurfaceSourceWindowsHWND instead. +typedef WGPUSurfaceSourceWindowsHWND WGPUSurfaceDescriptorFromWindowsHWND; -typedef struct WGPURequestDeviceCallbackInfo2 { - WGPUChainedStruct const* nextInChain; - WGPUCallbackMode mode; - WGPURequestDeviceCallback2 callback; - void* userdata1; - void* userdata2; -} WGPURequestDeviceCallbackInfo2 WGPU_STRUCTURE_ATTRIBUTE; +// WGPUSurfaceDescriptorFromXcbWindow is deprecated. +// Use WGPUSurfaceSourceXCBWindow instead. +typedef WGPUSurfaceSourceXCBWindow WGPUSurfaceDescriptorFromXcbWindow; -#define WGPU_REQUEST_DEVICE_CALLBACK_INFO_2_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestDeviceCallbackInfo2, { \ - /*.mode=*/WGPUCallbackMode_Undefined WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ -}) +// WGPUSurfaceDescriptorFromXlibWindow is deprecated. +// Use WGPUSurfaceSourceXlibWindow instead. +typedef WGPUSurfaceSourceXlibWindow WGPUSurfaceDescriptorFromXlibWindow; #ifdef __cplusplus extern "C" { @@ -3545,39 +3458,43 @@ extern "C" { #if !defined(WGPU_SKIP_PROCS) -typedef void (*WGPUProcAdapterInfoFreeMembers)(WGPUAdapterInfo value) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcAdapterPropertiesFreeMembers)(WGPUAdapterProperties value) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcAdapterPropertiesMemoryHeapsFreeMembers)(WGPUAdapterPropertiesMemoryHeaps value) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUInstance (*WGPUProcCreateInstance)(WGPUInstanceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDrmFormatCapabilitiesFreeMembers)(WGPUDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUStatus (*WGPUProcGetInstanceFeatures)(WGPUInstanceFeatures * features) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUProc (*WGPUProcGetProcAddress)(WGPUDevice device, char const * procName) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSharedBufferMemoryEndAccessStateFreeMembers)(WGPUSharedBufferMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSharedTextureMemoryEndAccessStateFreeMembers)(WGPUSharedTextureMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSurfaceCapabilitiesFreeMembers)(WGPUSurfaceCapabilities value) WGPU_FUNCTION_ATTRIBUTE; +// TODO(374150686): Remove these Emscripten specific declarations from the +// header once they are fully deprecated. +#ifdef __EMSCRIPTEN__ +WGPU_EXPORT WGPUDevice emscripten_webgpu_get_device(void); +#endif + +typedef void (*WGPUProcAdapterInfoFreeMembers)( WGPUAdapterInfo value) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcAdapterPropertiesMemoryHeapsFreeMembers)( WGPUAdapterPropertiesMemoryHeaps value) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUInstance (*WGPUProcCreateInstance)( WGPU_NULLABLE WGPUInstanceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDawnDrmFormatCapabilitiesFreeMembers)( WGPUDawnDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcGetInstanceFeatures)( WGPUInstanceFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUProc (*WGPUProcGetProcAddress)( WGPUStringView procName) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSharedBufferMemoryEndAccessStateFreeMembers)( WGPUSharedBufferMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSharedTextureMemoryEndAccessStateFreeMembers)( WGPUSharedTextureMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSupportedWGSLLanguageFeaturesFreeMembers)( WGPUSupportedWGSLLanguageFeatures value) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSupportedFeaturesFreeMembers)( WGPUSupportedFeatures value) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSurfaceCapabilitiesFreeMembers)( WGPUSurfaceCapabilities value) WGPU_FUNCTION_ATTRIBUTE; // Procs of Adapter typedef WGPUDevice (*WGPUProcAdapterCreateDevice)(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef size_t (*WGPUProcAdapterEnumerateFeatures)(WGPUAdapter adapter, WGPUFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUStatus (*WGPUProcAdapterGetFormatCapabilities)(WGPUAdapter adapter, WGPUTextureFormat format, WGPUFormatCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcAdapterGetFeatures)(WGPUAdapter adapter, WGPUSupportedFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcAdapterGetFormatCapabilities)(WGPUAdapter adapter, WGPUTextureFormat format, WGPUDawnFormatCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcAdapterGetInfo)(WGPUAdapter adapter, WGPUAdapterInfo * info) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUInstance (*WGPUProcAdapterGetInstance)(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcAdapterGetLimits)(WGPUAdapter adapter, WGPUSupportedLimits * limits) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUStatus (*WGPUProcAdapterGetProperties)(WGPUAdapter adapter, WGPUAdapterProperties * properties) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBool (*WGPUProcAdapterHasFeature)(WGPUAdapter adapter, WGPUFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcAdapterRequestDevice)(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * descriptor, WGPURequestDeviceCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcAdapterRequestDevice2)(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcAdapterRequestDeviceF)(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcAdapterRequestDevice)(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcAdapterAddRef)(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcAdapterRelease)(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; // Procs of BindGroup -typedef void (*WGPUProcBindGroupSetLabel)(WGPUBindGroup bindGroup, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcBindGroupSetLabel)(WGPUBindGroup bindGroup, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBindGroupAddRef)(WGPUBindGroup bindGroup) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBindGroupRelease)(WGPUBindGroup bindGroup) WGPU_FUNCTION_ATTRIBUTE; // Procs of BindGroupLayout -typedef void (*WGPUProcBindGroupLayoutSetLabel)(WGPUBindGroupLayout bindGroupLayout, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcBindGroupLayoutSetLabel)(WGPUBindGroupLayout bindGroupLayout, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBindGroupLayoutAddRef)(WGPUBindGroupLayout bindGroupLayout) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBindGroupLayoutRelease)(WGPUBindGroupLayout bindGroupLayout) WGPU_FUNCTION_ATTRIBUTE; @@ -3587,17 +3504,15 @@ typedef void const * (*WGPUProcBufferGetConstMappedRange)(WGPUBuffer buffer, siz typedef WGPUBufferMapState (*WGPUProcBufferGetMapState)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; typedef void * (*WGPUProcBufferGetMappedRange)(WGPUBuffer buffer, size_t offset, size_t size) WGPU_FUNCTION_ATTRIBUTE; typedef uint64_t (*WGPUProcBufferGetSize)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUBufferUsageFlags (*WGPUProcBufferGetUsage)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcBufferMapAsync)(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcBufferMapAsync2)(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcBufferMapAsyncF)(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcBufferSetLabel)(WGPUBuffer buffer, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUBufferUsage (*WGPUProcBufferGetUsage)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcBufferMapAsync)(WGPUBuffer buffer, WGPUMapMode mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcBufferSetLabel)(WGPUBuffer buffer, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBufferUnmap)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBufferAddRef)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcBufferRelease)(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; // Procs of CommandBuffer -typedef void (*WGPUProcCommandBufferSetLabel)(WGPUCommandBuffer commandBuffer, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcCommandBufferSetLabel)(WGPUCommandBuffer commandBuffer, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandBufferAddRef)(WGPUCommandBuffer commandBuffer) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandBufferRelease)(WGPUCommandBuffer commandBuffer) WGPU_FUNCTION_ATTRIBUTE; @@ -3610,12 +3525,12 @@ typedef void (*WGPUProcCommandEncoderCopyBufferToTexture)(WGPUCommandEncoder com typedef void (*WGPUProcCommandEncoderCopyTextureToBuffer)(WGPUCommandEncoder commandEncoder, WGPUImageCopyTexture const * source, WGPUImageCopyBuffer const * destination, WGPUExtent3D const * copySize) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderCopyTextureToTexture)(WGPUCommandEncoder commandEncoder, WGPUImageCopyTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUCommandBuffer (*WGPUProcCommandEncoderFinish)(WGPUCommandEncoder commandEncoder, WGPU_NULLABLE WGPUCommandBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcCommandEncoderInjectValidationError)(WGPUCommandEncoder commandEncoder, char const * message) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcCommandEncoderInsertDebugMarker)(WGPUCommandEncoder commandEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcCommandEncoderInjectValidationError)(WGPUCommandEncoder commandEncoder, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcCommandEncoderInsertDebugMarker)(WGPUCommandEncoder commandEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderPopDebugGroup)(WGPUCommandEncoder commandEncoder) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcCommandEncoderPushDebugGroup)(WGPUCommandEncoder commandEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcCommandEncoderPushDebugGroup)(WGPUCommandEncoder commandEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderResolveQuerySet)(WGPUCommandEncoder commandEncoder, WGPUQuerySet querySet, uint32_t firstQuery, uint32_t queryCount, WGPUBuffer destination, uint64_t destinationOffset) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcCommandEncoderSetLabel)(WGPUCommandEncoder commandEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcCommandEncoderSetLabel)(WGPUCommandEncoder commandEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderWriteBuffer)(WGPUCommandEncoder commandEncoder, WGPUBuffer buffer, uint64_t bufferOffset, uint8_t const * data, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderWriteTimestamp)(WGPUCommandEncoder commandEncoder, WGPUQuerySet querySet, uint32_t queryIndex) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcCommandEncoderAddRef)(WGPUCommandEncoder commandEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -3625,11 +3540,11 @@ typedef void (*WGPUProcCommandEncoderRelease)(WGPUCommandEncoder commandEncoder) typedef void (*WGPUProcComputePassEncoderDispatchWorkgroups)(WGPUComputePassEncoder computePassEncoder, uint32_t workgroupCountX, uint32_t workgroupCountY, uint32_t workgroupCountZ) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderDispatchWorkgroupsIndirect)(WGPUComputePassEncoder computePassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderEnd)(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcComputePassEncoderInsertDebugMarker)(WGPUComputePassEncoder computePassEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcComputePassEncoderInsertDebugMarker)(WGPUComputePassEncoder computePassEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderPopDebugGroup)(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcComputePassEncoderPushDebugGroup)(WGPUComputePassEncoder computePassEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcComputePassEncoderPushDebugGroup)(WGPUComputePassEncoder computePassEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderSetBindGroup)(WGPUComputePassEncoder computePassEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcComputePassEncoderSetLabel)(WGPUComputePassEncoder computePassEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcComputePassEncoderSetLabel)(WGPUComputePassEncoder computePassEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderSetPipeline)(WGPUComputePassEncoder computePassEncoder, WGPUComputePipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderWriteTimestamp)(WGPUComputePassEncoder computePassEncoder, WGPUQuerySet querySet, uint32_t queryIndex) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePassEncoderAddRef)(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -3637,7 +3552,7 @@ typedef void (*WGPUProcComputePassEncoderRelease)(WGPUComputePassEncoder compute // Procs of ComputePipeline typedef WGPUBindGroupLayout (*WGPUProcComputePipelineGetBindGroupLayout)(WGPUComputePipeline computePipeline, uint32_t groupIndex) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcComputePipelineSetLabel)(WGPUComputePipeline computePipeline, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcComputePipelineSetLabel)(WGPUComputePipeline computePipeline, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePipelineAddRef)(WGPUComputePipeline computePipeline) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcComputePipelineRelease)(WGPUComputePipeline computePipeline) WGPU_FUNCTION_ATTRIBUTE; @@ -3647,46 +3562,38 @@ typedef WGPUBindGroupLayout (*WGPUProcDeviceCreateBindGroupLayout)(WGPUDevice de typedef WGPUBuffer (*WGPUProcDeviceCreateBuffer)(WGPUDevice device, WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUCommandEncoder (*WGPUProcDeviceCreateCommandEncoder)(WGPUDevice device, WGPU_NULLABLE WGPUCommandEncoderDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUComputePipeline (*WGPUProcDeviceCreateComputePipeline)(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceCreateComputePipelineAsync)(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDeviceCreateComputePipelineAsync2)(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDeviceCreateComputePipelineAsyncF)(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcDeviceCreateComputePipelineAsync)(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBuffer (*WGPUProcDeviceCreateErrorBuffer)(WGPUDevice device, WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUExternalTexture (*WGPUProcDeviceCreateErrorExternalTexture)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUShaderModule (*WGPUProcDeviceCreateErrorShaderModule)(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor, char const * errorMessage) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUShaderModule (*WGPUProcDeviceCreateErrorShaderModule)(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor, WGPUStringView errorMessage) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUTexture (*WGPUProcDeviceCreateErrorTexture)(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUExternalTexture (*WGPUProcDeviceCreateExternalTexture)(WGPUDevice device, WGPUExternalTextureDescriptor const * externalTextureDescriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUPipelineLayout (*WGPUProcDeviceCreatePipelineLayout)(WGPUDevice device, WGPUPipelineLayoutDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUQuerySet (*WGPUProcDeviceCreateQuerySet)(WGPUDevice device, WGPUQuerySetDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPURenderBundleEncoder (*WGPUProcDeviceCreateRenderBundleEncoder)(WGPUDevice device, WGPURenderBundleEncoderDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPURenderPipeline (*WGPUProcDeviceCreateRenderPipeline)(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceCreateRenderPipelineAsync)(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDeviceCreateRenderPipelineAsync2)(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDeviceCreateRenderPipelineAsyncF)(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcDeviceCreateRenderPipelineAsync)(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUSampler (*WGPUProcDeviceCreateSampler)(WGPUDevice device, WGPU_NULLABLE WGPUSamplerDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUShaderModule (*WGPUProcDeviceCreateShaderModule)(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUSwapChain (*WGPUProcDeviceCreateSwapChain)(WGPUDevice device, WGPUSurface surface, WGPUSwapChainDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUTexture (*WGPUProcDeviceCreateTexture)(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDeviceDestroy)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -typedef size_t (*WGPUProcDeviceEnumerateFeatures)(WGPUDevice device, WGPUFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceForceLoss)(WGPUDevice device, WGPUDeviceLostReason type, char const * message) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDeviceForceLoss)(WGPUDevice device, WGPUDeviceLostReason type, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcDeviceGetAHardwareBufferProperties)(WGPUDevice device, void * handle, WGPUAHardwareBufferProperties * properties) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUAdapter (*WGPUProcDeviceGetAdapter)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcDeviceGetAdapterInfo)(WGPUDevice device, WGPUAdapterInfo * adapterInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDeviceGetFeatures)(WGPUDevice device, WGPUSupportedFeatures * features) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcDeviceGetLimits)(WGPUDevice device, WGPUSupportedLimits * limits) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcDeviceGetLostFuture)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUQueue (*WGPUProcDeviceGetQueue)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUTextureUsageFlags (*WGPUProcDeviceGetSupportedSurfaceUsage)(WGPUDevice device, WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBool (*WGPUProcDeviceHasFeature)(WGPUDevice device, WGPUFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUSharedBufferMemory (*WGPUProcDeviceImportSharedBufferMemory)(WGPUDevice device, WGPUSharedBufferMemoryDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUSharedFence (*WGPUProcDeviceImportSharedFence)(WGPUDevice device, WGPUSharedFenceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUSharedTextureMemory (*WGPUProcDeviceImportSharedTextureMemory)(WGPUDevice device, WGPUSharedTextureMemoryDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceInjectError)(WGPUDevice device, WGPUErrorType type, char const * message) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDevicePopErrorScope)(WGPUDevice device, WGPUErrorCallback oldCallback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDevicePopErrorScope2)(WGPUDevice device, WGPUPopErrorScopeCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcDevicePopErrorScopeF)(WGPUDevice device, WGPUPopErrorScopeCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDeviceInjectError)(WGPUDevice device, WGPUErrorType type, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcDevicePopErrorScope)(WGPUDevice device, WGPUPopErrorScopeCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDevicePushErrorScope)(WGPUDevice device, WGPUErrorFilter filter) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceSetDeviceLostCallback)(WGPUDevice device, WGPUDeviceLostCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceSetLabel)(WGPUDevice device, char const * label) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceSetLoggingCallback)(WGPUDevice device, WGPULoggingCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcDeviceSetUncapturedErrorCallback)(WGPUDevice device, WGPUErrorCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDeviceSetLabel)(WGPUDevice device, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcDeviceSetLoggingCallback)(WGPUDevice device, WGPULoggingCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDeviceTick)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDeviceValidateTextureDescriptor)(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDeviceAddRef)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; @@ -3696,24 +3603,23 @@ typedef void (*WGPUProcDeviceRelease)(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE typedef void (*WGPUProcExternalTextureDestroy)(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcExternalTextureExpire)(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcExternalTextureRefresh)(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcExternalTextureSetLabel)(WGPUExternalTexture externalTexture, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcExternalTextureSetLabel)(WGPUExternalTexture externalTexture, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcExternalTextureAddRef)(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcExternalTextureRelease)(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; // Procs of Instance typedef WGPUSurface (*WGPUProcInstanceCreateSurface)(WGPUInstance instance, WGPUSurfaceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef size_t (*WGPUProcInstanceEnumerateWGSLLanguageFeatures)(WGPUInstance instance, WGPUWGSLFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUBool (*WGPUProcInstanceHasWGSLLanguageFeature)(WGPUInstance instance, WGPUWGSLFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcInstanceGetWGSLLanguageFeatures)(WGPUInstance instance, WGPUSupportedWGSLLanguageFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUBool (*WGPUProcInstanceHasWGSLLanguageFeature)(WGPUInstance instance, WGPUWGSLLanguageFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcInstanceProcessEvents)(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcInstanceRequestAdapter)(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcInstanceRequestAdapter2)(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcInstanceRequestAdapterF)(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcInstanceRequestAdapter)(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUWaitStatus (*WGPUProcInstanceWaitAny)(WGPUInstance instance, size_t futureCount, WGPUFutureWaitInfo * futures, uint64_t timeoutNS) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcInstanceAddRef)(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcInstanceRelease)(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; // Procs of PipelineLayout -typedef void (*WGPUProcPipelineLayoutSetLabel)(WGPUPipelineLayout pipelineLayout, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcPipelineLayoutSetLabel)(WGPUPipelineLayout pipelineLayout, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcPipelineLayoutAddRef)(WGPUPipelineLayout pipelineLayout) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcPipelineLayoutRelease)(WGPUPipelineLayout pipelineLayout) WGPU_FUNCTION_ATTRIBUTE; @@ -3721,17 +3627,15 @@ typedef void (*WGPUProcPipelineLayoutRelease)(WGPUPipelineLayout pipelineLayout) typedef void (*WGPUProcQuerySetDestroy)(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; typedef uint32_t (*WGPUProcQuerySetGetCount)(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUQueryType (*WGPUProcQuerySetGetType)(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcQuerySetSetLabel)(WGPUQuerySet querySet, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcQuerySetSetLabel)(WGPUQuerySet querySet, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQuerySetAddRef)(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQuerySetRelease)(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; // Procs of Queue typedef void (*WGPUProcQueueCopyExternalTextureForBrowser)(WGPUQueue queue, WGPUImageCopyExternalTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize, WGPUCopyTextureForBrowserOptions const * options) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQueueCopyTextureForBrowser)(WGPUQueue queue, WGPUImageCopyTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize, WGPUCopyTextureForBrowserOptions const * options) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcQueueOnSubmittedWorkDone)(WGPUQueue queue, WGPUQueueWorkDoneCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcQueueOnSubmittedWorkDone2)(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcQueueOnSubmittedWorkDoneF)(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcQueueSetLabel)(WGPUQueue queue, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcQueueOnSubmittedWorkDone)(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcQueueSetLabel)(WGPUQueue queue, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQueueSubmit)(WGPUQueue queue, size_t commandCount, WGPUCommandBuffer const * commands) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQueueWriteBuffer)(WGPUQueue queue, WGPUBuffer buffer, uint64_t bufferOffset, void const * data, size_t size) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQueueWriteTexture)(WGPUQueue queue, WGPUImageCopyTexture const * destination, void const * data, size_t dataSize, WGPUTextureDataLayout const * dataLayout, WGPUExtent3D const * writeSize) WGPU_FUNCTION_ATTRIBUTE; @@ -3739,7 +3643,7 @@ typedef void (*WGPUProcQueueAddRef)(WGPUQueue queue) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcQueueRelease)(WGPUQueue queue) WGPU_FUNCTION_ATTRIBUTE; // Procs of RenderBundle -typedef void (*WGPUProcRenderBundleSetLabel)(WGPURenderBundle renderBundle, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderBundleSetLabel)(WGPURenderBundle renderBundle, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleAddRef)(WGPURenderBundle renderBundle) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleRelease)(WGPURenderBundle renderBundle) WGPU_FUNCTION_ATTRIBUTE; @@ -3749,12 +3653,12 @@ typedef void (*WGPUProcRenderBundleEncoderDrawIndexed)(WGPURenderBundleEncoder r typedef void (*WGPUProcRenderBundleEncoderDrawIndexedIndirect)(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderDrawIndirect)(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; typedef WGPURenderBundle (*WGPUProcRenderBundleEncoderFinish)(WGPURenderBundleEncoder renderBundleEncoder, WGPU_NULLABLE WGPURenderBundleDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderBundleEncoderInsertDebugMarker)(WGPURenderBundleEncoder renderBundleEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderBundleEncoderInsertDebugMarker)(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderPopDebugGroup)(WGPURenderBundleEncoder renderBundleEncoder) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderBundleEncoderPushDebugGroup)(WGPURenderBundleEncoder renderBundleEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderBundleEncoderPushDebugGroup)(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderSetBindGroup)(WGPURenderBundleEncoder renderBundleEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderSetIndexBuffer)(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer buffer, WGPUIndexFormat format, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderBundleEncoderSetLabel)(WGPURenderBundleEncoder renderBundleEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderBundleEncoderSetLabel)(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderSetPipeline)(WGPURenderBundleEncoder renderBundleEncoder, WGPURenderPipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderSetVertexBuffer)(WGPURenderBundleEncoder renderBundleEncoder, uint32_t slot, WGPU_NULLABLE WGPUBuffer buffer, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderBundleEncoderAddRef)(WGPURenderBundleEncoder renderBundleEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -3769,14 +3673,16 @@ typedef void (*WGPUProcRenderPassEncoderDrawIndirect)(WGPURenderPassEncoder rend typedef void (*WGPUProcRenderPassEncoderEnd)(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderEndOcclusionQuery)(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderExecuteBundles)(WGPURenderPassEncoder renderPassEncoder, size_t bundleCount, WGPURenderBundle const * bundles) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderPassEncoderInsertDebugMarker)(WGPURenderPassEncoder renderPassEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPassEncoderInsertDebugMarker)(WGPURenderPassEncoder renderPassEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPassEncoderMultiDrawIndexedIndirect)(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset, uint32_t maxDrawCount, WGPU_NULLABLE WGPUBuffer drawCountBuffer, uint64_t drawCountBufferOffset) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPassEncoderMultiDrawIndirect)(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset, uint32_t maxDrawCount, WGPU_NULLABLE WGPUBuffer drawCountBuffer, uint64_t drawCountBufferOffset) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderPixelLocalStorageBarrier)(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderPopDebugGroup)(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderPassEncoderPushDebugGroup)(WGPURenderPassEncoder renderPassEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPassEncoderPushDebugGroup)(WGPURenderPassEncoder renderPassEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetBindGroup)(WGPURenderPassEncoder renderPassEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetBlendConstant)(WGPURenderPassEncoder renderPassEncoder, WGPUColor const * color) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetIndexBuffer)(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer buffer, WGPUIndexFormat format, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderPassEncoderSetLabel)(WGPURenderPassEncoder renderPassEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPassEncoderSetLabel)(WGPURenderPassEncoder renderPassEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetPipeline)(WGPURenderPassEncoder renderPassEncoder, WGPURenderPipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetScissorRect)(WGPURenderPassEncoder renderPassEncoder, uint32_t x, uint32_t y, uint32_t width, uint32_t height) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPassEncoderSetStencilReference)(WGPURenderPassEncoder renderPassEncoder, uint32_t reference) WGPU_FUNCTION_ATTRIBUTE; @@ -3788,30 +3694,28 @@ typedef void (*WGPUProcRenderPassEncoderRelease)(WGPURenderPassEncoder renderPas // Procs of RenderPipeline typedef WGPUBindGroupLayout (*WGPUProcRenderPipelineGetBindGroupLayout)(WGPURenderPipeline renderPipeline, uint32_t groupIndex) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcRenderPipelineSetLabel)(WGPURenderPipeline renderPipeline, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcRenderPipelineSetLabel)(WGPURenderPipeline renderPipeline, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPipelineAddRef)(WGPURenderPipeline renderPipeline) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcRenderPipelineRelease)(WGPURenderPipeline renderPipeline) WGPU_FUNCTION_ATTRIBUTE; // Procs of Sampler -typedef void (*WGPUProcSamplerSetLabel)(WGPUSampler sampler, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSamplerSetLabel)(WGPUSampler sampler, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSamplerAddRef)(WGPUSampler sampler) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSamplerRelease)(WGPUSampler sampler) WGPU_FUNCTION_ATTRIBUTE; // Procs of ShaderModule -typedef void (*WGPUProcShaderModuleGetCompilationInfo)(WGPUShaderModule shaderModule, WGPUCompilationInfoCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcShaderModuleGetCompilationInfo2)(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUFuture (*WGPUProcShaderModuleGetCompilationInfoF)(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcShaderModuleSetLabel)(WGPUShaderModule shaderModule, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUFuture (*WGPUProcShaderModuleGetCompilationInfo)(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcShaderModuleSetLabel)(WGPUShaderModule shaderModule, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcShaderModuleAddRef)(WGPUShaderModule shaderModule) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcShaderModuleRelease)(WGPUShaderModule shaderModule) WGPU_FUNCTION_ATTRIBUTE; // Procs of SharedBufferMemory -typedef WGPUBool (*WGPUProcSharedBufferMemoryBeginAccess)(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcSharedBufferMemoryBeginAccess)(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBuffer (*WGPUProcSharedBufferMemoryCreateBuffer)(WGPUSharedBufferMemory sharedBufferMemory, WGPU_NULLABLE WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUBool (*WGPUProcSharedBufferMemoryEndAccess)(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcSharedBufferMemoryEndAccess)(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcSharedBufferMemoryGetProperties)(WGPUSharedBufferMemory sharedBufferMemory, WGPUSharedBufferMemoryProperties * properties) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBool (*WGPUProcSharedBufferMemoryIsDeviceLost)(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSharedBufferMemorySetLabel)(WGPUSharedBufferMemory sharedBufferMemory, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSharedBufferMemorySetLabel)(WGPUSharedBufferMemory sharedBufferMemory, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedBufferMemoryAddRef)(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedBufferMemoryRelease)(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; @@ -3821,12 +3725,12 @@ typedef void (*WGPUProcSharedFenceAddRef)(WGPUSharedFence sharedFence) WGPU_FUNC typedef void (*WGPUProcSharedFenceRelease)(WGPUSharedFence sharedFence) WGPU_FUNCTION_ATTRIBUTE; // Procs of SharedTextureMemory -typedef WGPUBool (*WGPUProcSharedTextureMemoryBeginAccess)(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcSharedTextureMemoryBeginAccess)(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUTexture (*WGPUProcSharedTextureMemoryCreateTexture)(WGPUSharedTextureMemory sharedTextureMemory, WGPU_NULLABLE WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUBool (*WGPUProcSharedTextureMemoryEndAccess)(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcSharedTextureMemoryEndAccess)(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcSharedTextureMemoryGetProperties)(WGPUSharedTextureMemory sharedTextureMemory, WGPUSharedTextureMemoryProperties * properties) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBool (*WGPUProcSharedTextureMemoryIsDeviceLost)(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSharedTextureMemorySetLabel)(WGPUSharedTextureMemory sharedTextureMemory, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSharedTextureMemorySetLabel)(WGPUSharedTextureMemory sharedTextureMemory, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedTextureMemoryAddRef)(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedTextureMemoryRelease)(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; @@ -3834,19 +3738,12 @@ typedef void (*WGPUProcSharedTextureMemoryRelease)(WGPUSharedTextureMemory share typedef void (*WGPUProcSurfaceConfigure)(WGPUSurface surface, WGPUSurfaceConfiguration const * config) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcSurfaceGetCapabilities)(WGPUSurface surface, WGPUAdapter adapter, WGPUSurfaceCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSurfaceGetCurrentTexture)(WGPUSurface surface, WGPUSurfaceTexture * surfaceTexture) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUTextureFormat (*WGPUProcSurfaceGetPreferredFormat)(WGPUSurface surface, WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSurfacePresent)(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcSurfaceSetLabel)(WGPUSurface surface, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSurfaceUnconfigure)(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSurfaceAddRef)(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSurfaceRelease)(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; -// Procs of SwapChain -typedef WGPUTexture (*WGPUProcSwapChainGetCurrentTexture)(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUTextureView (*WGPUProcSwapChainGetCurrentTextureView)(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSwapChainPresent)(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSwapChainAddRef)(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcSwapChainRelease)(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; - // Procs of Texture typedef WGPUTextureView (*WGPUProcTextureCreateErrorView)(WGPUTexture texture, WGPU_NULLABLE WGPUTextureViewDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUTextureView (*WGPUProcTextureCreateView)(WGPUTexture texture, WGPU_NULLABLE WGPUTextureViewDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; @@ -3857,14 +3754,14 @@ typedef WGPUTextureFormat (*WGPUProcTextureGetFormat)(WGPUTexture texture) WGPU_ typedef uint32_t (*WGPUProcTextureGetHeight)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; typedef uint32_t (*WGPUProcTextureGetMipLevelCount)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; typedef uint32_t (*WGPUProcTextureGetSampleCount)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUTextureUsageFlags (*WGPUProcTextureGetUsage)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUTextureUsage (*WGPUProcTextureGetUsage)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; typedef uint32_t (*WGPUProcTextureGetWidth)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; -typedef void (*WGPUProcTextureSetLabel)(WGPUTexture texture, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcTextureSetLabel)(WGPUTexture texture, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcTextureAddRef)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcTextureRelease)(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; // Procs of TextureView -typedef void (*WGPUProcTextureViewSetLabel)(WGPUTextureView textureView, char const * label) WGPU_FUNCTION_ATTRIBUTE; +typedef void (*WGPUProcTextureViewSetLabel)(WGPUTextureView textureView, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcTextureViewAddRef)(WGPUTextureView textureView) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcTextureViewRelease)(WGPUTextureView textureView) WGPU_FUNCTION_ATTRIBUTE; @@ -3874,38 +3771,36 @@ typedef void (*WGPUProcTextureViewRelease)(WGPUTextureView textureView) WGPU_FUN #if !defined(WGPU_SKIP_DECLARATIONS) WGPU_EXPORT void wgpuAdapterInfoFreeMembers(WGPUAdapterInfo value) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuAdapterPropertiesFreeMembers(WGPUAdapterProperties value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuAdapterPropertiesMemoryHeapsFreeMembers(WGPUAdapterPropertiesMemoryHeaps value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUInstance wgpuCreateInstance(WGPU_NULLABLE WGPUInstanceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDrmFormatCapabilitiesFreeMembers(WGPUDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDawnDrmFormatCapabilitiesFreeMembers(WGPUDawnDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuGetInstanceFeatures(WGPUInstanceFeatures * features) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUProc wgpuGetProcAddress(WGPU_NULLABLE WGPUDevice device, char const * procName) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUProc wgpuGetProcAddress(WGPUStringView procName) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedBufferMemoryEndAccessStateFreeMembers(WGPUSharedBufferMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedTextureMemoryEndAccessStateFreeMembers(WGPUSharedTextureMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSupportedWGSLLanguageFeaturesFreeMembers(WGPUSupportedWGSLLanguageFeatures value) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSupportedFeaturesFreeMembers(WGPUSupportedFeatures value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfaceCapabilitiesFreeMembers(WGPUSurfaceCapabilities value) WGPU_FUNCTION_ATTRIBUTE; // Methods of Adapter WGPU_EXPORT WGPUDevice wgpuAdapterCreateDevice(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT size_t wgpuAdapterEnumerateFeatures(WGPUAdapter adapter, WGPUFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUStatus wgpuAdapterGetFormatCapabilities(WGPUAdapter adapter, WGPUTextureFormat format, WGPUFormatCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuAdapterGetFeatures(WGPUAdapter adapter, WGPUSupportedFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuAdapterGetFormatCapabilities(WGPUAdapter adapter, WGPUTextureFormat format, WGPUDawnFormatCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuAdapterGetInfo(WGPUAdapter adapter, WGPUAdapterInfo * info) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUInstance wgpuAdapterGetInstance(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuAdapterGetLimits(WGPUAdapter adapter, WGPUSupportedLimits * limits) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUStatus wgpuAdapterGetProperties(WGPUAdapter adapter, WGPUAdapterProperties * properties) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBool wgpuAdapterHasFeature(WGPUAdapter adapter, WGPUFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuAdapterRequestDevice(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * descriptor, WGPURequestDeviceCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuAdapterRequestDevice2(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuAdapterRequestDeviceF(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuAdapterRequestDevice(WGPUAdapter adapter, WGPU_NULLABLE WGPUDeviceDescriptor const * options, WGPURequestDeviceCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuAdapterAddRef(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuAdapterRelease(WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; // Methods of BindGroup -WGPU_EXPORT void wgpuBindGroupSetLabel(WGPUBindGroup bindGroup, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuBindGroupSetLabel(WGPUBindGroup bindGroup, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBindGroupAddRef(WGPUBindGroup bindGroup) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBindGroupRelease(WGPUBindGroup bindGroup) WGPU_FUNCTION_ATTRIBUTE; // Methods of BindGroupLayout -WGPU_EXPORT void wgpuBindGroupLayoutSetLabel(WGPUBindGroupLayout bindGroupLayout, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuBindGroupLayoutSetLabel(WGPUBindGroupLayout bindGroupLayout, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBindGroupLayoutAddRef(WGPUBindGroupLayout bindGroupLayout) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBindGroupLayoutRelease(WGPUBindGroupLayout bindGroupLayout) WGPU_FUNCTION_ATTRIBUTE; @@ -3915,17 +3810,15 @@ WGPU_EXPORT void const * wgpuBufferGetConstMappedRange(WGPUBuffer buffer, size_t WGPU_EXPORT WGPUBufferMapState wgpuBufferGetMapState(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void * wgpuBufferGetMappedRange(WGPUBuffer buffer, size_t offset, size_t size) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT uint64_t wgpuBufferGetSize(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUBufferUsageFlags wgpuBufferGetUsage(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuBufferMapAsync(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuBufferMapAsync2(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuBufferMapAsyncF(WGPUBuffer buffer, WGPUMapModeFlags mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuBufferSetLabel(WGPUBuffer buffer, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUBufferUsage wgpuBufferGetUsage(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuBufferMapAsync(WGPUBuffer buffer, WGPUMapMode mode, size_t offset, size_t size, WGPUBufferMapCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuBufferSetLabel(WGPUBuffer buffer, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBufferUnmap(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBufferAddRef(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuBufferRelease(WGPUBuffer buffer) WGPU_FUNCTION_ATTRIBUTE; // Methods of CommandBuffer -WGPU_EXPORT void wgpuCommandBufferSetLabel(WGPUCommandBuffer commandBuffer, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuCommandBufferSetLabel(WGPUCommandBuffer commandBuffer, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandBufferAddRef(WGPUCommandBuffer commandBuffer) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandBufferRelease(WGPUCommandBuffer commandBuffer) WGPU_FUNCTION_ATTRIBUTE; @@ -3938,12 +3831,12 @@ WGPU_EXPORT void wgpuCommandEncoderCopyBufferToTexture(WGPUCommandEncoder comman WGPU_EXPORT void wgpuCommandEncoderCopyTextureToBuffer(WGPUCommandEncoder commandEncoder, WGPUImageCopyTexture const * source, WGPUImageCopyBuffer const * destination, WGPUExtent3D const * copySize) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderCopyTextureToTexture(WGPUCommandEncoder commandEncoder, WGPUImageCopyTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUCommandBuffer wgpuCommandEncoderFinish(WGPUCommandEncoder commandEncoder, WGPU_NULLABLE WGPUCommandBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuCommandEncoderInjectValidationError(WGPUCommandEncoder commandEncoder, char const * message) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuCommandEncoderInsertDebugMarker(WGPUCommandEncoder commandEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuCommandEncoderInjectValidationError(WGPUCommandEncoder commandEncoder, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuCommandEncoderInsertDebugMarker(WGPUCommandEncoder commandEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderPopDebugGroup(WGPUCommandEncoder commandEncoder) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuCommandEncoderPushDebugGroup(WGPUCommandEncoder commandEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuCommandEncoderPushDebugGroup(WGPUCommandEncoder commandEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderResolveQuerySet(WGPUCommandEncoder commandEncoder, WGPUQuerySet querySet, uint32_t firstQuery, uint32_t queryCount, WGPUBuffer destination, uint64_t destinationOffset) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuCommandEncoderSetLabel(WGPUCommandEncoder commandEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuCommandEncoderSetLabel(WGPUCommandEncoder commandEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderWriteBuffer(WGPUCommandEncoder commandEncoder, WGPUBuffer buffer, uint64_t bufferOffset, uint8_t const * data, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderWriteTimestamp(WGPUCommandEncoder commandEncoder, WGPUQuerySet querySet, uint32_t queryIndex) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuCommandEncoderAddRef(WGPUCommandEncoder commandEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -3953,11 +3846,11 @@ WGPU_EXPORT void wgpuCommandEncoderRelease(WGPUCommandEncoder commandEncoder) WG WGPU_EXPORT void wgpuComputePassEncoderDispatchWorkgroups(WGPUComputePassEncoder computePassEncoder, uint32_t workgroupCountX, uint32_t workgroupCountY, uint32_t workgroupCountZ) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderDispatchWorkgroupsIndirect(WGPUComputePassEncoder computePassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderEnd(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuComputePassEncoderInsertDebugMarker(WGPUComputePassEncoder computePassEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuComputePassEncoderInsertDebugMarker(WGPUComputePassEncoder computePassEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderPopDebugGroup(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuComputePassEncoderPushDebugGroup(WGPUComputePassEncoder computePassEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuComputePassEncoderPushDebugGroup(WGPUComputePassEncoder computePassEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderSetBindGroup(WGPUComputePassEncoder computePassEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuComputePassEncoderSetLabel(WGPUComputePassEncoder computePassEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuComputePassEncoderSetLabel(WGPUComputePassEncoder computePassEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderSetPipeline(WGPUComputePassEncoder computePassEncoder, WGPUComputePipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderWriteTimestamp(WGPUComputePassEncoder computePassEncoder, WGPUQuerySet querySet, uint32_t queryIndex) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePassEncoderAddRef(WGPUComputePassEncoder computePassEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -3965,7 +3858,7 @@ WGPU_EXPORT void wgpuComputePassEncoderRelease(WGPUComputePassEncoder computePas // Methods of ComputePipeline WGPU_EXPORT WGPUBindGroupLayout wgpuComputePipelineGetBindGroupLayout(WGPUComputePipeline computePipeline, uint32_t groupIndex) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuComputePipelineSetLabel(WGPUComputePipeline computePipeline, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuComputePipelineSetLabel(WGPUComputePipeline computePipeline, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePipelineAddRef(WGPUComputePipeline computePipeline) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuComputePipelineRelease(WGPUComputePipeline computePipeline) WGPU_FUNCTION_ATTRIBUTE; @@ -3975,46 +3868,38 @@ WGPU_EXPORT WGPUBindGroupLayout wgpuDeviceCreateBindGroupLayout(WGPUDevice devic WGPU_EXPORT WGPUBuffer wgpuDeviceCreateBuffer(WGPUDevice device, WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUCommandEncoder wgpuDeviceCreateCommandEncoder(WGPUDevice device, WGPU_NULLABLE WGPUCommandEncoderDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUComputePipeline wgpuDeviceCreateComputePipeline(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceCreateComputePipelineAsync(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDeviceCreateComputePipelineAsync2(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDeviceCreateComputePipelineAsyncF(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuDeviceCreateComputePipelineAsync(WGPUDevice device, WGPUComputePipelineDescriptor const * descriptor, WGPUCreateComputePipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBuffer wgpuDeviceCreateErrorBuffer(WGPUDevice device, WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUExternalTexture wgpuDeviceCreateErrorExternalTexture(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUShaderModule wgpuDeviceCreateErrorShaderModule(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor, char const * errorMessage) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUShaderModule wgpuDeviceCreateErrorShaderModule(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor, WGPUStringView errorMessage) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUTexture wgpuDeviceCreateErrorTexture(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUExternalTexture wgpuDeviceCreateExternalTexture(WGPUDevice device, WGPUExternalTextureDescriptor const * externalTextureDescriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUPipelineLayout wgpuDeviceCreatePipelineLayout(WGPUDevice device, WGPUPipelineLayoutDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUQuerySet wgpuDeviceCreateQuerySet(WGPUDevice device, WGPUQuerySetDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPURenderBundleEncoder wgpuDeviceCreateRenderBundleEncoder(WGPUDevice device, WGPURenderBundleEncoderDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPURenderPipeline wgpuDeviceCreateRenderPipeline(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceCreateRenderPipelineAsync(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDeviceCreateRenderPipelineAsync2(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDeviceCreateRenderPipelineAsyncF(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuDeviceCreateRenderPipelineAsync(WGPUDevice device, WGPURenderPipelineDescriptor const * descriptor, WGPUCreateRenderPipelineAsyncCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUSampler wgpuDeviceCreateSampler(WGPUDevice device, WGPU_NULLABLE WGPUSamplerDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUShaderModule wgpuDeviceCreateShaderModule(WGPUDevice device, WGPUShaderModuleDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUSwapChain wgpuDeviceCreateSwapChain(WGPUDevice device, WGPUSurface surface, WGPUSwapChainDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUTexture wgpuDeviceCreateTexture(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDeviceDestroy(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT size_t wgpuDeviceEnumerateFeatures(WGPUDevice device, WGPUFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceForceLoss(WGPUDevice device, WGPUDeviceLostReason type, char const * message) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDeviceForceLoss(WGPUDevice device, WGPUDeviceLostReason type, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuDeviceGetAHardwareBufferProperties(WGPUDevice device, void * handle, WGPUAHardwareBufferProperties * properties) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUAdapter wgpuDeviceGetAdapter(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuDeviceGetAdapterInfo(WGPUDevice device, WGPUAdapterInfo * adapterInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDeviceGetFeatures(WGPUDevice device, WGPUSupportedFeatures * features) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuDeviceGetLimits(WGPUDevice device, WGPUSupportedLimits * limits) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuDeviceGetLostFuture(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUQueue wgpuDeviceGetQueue(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUTextureUsageFlags wgpuDeviceGetSupportedSurfaceUsage(WGPUDevice device, WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBool wgpuDeviceHasFeature(WGPUDevice device, WGPUFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUSharedBufferMemory wgpuDeviceImportSharedBufferMemory(WGPUDevice device, WGPUSharedBufferMemoryDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUSharedFence wgpuDeviceImportSharedFence(WGPUDevice device, WGPUSharedFenceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUSharedTextureMemory wgpuDeviceImportSharedTextureMemory(WGPUDevice device, WGPUSharedTextureMemoryDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceInjectError(WGPUDevice device, WGPUErrorType type, char const * message) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDevicePopErrorScope(WGPUDevice device, WGPUErrorCallback oldCallback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDevicePopErrorScope2(WGPUDevice device, WGPUPopErrorScopeCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuDevicePopErrorScopeF(WGPUDevice device, WGPUPopErrorScopeCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDeviceInjectError(WGPUDevice device, WGPUErrorType type, WGPUStringView message) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuDevicePopErrorScope(WGPUDevice device, WGPUPopErrorScopeCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDevicePushErrorScope(WGPUDevice device, WGPUErrorFilter filter) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceSetDeviceLostCallback(WGPUDevice device, WGPUDeviceLostCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceSetLabel(WGPUDevice device, char const * label) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceSetLoggingCallback(WGPUDevice device, WGPULoggingCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuDeviceSetUncapturedErrorCallback(WGPUDevice device, WGPUErrorCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDeviceSetLabel(WGPUDevice device, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuDeviceSetLoggingCallback(WGPUDevice device, WGPULoggingCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDeviceTick(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDeviceValidateTextureDescriptor(WGPUDevice device, WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDeviceAddRef(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; @@ -4024,24 +3909,23 @@ WGPU_EXPORT void wgpuDeviceRelease(WGPUDevice device) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuExternalTextureDestroy(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuExternalTextureExpire(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuExternalTextureRefresh(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuExternalTextureSetLabel(WGPUExternalTexture externalTexture, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuExternalTextureSetLabel(WGPUExternalTexture externalTexture, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuExternalTextureAddRef(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuExternalTextureRelease(WGPUExternalTexture externalTexture) WGPU_FUNCTION_ATTRIBUTE; // Methods of Instance WGPU_EXPORT WGPUSurface wgpuInstanceCreateSurface(WGPUInstance instance, WGPUSurfaceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT size_t wgpuInstanceEnumerateWGSLLanguageFeatures(WGPUInstance instance, WGPUWGSLFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUBool wgpuInstanceHasWGSLLanguageFeature(WGPUInstance instance, WGPUWGSLFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuInstanceGetWGSLLanguageFeatures(WGPUInstance instance, WGPUSupportedWGSLLanguageFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUBool wgpuInstanceHasWGSLLanguageFeature(WGPUInstance instance, WGPUWGSLLanguageFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuInstanceProcessEvents(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuInstanceRequestAdapter(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuInstanceRequestAdapter2(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuInstanceRequestAdapterF(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuInstanceRequestAdapter(WGPUInstance instance, WGPU_NULLABLE WGPURequestAdapterOptions const * options, WGPURequestAdapterCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUWaitStatus wgpuInstanceWaitAny(WGPUInstance instance, size_t futureCount, WGPUFutureWaitInfo * futures, uint64_t timeoutNS) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuInstanceAddRef(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuInstanceRelease(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; // Methods of PipelineLayout -WGPU_EXPORT void wgpuPipelineLayoutSetLabel(WGPUPipelineLayout pipelineLayout, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuPipelineLayoutSetLabel(WGPUPipelineLayout pipelineLayout, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuPipelineLayoutAddRef(WGPUPipelineLayout pipelineLayout) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuPipelineLayoutRelease(WGPUPipelineLayout pipelineLayout) WGPU_FUNCTION_ATTRIBUTE; @@ -4049,17 +3933,15 @@ WGPU_EXPORT void wgpuPipelineLayoutRelease(WGPUPipelineLayout pipelineLayout) WG WGPU_EXPORT void wgpuQuerySetDestroy(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT uint32_t wgpuQuerySetGetCount(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUQueryType wgpuQuerySetGetType(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuQuerySetSetLabel(WGPUQuerySet querySet, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuQuerySetSetLabel(WGPUQuerySet querySet, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQuerySetAddRef(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQuerySetRelease(WGPUQuerySet querySet) WGPU_FUNCTION_ATTRIBUTE; // Methods of Queue WGPU_EXPORT void wgpuQueueCopyExternalTextureForBrowser(WGPUQueue queue, WGPUImageCopyExternalTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize, WGPUCopyTextureForBrowserOptions const * options) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQueueCopyTextureForBrowser(WGPUQueue queue, WGPUImageCopyTexture const * source, WGPUImageCopyTexture const * destination, WGPUExtent3D const * copySize, WGPUCopyTextureForBrowserOptions const * options) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuQueueOnSubmittedWorkDone(WGPUQueue queue, WGPUQueueWorkDoneCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuQueueOnSubmittedWorkDone2(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuQueueOnSubmittedWorkDoneF(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuQueueSetLabel(WGPUQueue queue, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuQueueOnSubmittedWorkDone(WGPUQueue queue, WGPUQueueWorkDoneCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuQueueSetLabel(WGPUQueue queue, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQueueSubmit(WGPUQueue queue, size_t commandCount, WGPUCommandBuffer const * commands) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQueueWriteBuffer(WGPUQueue queue, WGPUBuffer buffer, uint64_t bufferOffset, void const * data, size_t size) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQueueWriteTexture(WGPUQueue queue, WGPUImageCopyTexture const * destination, void const * data, size_t dataSize, WGPUTextureDataLayout const * dataLayout, WGPUExtent3D const * writeSize) WGPU_FUNCTION_ATTRIBUTE; @@ -4067,7 +3949,7 @@ WGPU_EXPORT void wgpuQueueAddRef(WGPUQueue queue) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuQueueRelease(WGPUQueue queue) WGPU_FUNCTION_ATTRIBUTE; // Methods of RenderBundle -WGPU_EXPORT void wgpuRenderBundleSetLabel(WGPURenderBundle renderBundle, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderBundleSetLabel(WGPURenderBundle renderBundle, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleAddRef(WGPURenderBundle renderBundle) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleRelease(WGPURenderBundle renderBundle) WGPU_FUNCTION_ATTRIBUTE; @@ -4077,12 +3959,12 @@ WGPU_EXPORT void wgpuRenderBundleEncoderDrawIndexed(WGPURenderBundleEncoder rend WGPU_EXPORT void wgpuRenderBundleEncoderDrawIndexedIndirect(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderDrawIndirect(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPURenderBundle wgpuRenderBundleEncoderFinish(WGPURenderBundleEncoder renderBundleEncoder, WGPU_NULLABLE WGPURenderBundleDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderBundleEncoderInsertDebugMarker(WGPURenderBundleEncoder renderBundleEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderBundleEncoderInsertDebugMarker(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderPopDebugGroup(WGPURenderBundleEncoder renderBundleEncoder) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderBundleEncoderPushDebugGroup(WGPURenderBundleEncoder renderBundleEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderBundleEncoderPushDebugGroup(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderSetBindGroup(WGPURenderBundleEncoder renderBundleEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderSetIndexBuffer(WGPURenderBundleEncoder renderBundleEncoder, WGPUBuffer buffer, WGPUIndexFormat format, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderBundleEncoderSetLabel(WGPURenderBundleEncoder renderBundleEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderBundleEncoderSetLabel(WGPURenderBundleEncoder renderBundleEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderSetPipeline(WGPURenderBundleEncoder renderBundleEncoder, WGPURenderPipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderSetVertexBuffer(WGPURenderBundleEncoder renderBundleEncoder, uint32_t slot, WGPU_NULLABLE WGPUBuffer buffer, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderBundleEncoderAddRef(WGPURenderBundleEncoder renderBundleEncoder) WGPU_FUNCTION_ATTRIBUTE; @@ -4097,14 +3979,16 @@ WGPU_EXPORT void wgpuRenderPassEncoderDrawIndirect(WGPURenderPassEncoder renderP WGPU_EXPORT void wgpuRenderPassEncoderEnd(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderEndOcclusionQuery(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderExecuteBundles(WGPURenderPassEncoder renderPassEncoder, size_t bundleCount, WGPURenderBundle const * bundles) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderPassEncoderInsertDebugMarker(WGPURenderPassEncoder renderPassEncoder, char const * markerLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPassEncoderInsertDebugMarker(WGPURenderPassEncoder renderPassEncoder, WGPUStringView markerLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPassEncoderMultiDrawIndexedIndirect(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset, uint32_t maxDrawCount, WGPU_NULLABLE WGPUBuffer drawCountBuffer, uint64_t drawCountBufferOffset) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPassEncoderMultiDrawIndirect(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer indirectBuffer, uint64_t indirectOffset, uint32_t maxDrawCount, WGPU_NULLABLE WGPUBuffer drawCountBuffer, uint64_t drawCountBufferOffset) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderPixelLocalStorageBarrier(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderPopDebugGroup(WGPURenderPassEncoder renderPassEncoder) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderPassEncoderPushDebugGroup(WGPURenderPassEncoder renderPassEncoder, char const * groupLabel) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPassEncoderPushDebugGroup(WGPURenderPassEncoder renderPassEncoder, WGPUStringView groupLabel) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetBindGroup(WGPURenderPassEncoder renderPassEncoder, uint32_t groupIndex, WGPU_NULLABLE WGPUBindGroup group, size_t dynamicOffsetCount, uint32_t const * dynamicOffsets) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetBlendConstant(WGPURenderPassEncoder renderPassEncoder, WGPUColor const * color) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetIndexBuffer(WGPURenderPassEncoder renderPassEncoder, WGPUBuffer buffer, WGPUIndexFormat format, uint64_t offset, uint64_t size) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderPassEncoderSetLabel(WGPURenderPassEncoder renderPassEncoder, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPassEncoderSetLabel(WGPURenderPassEncoder renderPassEncoder, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetPipeline(WGPURenderPassEncoder renderPassEncoder, WGPURenderPipeline pipeline) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetScissorRect(WGPURenderPassEncoder renderPassEncoder, uint32_t x, uint32_t y, uint32_t width, uint32_t height) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPassEncoderSetStencilReference(WGPURenderPassEncoder renderPassEncoder, uint32_t reference) WGPU_FUNCTION_ATTRIBUTE; @@ -4116,30 +4000,28 @@ WGPU_EXPORT void wgpuRenderPassEncoderRelease(WGPURenderPassEncoder renderPassEn // Methods of RenderPipeline WGPU_EXPORT WGPUBindGroupLayout wgpuRenderPipelineGetBindGroupLayout(WGPURenderPipeline renderPipeline, uint32_t groupIndex) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuRenderPipelineSetLabel(WGPURenderPipeline renderPipeline, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuRenderPipelineSetLabel(WGPURenderPipeline renderPipeline, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPipelineAddRef(WGPURenderPipeline renderPipeline) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuRenderPipelineRelease(WGPURenderPipeline renderPipeline) WGPU_FUNCTION_ATTRIBUTE; // Methods of Sampler -WGPU_EXPORT void wgpuSamplerSetLabel(WGPUSampler sampler, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSamplerSetLabel(WGPUSampler sampler, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSamplerAddRef(WGPUSampler sampler) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSamplerRelease(WGPUSampler sampler) WGPU_FUNCTION_ATTRIBUTE; // Methods of ShaderModule -WGPU_EXPORT void wgpuShaderModuleGetCompilationInfo(WGPUShaderModule shaderModule, WGPUCompilationInfoCallback callback, void * userdata) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuShaderModuleGetCompilationInfo2(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo2 callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUFuture wgpuShaderModuleGetCompilationInfoF(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuShaderModuleSetLabel(WGPUShaderModule shaderModule, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUFuture wgpuShaderModuleGetCompilationInfo(WGPUShaderModule shaderModule, WGPUCompilationInfoCallbackInfo callbackInfo) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuShaderModuleSetLabel(WGPUShaderModule shaderModule, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuShaderModuleAddRef(WGPUShaderModule shaderModule) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuShaderModuleRelease(WGPUShaderModule shaderModule) WGPU_FUNCTION_ATTRIBUTE; // Methods of SharedBufferMemory -WGPU_EXPORT WGPUBool wgpuSharedBufferMemoryBeginAccess(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuSharedBufferMemoryBeginAccess(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBuffer wgpuSharedBufferMemoryCreateBuffer(WGPUSharedBufferMemory sharedBufferMemory, WGPU_NULLABLE WGPUBufferDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUBool wgpuSharedBufferMemoryEndAccess(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuSharedBufferMemoryEndAccess(WGPUSharedBufferMemory sharedBufferMemory, WGPUBuffer buffer, WGPUSharedBufferMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuSharedBufferMemoryGetProperties(WGPUSharedBufferMemory sharedBufferMemory, WGPUSharedBufferMemoryProperties * properties) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBool wgpuSharedBufferMemoryIsDeviceLost(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuSharedBufferMemorySetLabel(WGPUSharedBufferMemory sharedBufferMemory, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSharedBufferMemorySetLabel(WGPUSharedBufferMemory sharedBufferMemory, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedBufferMemoryAddRef(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedBufferMemoryRelease(WGPUSharedBufferMemory sharedBufferMemory) WGPU_FUNCTION_ATTRIBUTE; @@ -4149,12 +4031,12 @@ WGPU_EXPORT void wgpuSharedFenceAddRef(WGPUSharedFence sharedFence) WGPU_FUNCTIO WGPU_EXPORT void wgpuSharedFenceRelease(WGPUSharedFence sharedFence) WGPU_FUNCTION_ATTRIBUTE; // Methods of SharedTextureMemory -WGPU_EXPORT WGPUBool wgpuSharedTextureMemoryBeginAccess(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuSharedTextureMemoryBeginAccess(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryBeginAccessDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUTexture wgpuSharedTextureMemoryCreateTexture(WGPUSharedTextureMemory sharedTextureMemory, WGPU_NULLABLE WGPUTextureDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUBool wgpuSharedTextureMemoryEndAccess(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuSharedTextureMemoryEndAccess(WGPUSharedTextureMemory sharedTextureMemory, WGPUTexture texture, WGPUSharedTextureMemoryEndAccessState * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuSharedTextureMemoryGetProperties(WGPUSharedTextureMemory sharedTextureMemory, WGPUSharedTextureMemoryProperties * properties) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBool wgpuSharedTextureMemoryIsDeviceLost(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuSharedTextureMemorySetLabel(WGPUSharedTextureMemory sharedTextureMemory, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSharedTextureMemorySetLabel(WGPUSharedTextureMemory sharedTextureMemory, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedTextureMemoryAddRef(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedTextureMemoryRelease(WGPUSharedTextureMemory sharedTextureMemory) WGPU_FUNCTION_ATTRIBUTE; @@ -4162,19 +4044,12 @@ WGPU_EXPORT void wgpuSharedTextureMemoryRelease(WGPUSharedTextureMemory sharedTe WGPU_EXPORT void wgpuSurfaceConfigure(WGPUSurface surface, WGPUSurfaceConfiguration const * config) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuSurfaceGetCapabilities(WGPUSurface surface, WGPUAdapter adapter, WGPUSurfaceCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfaceGetCurrentTexture(WGPUSurface surface, WGPUSurfaceTexture * surfaceTexture) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUTextureFormat wgpuSurfaceGetPreferredFormat(WGPUSurface surface, WGPUAdapter adapter) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfacePresent(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuSurfaceSetLabel(WGPUSurface surface, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfaceUnconfigure(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfaceAddRef(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSurfaceRelease(WGPUSurface surface) WGPU_FUNCTION_ATTRIBUTE; -// Methods of SwapChain -WGPU_EXPORT WGPUTexture wgpuSwapChainGetCurrentTexture(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUTextureView wgpuSwapChainGetCurrentTextureView(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuSwapChainPresent(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuSwapChainAddRef(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuSwapChainRelease(WGPUSwapChain swapChain) WGPU_FUNCTION_ATTRIBUTE; - // Methods of Texture WGPU_EXPORT WGPUTextureView wgpuTextureCreateErrorView(WGPUTexture texture, WGPU_NULLABLE WGPUTextureViewDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUTextureView wgpuTextureCreateView(WGPUTexture texture, WGPU_NULLABLE WGPUTextureViewDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; @@ -4185,14 +4060,14 @@ WGPU_EXPORT WGPUTextureFormat wgpuTextureGetFormat(WGPUTexture texture) WGPU_FUN WGPU_EXPORT uint32_t wgpuTextureGetHeight(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT uint32_t wgpuTextureGetMipLevelCount(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT uint32_t wgpuTextureGetSampleCount(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUTextureUsageFlags wgpuTextureGetUsage(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUTextureUsage wgpuTextureGetUsage(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT uint32_t wgpuTextureGetWidth(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT void wgpuTextureSetLabel(WGPUTexture texture, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuTextureSetLabel(WGPUTexture texture, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuTextureAddRef(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuTextureRelease(WGPUTexture texture) WGPU_FUNCTION_ATTRIBUTE; // Methods of TextureView -WGPU_EXPORT void wgpuTextureViewSetLabel(WGPUTextureView textureView, char const * label) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT void wgpuTextureViewSetLabel(WGPUTextureView textureView, WGPUStringView label) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuTextureViewAddRef(WGPUTextureView textureView) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuTextureViewRelease(WGPUTextureView textureView) WGPU_FUNCTION_ATTRIBUTE; From 254e4ea68718bff95b36ceaf6f7d7afb7288755f Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:16:40 -0500 Subject: [PATCH 30/44] remove legacy dir from experimental --- experimental/legacy/README.md | 1 - experimental/legacy/audio/Makefile | 46 --- experimental/legacy/audio/run.cpp | 148 --------- experimental/legacy/transformer/Makefile | 27 -- experimental/legacy/transformer/run.cpp | 247 --------------- experimental/legacy/transformer/shaders.h | 221 -------------- .../legacy/transformer/test_kernels.cpp | 283 ------------------ 7 files changed, 973 deletions(-) delete mode 100644 experimental/legacy/README.md delete mode 100644 experimental/legacy/audio/Makefile delete mode 100644 experimental/legacy/audio/run.cpp delete mode 100644 experimental/legacy/transformer/Makefile delete mode 100644 experimental/legacy/transformer/run.cpp delete mode 100644 experimental/legacy/transformer/shaders.h delete mode 100644 experimental/legacy/transformer/test_kernels.cpp diff --git a/experimental/legacy/README.md b/experimental/legacy/README.md deleted file mode 100644 index 5b8629e..0000000 --- a/experimental/legacy/README.md +++ /dev/null @@ -1 +0,0 @@ -Code in this directory is not actively developed and will probably be removed or rewritten in the future. diff --git a/experimental/legacy/audio/Makefile b/experimental/legacy/audio/Makefile deleted file mode 100644 index eb1a4ef..0000000 --- a/experimental/legacy/audio/Makefile +++ /dev/null @@ -1,46 +0,0 @@ -CXX=clang++ -GPUCPP ?= $(PWD)/../.. -LIBDIR ?= $(GPUCPP)/third_party/lib -LIBSPEC ?= . $(GPUCPP)/source -NUM_JOBS?=$(shell nproc) -TARGET=microphone -ifeq ($(shell $(CXX) -std=c++17 -x c++ -E -include array - < /dev/null > /dev/null 2>&1 ; echo $$?),0) - STDLIB := -else - STDLIB := -stdlib=libc++ -endif - -PA_FLAG=-I$(GPUCPP)/third_party/headers/portaudio -PA_LIB=-lportaudio.2 - -FLAGS=-std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib $(PA_FLAG) run.cpp -ldl -ldawn $(PA_LIB) - -run: ./build/$(TARGET) dawnlib - $(LIBSPEC) && ./build/$(TARGET) - -dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libdawn.so $(GPUCPP)/third_party/lib/libdawn.dylib),,run_setup) - -run_setup: check-python - cd $(GPUCPP) && python3 setup.py - -build/$(TARGET): run.cpp check-portaudio - mkdir -p build && $(CXX) $(FLAGS) -DNDEBUG -o ./build/$(TARGET) - -clean: - read -r -p "This will delete the contents of build/*. Are you sure? [CTRL-C to abort] " response && rm -rf build/* - -check-python: - @command -v python3 >/dev/null 2>&1 || { echo >&2 "Python needs to be installed and in your path."; exit 1; } - -check-portaudio: - # check if portaudio.2.dylib or portaudio.so is in the third_party/lib directory - @echo "Checking for portaudio library..." - @if [ ! -f $(GPUCPP)/third_party/lib/libportaudio.2.dylib ] && [ ! -f $(GPUCPP)/third_party/lib/libportaudio.so ]; then \ - echo "Portaudio library not found. Please install portaudio and place the library in the third_party/lib directory."; \ - exit 1; \ - fi - # check header file third_party/headers/portaudio/portaudio.h - if [ ! -f $(GPUCPP)/third_party/headers/portaudio/portaudio.h ]; then \ - echo "Portaudio header file not found. Please install portaudio and place the header file in the third_party/headers/portaudio directory."; \ - exit 1; \ - fi diff --git a/experimental/legacy/audio/run.cpp b/experimental/legacy/audio/run.cpp deleted file mode 100644 index bfbf0db..0000000 --- a/experimental/legacy/audio/run.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include -#include -#include - -#include "gpu.hpp" -#include "portaudio.h" - -#define SAMPLE_RATE (44100) -#define PA_SAMPLE_TYPE paFloat32 -#define FRAMES_PER_BUFFER (1024) - -typedef float SAMPLE; - -struct Buffer { - float *buffer; // non-owning pointer into buffer - size_t size; -}; - - -float sigmoid(float x) { - return 1 / (1 + std::exp(-4 * x)) - 0.5; -} - -static int gNumNoInputs = 0; -static int callback(const void *inputBuffer, void *outputBuffer, - unsigned long framesPerBuffer, - const PaStreamCallbackTimeInfo *timeInfo, - PaStreamCallbackFlags statusFlags, void *userData) { - SAMPLE* out = (SAMPLE *)outputBuffer; - const SAMPLE *in = (const SAMPLE *)inputBuffer; - (void)timeInfo; /* Prevent unused variable warnings. */ - (void)statusFlags; - (void)userData; - - Buffer *buffer = reinterpret_cast(userData); - size_t timeIndex = (timeInfo->currentTime /* in seconds */ * SAMPLE_RATE) / - FRAMES_PER_BUFFER * FRAMES_PER_BUFFER; - - int scale = 1; - size_t reverseIndex = buffer->size - scale * timeIndex; - - if (inputBuffer == NULL) { - for (int i = 0; i < framesPerBuffer; i++) { - *out++ = sigmoid(0); - *out++ = sigmoid(0); - } - gNumNoInputs += 1; - } else { - for (int i = 0; i < framesPerBuffer; i++) { - size_t playHead0 = (timeIndex + i) % buffer->size; - size_t playHead1 = (reverseIndex - scale * i ) % buffer->size; // reverse playhead - - SAMPLE sample = *in++; /* MONO input */ - - float value = sigmoid(0.5 * sample + 0.5 * buffer->buffer[playHead1]); - - *out++ = value; /* LEFT */ - *out++ = value; /* RIGHT */ - buffer->buffer[playHead0] = sample; - - printf("\033[H\033[H\n\nTime = %f\nplayHead0 = %zu\nplayHead1 (reverse) " - "index=%zu\noutput value=%.2f\ninput value=%.2f\nplayhead1 value =%.2f\n", - timeInfo->currentTime, playHead0, playHead1, value, - sample, - buffer->buffer[playHead1]); - } - } - - return paContinue; -} - -void check(bool condition, const char *message) { - if (!condition) { - fprintf(stderr, "%s\n", message); - Pa_Terminate(); - exit(1); - } -} - -int main(void) { - PaStreamParameters inputParameters, outputParameters; - PaStream *stream; - PaError err; - - printf("\033[H\033[J"); - - printf("Turn down volume before starting.\nPress Enter to start."); - getchar(); - - printf("\033[H\033[J"); - - err = Pa_Initialize(); - check(err == paNoError, "Error: Pa_Initialize failed."); - - // Setup device - - inputParameters = { - .device = Pa_GetDefaultInputDevice(), - .channelCount = 1, - .sampleFormat = PA_SAMPLE_TYPE, - .suggestedLatency = - Pa_GetDeviceInfo(Pa_GetDefaultInputDevice())->defaultLowInputLatency, - .hostApiSpecificStreamInfo = NULL}; - - check(inputParameters.device != paNoDevice, - "Error: No default input device."); - - outputParameters = {.device = Pa_GetDefaultOutputDevice(), - .channelCount = 2, - .sampleFormat = PA_SAMPLE_TYPE, - .suggestedLatency = - Pa_GetDeviceInfo(Pa_GetDefaultOutputDevice()) - ->defaultLowOutputLatency, - .hostApiSpecificStreamInfo = NULL}; - - if (outputParameters.device == paNoDevice) { - fprintf(stderr, "Error: No default output device.\n"); - exit(1); - } - - constexpr size_t kBufferTime = 4; // seconds - std::array bufferAlloc; - // zero - for (size_t i = 0; i < bufferAlloc.size(); i++) { - bufferAlloc[i] = 0; - } - - Buffer buffer = {bufferAlloc.data(), bufferAlloc.size()}; - - err = Pa_OpenStream(&stream, &inputParameters, &outputParameters, SAMPLE_RATE, - FRAMES_PER_BUFFER, 0, - /* paClipOff, */ /* we won't output out of range samples - so don't bother clipping them */ - callback, reinterpret_cast(&buffer)); - check(err == paNoError, "Error: Pa_OpenStream failed."); - - err = Pa_StartStream(stream); - check(err == paNoError, "Error: Pa_StartStream failed."); - - printf("Hit Enter to stop program.\n"); - getchar(); - err = Pa_CloseStream(stream); - check(err == paNoError, "Error: Pa_CloseStream failed."); - - printf("Finished. gNumNoInputs = %d\n", gNumNoInputs); - Pa_Terminate(); - return 0; -} diff --git a/experimental/legacy/transformer/Makefile b/experimental/legacy/transformer/Makefile deleted file mode 100644 index 8102130..0000000 --- a/experimental/legacy/transformer/Makefile +++ /dev/null @@ -1,27 +0,0 @@ -CXX=clang++ -GPUCPP ?= $(PWD)/../.. -LIBDIR ?= $(GPUCPP)/third_party/lib -LIBSPEC ?= . $(GPUCPP)/source -NUM_JOBS?=$(shell nproc) -TARGET=transformer -CODEPATH = find . ../../utils ../../ -maxdepth 1 -type f - -tests: - mkdir -p build && $(CXX) -std=c++17 -I$(GPUCPP) -I$(GPUCPP)/utils -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib test_kernels.cpp -ldawn -ldl -o ./build/run_tests && $(LIBSPEC) && ./build/run_tests - -run: ./build/$(TARGET) - $(LIBSPEC) && ./build/$(TARGET) - -# Use clang -v to see the include paths -build/$(TARGET): run.cpp - mkdir -p build && $(CXX) -std=c++17 -I$(GPUCPP) -I$(GPUCPP)/utils -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib run.cpp -ldl -ldawn -o ./build/$(TARGET) - -watch: - @command -v entr >/dev/null 2>&1 || { echo >&2 "Please install entr with 'brew install entr' or 'sudo apt-get install entr'"; exit 1; } - mkdir -p build && $(CODEPATH) | entr -s "$(LIBSPEC) && rm -f ./build/$(TARGET) && make -j$(NUM_JOBS) ./build/$(TARGET) && ./build/$(TARGET)" - -clean: - read -r -p "This will delete the contents of build/*. Are you sure? [CTRL-C to abort] " response && rm -rf build/* - -watch-tests: - mkdir -p build && ls | entr -s "$(CXX) -std=c++17 -I$(GPUCPP) -I$(GPUCPP)/utils -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib test_kernels.cpp -ldawn -ldl -o ./build/run_tests && $(LIBSPEC) && ./build/run_tests" diff --git a/experimental/legacy/transformer/run.cpp b/experimental/legacy/transformer/run.cpp deleted file mode 100644 index c7c3da5..0000000 --- a/experimental/legacy/transformer/run.cpp +++ /dev/null @@ -1,247 +0,0 @@ -#include "gpu.hpp" -#include "utils/array_utils.hpp" -#include "utils/logging.hpp" -#include - -#include "llmc/reference_impls.h" - -using namespace gpu; - -static const char *kShaderGelu = R"( -const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) -@group(0) @binding(0) var inp: array<{{precision}}>; -@group(0) @binding(1) var out: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let i: u32 = GlobalInvocationID.x; - if (i < arrayLength(&inp)) { - let x: f32 = inp[i]; - // select is more stable for larger values of x - out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR - * (x + .044715 * x * x * x))), x, x > 10.0); - } -} -)"; - -static const char *kMLPGate = R"( -const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) -@group(0) @binding(0) var gate: array<{{precision}}>; -@group(0) @binding(0) var gated: array<{{precision}}>; -@group(0) @binding(1) var out: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let i: u32 = GlobalInvocationID.x; - if (i < arrayLength(&a)) { - let x: f32 = gate[i]; - out[i] = gated[i] * select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR - * (x + .044715 * x * x * x))), x, x > 10.0); - } -} -)"; - -static const char *kShaderMatmul1 = R"( -@group(0) @binding(0) var A: array; -@group(0) @binding(1) var B: array; -@group(0) @binding(2) var C: array; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) global_id : vec3) { - let row = global_id.y; // row and column of C - let col = global_id.x; - if (row >= {{M}} || col >= {{N}}) { - return; - } - var total: f32 = A[row * {{K}}] * B[col * {{K}}]; // assumes size >= 1 - for (var k = 1u; k < {{K}}; k = k + 1u) { - // B is stored as B^T, effectively column-major - total += A[row * {{K}} + k] * B[col * {{K}} + k]; - } - C[row * {{N}} + col] = total; -} -)"; - -static const char *kShaderMatmul2 = R"( -@group(0) @binding(0) var A: array; -@group(0) @binding(1) var B: array; -@group(0) @binding(2) var C: array; -var tileA: array; -var tileB: array; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) global_id : vec3, - @builtin(local_invocation_id) local_id : vec3, - @builtin(local_invocation_index) local_index : u32, - @builtin(workgroup_id) group_id : vec3) { - let row = global_id.y; - let col = global_id.x; - let localRow = local_id.y; - let localCol = local_id.x; - if (row >= {{M}} || col >= {{N}}) { - return; - } - var total: f32 = 0; - for (var tileIndex = 0u; tileIndex < {{K}} / {{workgroupSizeX}}; tileIndex = tileIndex + 1u) { - // TODO - } - C[row * {{N}} + col] = total; -} -)"; - -struct Transformer { - Tensor qkv; // modelDim * 3 * qkvDim - Tensor rmsNormPre; // modelDim - Tensor rmsNormPost; // modelDim - Tensor out; // 3 * qkvDim * modelDim - Tensor mlp1; // modelDim * (2 * hidden_width * modelDim) - Tensor mlp2; // modelDim * (2 * hidden_width * modelDim) -}; - -struct Activations { - Tensor qkv; // batchSize * 3 * nHeads * qkvDim - Tensor qk; - Tensor att; -}; - -struct KVCache { - Tensor keyCache; - Tensor valueCache; -}; - -void createTransformer(Context &ctx, size_t modelDim, size_t qkvDim, - size_t nHeads, size_t batchSize, size_t seqLen, - size_t hiddenWidth, Transformer &transformer, - Activations &activations, KVCache &kvCache) { - std::mt19937 gen(314159); - transformer = { - .qkv = createTensor(ctx, Shape{3 * nHeads * qkvDim, modelDim}, - kf32), // column-major - .rmsNormPre = createTensor(ctx, Shape{modelDim}, kf32), - .rmsNormPost = createTensor(ctx, Shape{modelDim}, kf32), - .out = createTensor(ctx, Shape{3 * qkvDim, modelDim}, kf32), - .mlp1 = createTensor(ctx, Shape{modelDim, 2 * hiddenWidth}, kf32), - .mlp2 = createTensor(ctx, Shape{modelDim, 2 * hiddenWidth}, kf32), - }; - - // Initialize values - std::unique_ptr qkvInit(new float[modelDim * 3 * nHeads * qkvDim]); - // randint(qkvInit.get(), size(transformer.qkv.shape), gen, -2, 2); - range(qkvInit.get(), size(transformer.qkv.shape), 0.0); - LOG(kDefLog, kInfo, "%s", - show(qkvInit.get(), transformer.qkv.shape[0], - transformer.qkv.shape[1], "QKV Weights") - .c_str()); - toGPU(ctx, qkvInit.get(), transformer.qkv); - - activations = { - .qkv = createTensor(ctx, Shape{batchSize * 3 * nHeads * qkvDim}, kf32), - .qk = createTensor(ctx, Shape{batchSize * nHeads}, kf32), - .att = createTensor(ctx, Shape{batchSize * nHeads}, kf32)}; - - kvCache = { - .keyCache = createTensor(ctx, Shape{seqLen, qkvDim}, kf32), - .valueCache = createTensor(ctx, Shape{seqLen, qkvDim}, kf32), - }; - std::unique_ptr keyCacheInit(new float[seqLen * qkvDim]); - std::unique_ptr valueCacheInit(new float[seqLen * qkvDim]); - range(keyCacheInit.get(), size(kvCache.keyCache.shape), 0.0); - range(valueCacheInit.get(), size(kvCache.valueCache.shape), 0.0); - toGPU(ctx, keyCacheInit.get(), kvCache.keyCache); - toGPU(ctx, valueCacheInit.get(), kvCache.valueCache); -} - -inline KernelCode createMatmul(const char *shaderTemplate, const size_t M, - const size_t K, const size_t N, - const Shape &workgroupSize = {256, 1, 1}, - NumType precision = kf32) { - std::string codeString(shaderTemplate); - replaceAll(codeString, "{{workgroupSize}}", toString(workgroupSize)); - replaceAll(codeString, "{{workgroupSizeX}}", - std::to_string(workgroupSize[0])); - replaceAll(codeString, "{{workgroupSizeY}}", - std::to_string(workgroupSize[1])); - replaceAll(codeString, "{{workgroupSizeZ}}", - std::to_string(workgroupSize[2])); - replaceAll(codeString, "{{precision}}", toString(precision)); - replaceAll(codeString, "{{M}}", std::to_string(M)); - replaceAll(codeString, "{{K}}", std::to_string(K)); - replaceAll(codeString, "{{N}}", std::to_string(N)); - // LOG(kDefLog, kInfo, "Shader code:\n%s\n", codeString.c_str()); - return KernelCode{codeString, workgroupSize}; -} - -int main() { - printf("\033[2J\033[1;1H"); - Context ctx = createContext(); - static constexpr size_t seqLen = 24; - static constexpr size_t batchSize = 1; - static constexpr size_t modelDim = 4; // 3072; - static constexpr size_t hiddenWidth = modelDim * 2; - static constexpr size_t qkvDim = 3; // 256; - static constexpr size_t nHeads = 8; - std::mt19937 gen(314); - - Transformer transformer; - Activations activations; - KVCache kvcache; - LOG(kDefLog, kInfo, "Initializing transformer, allocating GPU buffers ...\n"); - createTransformer(ctx, modelDim, qkvDim, nHeads, batchSize, seqLen, - hiddenWidth, transformer, activations, kvcache); - - std::array inputArr; - randint(inputArr, gen, -2, 2); - LOG(kDefLog, kInfo, "%s", - show(inputArr.data(), 1, modelDim, "Input").c_str()); - Tensor input = createTensor(ctx, Shape{modelDim}, kf32, inputArr.data()); - - /* QKV Projection */ - - LOG(kDefLog, kInfo, "QKV Projection"); - { - KernelCode matmul = createMatmul(kShaderMatmul1, /*M*/ batchSize, - /*K*/ modelDim, /*N*/ 3 * qkvDim); - Kernel qkv = createKernel( - ctx, matmul, Bindings{input, transformer.qkv, activations.qkv}, - /*nthreads*/ {modelDim, 1, 1}); - std::promise promise; - std::future future = promise.get_future(); - dispatchKernel(ctx, qkv, promise); - wait(ctx, future); - std::array outputArr; - toCPU(ctx, activations.qkv, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "Output: %s", - show(outputArr.data(), 1, 3 * qkvDim, "QKV Output").c_str()); - std::array outputRefArr; - std::array weightsArr; - toCPU(ctx, transformer.qkv, weightsArr.data(), sizeof(weightsArr)); - ref::matmul_forward_cpu( - outputRefArr.data(), inputArr.data(), weightsArr.data(), nullptr, - /* batch */ 1, /* T */ 1, /* C */ modelDim, /* OC */ 3 * qkvDim); - LOG(kDefLog, kInfo, "Reference Output: %s", - show(outputRefArr.data(), 1, 3 * qkvDim, - "QKV Output (Reference)") - .c_str()); - LOG(kDefLog, kInfo, - isclose(outputArr.data(), outputRefArr.data(), 3 * qkvDim) ? "PASS" - : "FAIL"); - } - - /* QK Dot Products */ - - LOG(kDefLog, kInfo, "QK Dot Product"); - { - KernelCode dot = createMatmul(kShaderMatmul1, /*M*/ batchSize * nHeads, - /*K*/ qkvDim, /*M*/ 1); - - /* - // TODO(avh): need to pass in activation views that don't overlap here - Kernel qk = createKernel( - ctx, dot, Bindings{activations.qkv, activations.qkv, activations.qk}, - {batchSize * nHeads, 1, 1}); - */ - // TODO(avh): check nThreads - } - - LOG(kDefLog, kInfo, "Done"); -} diff --git a/experimental/legacy/transformer/shaders.h b/experimental/legacy/transformer/shaders.h deleted file mode 100644 index f218807..0000000 --- a/experimental/legacy/transformer/shaders.h +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef KERNELS_H -#define KERNELS_H - -#include "gpu.hpp" - -namespace gpu { - - -static const char *kShaderGelu = R"( -const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) -@group(0) @binding(0) var inp: array<{{precision}}>; -@group(0) @binding(1) var out: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let i: u32 = GlobalInvocationID.x; - if (i < arrayLength(&inp)) { - let x: f32 = inp[i]; - // select is more stable for larger values of x - out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR - * (x + .044715 * x * x * x))), x, x > 10.0); - } -} -)"; - -static const char *kShaderHadamard = R"( -@group(0) @binding(0) var A: array<{{precision}}>; -@group(0) @binding(1) var B: array<{{precision}}>; -@group(0) @binding(2) var C: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let idx = GlobalInvocationID.x; - if (idx < arrayLength(&A)) { - C[idx] = A[idx] * B[idx]; - } -} -)"; - -static const char *kShaderResidual = R"( -@group(0) @binding(0) var A: array<{{precision}}>; -@group(0) @binding(1) var B: array<{{precision}}>; -@group(0) @binding(2) var C: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let idx = GlobalInvocationID.x; - if (idx < arrayLength(&A)) { - C[idx] = A[idx] + B[idx]; - } -} -)"; - -/* LayerNorm - * v1: - * - No caching mean/std for backwards - * - No parallel reduction - * - Simple 1 thread for each 1..N - */ -// TODO(avh): Allow larger virtual 1D workgroups by making use of y / z -// dimensions and calculating the threadID accordingly. -static const char *kShaderLayerNorm1 = R"( -@group(0) @binding(0) var inp: array<{{precision}}>; -@group(0) @binding(1) var weight: array<{{precision}}>; -@group(0) @binding(2) var bias: array<{{precision}}>; -@group(0) @binding(3) var out: array<{{precision}}>; -@group(0) @binding(4) var params: Params; - -struct Params { - N: u32, - C: u32, -}; - -@compute @workgroup_size({{workgroupSize}}) -fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3, - @builtin(local_invocation_id) LocalInvocationID: vec3, - @builtin(workgroup_id) WorkgroupID: vec3) { - let idx: u32 = GlobalInvocationID.x; - - if (idx >= params.N) { return; } - - let C: u32 = params.C; - - // Calculate mean - var sum: f32 = 0.0; - for (var i: u32 = 0; i < C; i = i + 1) { - sum += inp[idx * C + i]; - } - let mean_val: f32 = sum / f32(C); - - // Calculate rstd - sum = 0.0; - for (var i: u32 = 0; i < C; i = i + 1) { - let diff: f32 = inp[idx * C + i] - mean_val; - sum += diff * diff; - } - let rstd_val: f32 = 1.0 / sqrt(sum / f32(C) + 1e-5); - - for (var i: u32 = 0; i < C; i = i + 1) { - let n: f32 = rstd_val * (inp[idx * C + i] - mean_val); - out[idx * C + i] = n * weight[i] + bias[i]; - } -} -)"; - -// matrix multiplication (naive implementation) -static const char *kShaderMatMul1 = R"( -@group(0) @binding(0) var A: array<{{precision}}>; -@group(0) @binding(1) var B: array<{{precision}}>; -@group(0) @binding(2) var C: array<{{precision}}>; -@compute @workgroup_size({{workgroupSize}}) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let i: u32 = GlobalInvocationID.x / {{N}}; - let j: u32 = GlobalInvocationID.x % {{N}}; - if (i < {{M}} && j < {{N}}) { - var sum: f32 = 0.0; - for (var k: u32 = 0; k < {{K}}; k = k + 1) { - sum = sum + A[i * {{K}} + k] * B[k * {{N}} + j]; - } - C[i * {{N}} + j] = sum; - } -} -)"; - -static const char *kShaderMatMul2 = R"( -@group(0) @binding(0) var A: array; -@group(0) @binding(1) var B: array; -@group(0) @binding(2) var C: array; -var tileA: array; -var tileB: array; -@compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1) -fn matmul( - @builtin(global_invocation_id) global_id : vec3, - @builtin(local_invocation_id) local_id : vec3, - @builtin(workgroup_id) workgroup_id : vec3 -) { - let row = global_id.x; - let col = global_id.y; - if (row >= {{M}} || col >= {{N}}) { - return; - } - var result: f32 = 0.0; - for (var i = 0u; i < {{K}}; i = i + workgroupSizeX) { - // Load tiles into shared memory - tileA[local_id.y][local_id.x] = A[row][i + local_id.x]; - tileB[local_id.y][local_id.x] = B[i + local_id.y][col]; - // Synchronize to make sure the tile is loaded - workgroupBarrier(); - // Perform partial dot product for the current tile - for (var k = 0u; k < workgroupSizeX; k = k + 1u) { - result = result + tileA[local_id.y][k] * tileB[k][local_id.x]; - } - // Synchronize before loading the next tile - workgroupBarrier(); - } - C[row][col] = result; -} -)"; - -/* Generates KernelCode instance for all matmul kernels - pass in - * the template code via `shaderRaw`. - * - * This is intended to be run ahead of time, so is not performance critical. - * */ -KernelCode MatmulShader(size_t workgroupSize, const char *shaderRaw, - NumType precision, size_t M, size_t K, size_t N) { - KernelCode shader = {shaderRaw, workgroupSize, precision}; - replaceAll(shader.data, "{{M}}", std::to_string(M)); - replaceAll(shader.data, "{{K}}", std::to_string(K)); - replaceAll(shader.data, "{{N}}", std::to_string(N)); - return shader; -} - -/* Softmax - * v1: - * - equivalent to naive softmax with one thread per row - */ -static const char *kShaderSoftmax1 = R"( -@group(0) @binding(0) var inp : array<{{precision}}>; -@group(0) @binding(1) var out : array<{{precision}}>; -@group(0) @binding(2) var params : Params; -struct Params { - N: u32, - C: u32, -}; -const NEG_INFINITY: f32 = -3.0e38; // WGSL has problem representing -3.4028235e+38 -@compute @workgroup_size({{workgroupSize}}) -fn main(@builtin(global_invocation_id) global_id : vec3) { - let N : u32 = params.N; - let C : u32 = params.C; - let i : u32 = global_id.x; - if (i < N) { - let inp_row_start : u32 = i * C; - var maxval : f32 = NEG_INFINITY; - // Find the maximum value in the row - for (var j : u32 = 0u; j < C; j++) { - let val : f32 = inp[inp_row_start + j]; - if (val > maxval) { - maxval = val; - } - } - var sum : f32 = 0.0; - // Compute the exponentials and sum them - for (var j : u32 = 0u; j < C; j++) { - let exp_val : f32 = exp(inp[inp_row_start + j] - maxval); - out[inp_row_start + j] = exp_val; - sum += exp_val; - } - // Normalize the row to get probabilities - let norm : f32 = 1.0f / sum; - for (var j : u32 = 0u; j < C; j++) { - out[inp_row_start + j] /= sum; - } - } -} -)"; - -} // namespace gpu - -#endif // KERNELS_H diff --git a/experimental/legacy/transformer/test_kernels.cpp b/experimental/legacy/transformer/test_kernels.cpp deleted file mode 100644 index b689a7b..0000000 --- a/experimental/legacy/transformer/test_kernels.cpp +++ /dev/null @@ -1,283 +0,0 @@ -#include -#include -#include -#include - -#include "gpu.hpp" -#include "utils/array_utils.hpp" -#include "utils/logging.hpp" - -#include "llmc/reference_impls.h" -#include "shaders.h" - -using namespace gpu; - -void testResidual(Context &ctx) { - constexpr size_t N = 200000; - constexpr size_t workgroupSize = 256; - std::array input1Arr; - std::array input2Arr; - range(input1Arr); - range(input2Arr); - std::array outputArr; - Tensor input1 = createTensor(ctx, {N}, kf32, input1Arr.data()); - Tensor input2 = createTensor(ctx, {N}, kf32, input2Arr.data()); - Tensor output = createTensor(ctx, {N}, kf32, outputArr.data()); - std::promise promise; - std::future future = promise.get_future(); - KernelCode shaderCode = {kShaderResidual, workgroupSize, kf32}; - LOG(kDefLog, kInfo, "Shader Code :\n%s", shaderCode.data.c_str()); - Kernel op = - createKernel(ctx, {kShaderResidual, workgroupSize, kf32}, - Bindings{input1, input2, output}, {cdiv(N, workgroupSize), 1, 1}); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, output, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "%s", - show(outputArr, "Residual Output").c_str()); - std::array outputRef; - ref::residual_forward_cpu(outputRef.data(), input1Arr.data(), - input2Arr.data(), N); - LOG(kDefLog, kInfo, "%s", - show(outputRef, "Residual Reference Output").c_str()); - assert(isclose(outputArr.data(), outputRef.data(), N)); - LOG(kDefLog, kInfo, "Done with Residual Test"); -} - -void testHadamard(Context &ctx) { - constexpr size_t N = 200000; - constexpr size_t workgroupSize = 256; - std::array input1Arr; - std::array input2Arr; - range(input1Arr); - range(input2Arr); - std::array outputArr; - Tensor input1 = createTensor(ctx, {N}, kf32, input1Arr.data()); - Tensor input2 = createTensor(ctx, {N}, kf32, input2Arr.data()); - Tensor output = createTensor(ctx, {N}, kf32, outputArr.data()); - KernelCode shaderCode = {kShaderHadamard, workgroupSize, kf32}; - LOG(kDefLog, kInfo, "Shader Code :\n%s", shaderCode.data.c_str()); - std::promise promise; - std::future future = promise.get_future(); - Kernel op = - createKernel(ctx, {kShaderHadamard, workgroupSize, kf32}, - Bindings{input1, input2, output}, {cdiv(N, workgroupSize), 1, 1}); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - LOG(kDefLog, kInfo, "%s", - show(outputArr, "Hadamard Output").c_str()); -} - -void testMatmul(Context &ctx) { - static constexpr size_t M = 4; - static constexpr size_t K = 5; - static constexpr size_t N = 4; - auto gen = std::mt19937(31415); - std::array input1Arr; - std::array input2Arr; - std::array outputArr; - randint(input1Arr, gen, 0, 5); - range(input2Arr); - Tensor input1 = createTensor(ctx, {M, K}, kf32, input1Arr.data()); - Tensor input2 = createTensor(ctx, {K, N}, kf32, input2Arr.data()); - Tensor output = createTensor(ctx, {M, N}, kf32, outputArr.data()); - Kernel op = createKernel( - ctx, MatmulShader(256, kShaderMatMul1, kf32, M, K, N), - Bindings{input1, input2, output}, {cdiv(M * N, 256), 1, 1}); - std::promise promise; - std::future future = promise.get_future(); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, output, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "%s", show(input1Arr, "A").c_str()); - LOG(kDefLog, kInfo, "%s", show(input2Arr, "B").c_str()); - LOG(kDefLog, kInfo, "%s", show(outputArr, "C").c_str()); - - std::array refOutputArr; - std::array input2ArrT; - transpose(input2Arr.data(), input2ArrT.data(), K, N); - LOG(kDefLog, kInfo, "%s", show(input2ArrT, "B'").c_str()); - ref::matmul_forward_cpu(refOutputArr.data(), input1Arr.data(), - input2ArrT.data(), nullptr, 1, M, K, N); - LOG(kDefLog, kInfo, show(refOutputArr, "C (reference)").c_str()); - - LOG(kDefLog, kInfo, "Done with Matmul Test"); - bool passed = isclose(outputArr.data(), refOutputArr.data(), N); - assert(passed); -} - -void testTensorPool(Context &ctx) { - LOG(kDefLog, kInfo, "Starting Tensor Pool Test"); - // Test using the tensor pool to prepare tensor buffers for kernel invocation - TensorPool pool = ctx.pool; - std::array inputArr = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - Tensor input = createTensor(ctx, {2, 3}, kf32, inputArr.data()); - Tensor output = createTensor(ctx, {2, 3}, kf32); - for (int i = 0; i < 10; i++) { - Tensor t = createTensor(ctx, {2, 3}, kf32); - } - // initializing a gpu buffer w/ value and then copy it back to CPU - std::array initValue = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - LOG(kDefLog, kInfo, "making tensors with init"); - Tensor tInit = createTensor(ctx, {2, 3}, kf32, initValue.data()); - LOG(kDefLog, kInfo, "Done with Tensor Pool Test"); - std::array targetValue; - toCPU(ctx, tInit, targetValue.data(), sizeof(initValue)); - LOG(kDefLog, kInfo, "%s", - show(initValue, "initialized GPU value").c_str()); - LOG(kDefLog, kInfo, "%s", - show(targetValue, "To CPU from GPU").c_str()); - LOG(kDefLog, kInfo, "Done with Tensor Pool Test"); -} - -void testGelu(Context &ctx) { - static constexpr size_t N = 3072; - std::array inputArr; - // range(inputArr); - auto gen = std::mt19937(31415); - // TODO(avh): investigate - on metal tanh seems to produce nan for values > 10 - randint(inputArr, gen, 0, 10); // for debugging - std::array outputArr; - Tensor geluIn = createTensor(ctx, {N}, kf32, inputArr.data()); - Tensor geluOut = createTensor(ctx, {N}, kf32, outputArr.data()); - LOG(kDefLog, kInfo, "Creating GELU Shader"); - KernelCode shader = {kShaderGelu, 256, kf32}; - Kernel op = createKernel(ctx, shader, Bindings{geluIn, geluOut}, - {cdiv(N, 256), 1, 1}); - LOG(kDefLog, kInfo, "Workgroup size: %s", - toString(shader.workgroupSize).c_str()); - LOG(kDefLog, kInfo, "dispatching GELU Shader"); - std::promise promise; - std::future future = promise.get_future(); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, geluOut, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "%s", show(inputArr, "GELU Input").c_str()); - LOG(kDefLog, kInfo, "%s", - show(outputArr, "GELU Output").c_str()); - std::array refOutputArr; - ref::gelu_forward_cpu(refOutputArr.data(), inputArr.data(), N); - LOG(kDefLog, kInfo, "%s", - show(refOutputArr, "GELU Reference Output").c_str()); - bool passed = isclose(outputArr.data(), refOutputArr.data(), N); - assert(passed); - LOG(kDefLog, kInfo, "Gelu passed? %d", passed); - LOG(kDefLog, kInfo, "Done with Gelu Test"); -} - -void testLayerNorm(Context &ctx) { - struct LNParam { - uint32_t N; // check - uint32_t C; - }; - constexpr size_t N = 6; - constexpr size_t C = 3072; - std::mt19937 gen(31415); - std::array inputArr; - randint(inputArr, gen, 0, 3); - // range(inputArr); - std::array outputArr; - std::array weightArr; - std::array biasArr; - Tensor input = createTensor(ctx, {N, C}, kf32, inputArr.data()); - LNParam params = {N, C}; - randint(weightArr, gen, 0, 5); // populate randomly - randint(biasArr, gen, 0, 5); - Tensor weight = createTensor(ctx, {C}, kf32, weightArr.data()); - Tensor bias = createTensor(ctx, {C}, kf32, biasArr.data()); - Tensor output = createTensor(ctx, {N, C}, kf32, outputArr.data()); - std::promise promise; - std::future future = promise.get_future(); - Kernel op = createKernel(ctx, {kShaderLayerNorm1, 256, kf32}, - Bindings{input, weight, bias, output}, - /* n threads */ {N, 1, 1}, params); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, output, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "%s", - show(inputArr, "LayerNorm Input").c_str()); - LOG(kDefLog, kInfo, "%s", - show(weightArr, "LayerNorm Weight").c_str()); - LOG(kDefLog, kInfo, "%s", - show(biasArr, "LayerNorm Bias").c_str()); - LOG(kDefLog, kInfo, "%s", - show(outputArr, "LayerNorm Output").c_str()); - std::array refOutputArr; - ref::layernorm_forward_cpu(refOutputArr.data(), inputArr.data(), - weightArr.data(), biasArr.data(), N, 1, C); - LOG(kDefLog, kInfo, "%s", - show(refOutputArr, - "LayerNorm Reference Implementation Output") - .c_str()); - bool passed = isclose(outputArr.data(), refOutputArr.data(), N * C); - assert(passed); - LOG(kDefLog, kInfo, "LayerNorm passed? %d", passed); - LOG(kDefLog, kInfo, "Done with LayerNorm Test"); -} - -void testSoftmax(Context &ctx) { - - struct SoftmaxParam { - uint32_t N; - uint32_t C; - }; - static constexpr size_t B = 6; // batch size - static constexpr size_t T = 8; // token index - static constexpr size_t C = 3072; // input channels - std::array inputArr; - std::array outputArr; - std::mt19937 gen(31415); - randint(inputArr, gen, 0, 3); - Tensor input = createTensor(ctx, {B * T, C}, kf32, inputArr.data()); - Tensor output = createTensor(ctx, {B * T, C}, kf32, outputArr.data()); - LOG(kDefLog, kInfo, "num threads: %d", B * T); - std::promise promise; - std::future future = promise.get_future(); - Kernel op = createKernel( - ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output}, - Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{B * T, C}); - dispatchKernel(ctx, op, promise); - wait(ctx, future); - toCPU(ctx, output, outputArr.data(), sizeof(outputArr)); - LOG(kDefLog, kInfo, "%s", - show(inputArr, "Softmax Input").c_str()); - LOG(kDefLog, kInfo, "%s", - show(outputArr, "Softmax Output").c_str()); - std::array refOutputArr; - ref::softmax_forward_cpu(refOutputArr.data(), inputArr.data(), B * T, C); - LOG(kDefLog, kInfo, "%s", - show(refOutputArr, "Softmax reference Output").c_str()); - - LOG(kDefLog, kInfo, "number of elements: %d", B * T * C); - bool passed = isclose(outputArr.data(), refOutputArr.data(), B * T * C); - assert(passed); - LOG(kDefLog, kInfo, "Softmax passed? %d", passed); - LOG(kDefLog, kInfo, "Done with Softmax Test"); -} - -void testAttention(Context &ctx) { - static constexpr size_t B = 6; - static constexpr size_t T = 32; // token index - static constexpr size_t C = 3072; // input channels - static constexpr size_t QKV_DIM = 256; - static constexpr size_t N_HEADS = 12; - static constexpr size_t OC = - QKV_DIM * N_HEADS * 3; // output channels, 3 for Q, K, V - std::array inputArr; - std::array outputArr; - std::array weightArr; -} - -int main(int argc, char **argv) { - Context ctx = createContext(); - - testTensorPool(ctx); - testResidual(ctx); - testHadamard(ctx); - testMatmul(ctx); - testGelu(ctx); - testLayerNorm(ctx); - testSoftmax(ctx); - - LOG(kDefLog, kInfo, "Done with all tests"); -} From f3f3b27f3d8486263a484492e728dd9ffac13261 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:31:02 -0500 Subject: [PATCH 31/44] gpt2_webgpu_aot runs on mac after updating experimental/kernels/Makefile to point to updated dylib artifact, add note to third_party/lib/README.md regarding new dawn shared library build process --- experimental/kernels/Makefile | 12 ++++++------ third_party/lib/README.md | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index 1162aaf..e2d89b1 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -16,7 +16,7 @@ CXXFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -I. -Iunittest_l CFLAGS=-Ofast -march=native -I. -Iunittest_llmc # CFLAGS=-O2 -march=native -I. -Iunittest_llmc -LDFLAGS=$(STDLIB) -L$(GPUCPP)/third_party/lib -ldl -ldawn -fsanitize=address +LDFLAGS=$(STDLIB) -L$(GPUCPP)/third_party/lib -ldl -lwebgpu_dawn -fsanitize=address FLAGS=$(CXXFLAGS) $(LDFLAGS) ifeq ($(shell [ -d /opt/homebrew/opt/libomp/lib ] && echo "exists"), exists) @@ -83,18 +83,18 @@ endef build/test_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -g + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o build/test_gpt2_with_metal_profiler: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/test_gpt2.c build/unittest_kernels.o -I$(GPUCPP) $(GPUCPP)/experimental/profiler/metal.mm -framework metal -framework Foundation -DMETAL_PROFILER -g - install_name_tool -change @rpath/libdawn.dylib $(GPUCPP)/third_party/lib/libdawn.dylib $@ + install_name_tool -change @rpath/libwebgpu_dawn.dylib $(GPUCPP)/third_party/lib/libwebgpu_dawn.dylib $@ build/train_gpt2: llm.c build/unittest_kernels.o gpt2_124M.bin mkdir -p build $(call preprocess_file) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o -g + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ llm.c/train_gpt2.c build/unittest_kernels.o build/ops.o: ops.cpp ops.hpp kernels.h llm.c mkdir -p build && $(CXX) $(CXXFLAGS) -c -o $@ $< @@ -105,7 +105,7 @@ build/gpt2_webgpu: llm.c gpt2_124M.bin llm.c gpt2_webgpu.cpp ops.cpp build/gpt2_webgpu_aot: llm.c gpt2_124M.bin llm.c gpt2_webgpu_aot.cpp ops_aot.cpp mkdir -p build - $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu_aot.cpp ops_aot.cpp -g + $(CC) $(CXXFLAGS) -Illm.c $(LDFLAGS) -o $@ gpt2_webgpu_aot.cpp ops_aot.cpp build/gpt2_webgpu.html: check-emsdk gpt2_webgpu.cpp term.html llm.c em++ gpt2_webgpu.cpp ops.cpp \ @@ -139,7 +139,7 @@ server: build/gpt2_webgpu.html build/unittest_kernels.o: unittest_llmc/unittest_kernels.cpp unittest_llmc/unittest_kernels.h kernels.h mkdir -p build && $(CXX) $(CXXFLAGS) -DNDEBUG -c -o $@ $< -dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libdawn.so $(GPUCPP)/third_party/lib/libdawn.dylib),,run_setup) +dawnlib: $(if $(wildcard $(GPUCPP)/third_party/lib/libwebgpu_dawn.so $(GPUCPP)/third_party/lib/libwebgpu_dawn.dylib),,run_setup) run_setup: check-python cd $(GPUCPP) && python3 setup.py diff --git a/third_party/lib/README.md b/third_party/lib/README.md index 340c10e..9391cbf 100644 --- a/third_party/lib/README.md +++ b/third_party/lib/README.md @@ -1,3 +1,22 @@ +# Release 0.2.0 (draft) + +Switched from +https://github.com/jspanchu/webgpu-dawn-binaries +to building from the dawn repository: +https://github.com/google/dawn + +Commit hash: +5a00ab1fbc22d6ebbab39c901c1f90144e9b71e9 + +Build with clang (assumes running from out/Release) + +``` +cmake -DBUILD_SHARED_LIBS=ON -DDAWN_BUILD_MONOLITHIC_LIBRARY=ON -DCMAKE_BUILD_TYPE=Release ../.. +``` +Library artifact is at `src/dawn/native/libwebgpu_dawn.dylib` (7.4 MB) + +# Release 0.1.0 + https://github.com/jspanchu/webgpu-dawn-binaries commit hash: c0602d5d0466040f6e080d6cb7209860538f9f8d From b39795919b37563a55df35c618ac405294391d11 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:37:13 -0500 Subject: [PATCH 32/44] add detailed note regarding dawn modifications to fix linker errors on mac --- third_party/lib/README.md | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/third_party/lib/README.md b/third_party/lib/README.md index 9391cbf..014a48a 100644 --- a/third_party/lib/README.md +++ b/third_party/lib/README.md @@ -15,6 +15,70 @@ cmake -DBUILD_SHARED_LIBS=ON -DDAWN_BUILD_MONOLITHIC_LIBRARY=ON -DCMAKE_BUILD_TY ``` Library artifact is at `src/dawn/native/libwebgpu_dawn.dylib` (7.4 MB) +Note that for OSX builds, needed to make modifications to the `CMakeLists.txt` file to get the build to work. Specifically, needed to add the following `FORCE_OBJECT` flags to dawn_glfw, dawn_wgpu_utils, and dawn_test_utils. Otherwise we get linker errors for missing symbols when building the shared library. This issue does not appear to be present on Linux builds. + +``` +(base) austinhuang@Austins-MacBook-Pro dawn % git diff 556f960f44690b3b808c779c08b44d48d4292925 5a00ab1fbc22d6ebbab39 +diff --git a/src/dawn/glfw/CMakeLists.txt b/src/dawn/glfw/CMakeLists.txt +index dc3f3ade03..d6d8d0ef4f 100644 +--- a/src/dawn/glfw/CMakeLists.txt ++++ b/src/dawn/glfw/CMakeLists.txt +@@ -40,6 +40,7 @@ endif () + + dawn_add_library( + dawn_glfw ++ FORCE_OBJECT + UTILITY_TARGET dawn_internal_config + HEADERS + "${headers}" +@@ -56,5 +57,5 @@ target_compile_definitions(dawn_glfw PRIVATE "WGPU_GLFW_IMPLEMENTATION") + if(BUILD_SHARED_LIBS) + target_compile_definitions(dawn_glfw PUBLIC "WGPU_GLFW_SHARED_LIBRARY") + endif() +- ++# target_link_libraries(dawn_glfw PUBLIC webgpu_dawn) + add_library(webgpu_glfw ALIAS dawn_glfw) +diff --git a/src/dawn/utils/CMakeLists.txt b/src/dawn/utils/CMakeLists.txt +index 5eb7120d99..3b00664829 100644 +--- a/src/dawn/utils/CMakeLists.txt ++++ b/src/dawn/utils/CMakeLists.txt +@@ -36,6 +36,7 @@ endif() + + dawn_add_library( + dawn_wgpu_utils ++ FORCE_OBJECT + ENABLE_EMSCRIPTEN + UTILITY_TARGET dawn_internal_config + PRIVATE_HEADERS +@@ -55,6 +56,8 @@ dawn_add_library( + ${private_wgpu_depends} + ) + ++# target_link_libraries(dawn_wgpu_utils PUBLIC webgpu_dawn) ++ + # Needed by WGPUHelpers + target_compile_definitions(dawn_wgpu_utils + PUBLIC +@@ -66,6 +69,7 @@ target_compile_definitions(dawn_wgpu_utils + ############################################################################### + dawn_add_library( + dawn_test_utils ++ FORCE_OBJECT + UTILITY_TARGET dawn_internal_config + PRIVATE_HEADERS + "BinarySemaphore.h" +@@ -84,6 +88,9 @@ dawn_add_library( + dawn::partition_alloc + ) + ++# target_link_libraries(dawn_test_utils PUBLIC webgpu_dawn dawn_wgpu_utils dawn_proc) ++ ++ + ############################################################################### + # Dawn system utilities + # - Used in tests and samples + ``` + # Release 0.1.0 https://github.com/jspanchu/webgpu-dawn-binaries From 40fd25d0c1349fc1042b4f438f9500acbc074ae7 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:39:50 -0500 Subject: [PATCH 33/44] correct commit hash --- third_party/lib/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/lib/README.md b/third_party/lib/README.md index 014a48a..3800e54 100644 --- a/third_party/lib/README.md +++ b/third_party/lib/README.md @@ -6,7 +6,7 @@ to building from the dawn repository: https://github.com/google/dawn Commit hash: -5a00ab1fbc22d6ebbab39c901c1f90144e9b71e9 +556f960f44690b3b808c779c08b44d48d4292925 Build with clang (assumes running from out/Release) From 46db79d087e8a87148153a5fce01f6b1dadadcae Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:51:37 -0500 Subject: [PATCH 34/44] update setup.py auto-downloads to point to updated libwebgpu_dawn.dylib/so shared libraries --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 931306c..40cc5cc 100644 --- a/setup.py +++ b/setup.py @@ -57,12 +57,12 @@ def download_dawn(os_name): print("=====================\n") outfile_map = { - "macOS": "third_party/lib/libdawn.dylib", - "Linux": "third_party/lib/libdawn.so", + "macOS": "third_party/lib/libwebgpu_dawn.dylib", + "Linux": "third_party/lib/libwebgpu_dawn.so", } url_map = { - "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/prerelease/libdawn.dylib", - "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/prerelease/libdawn.so", + "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.dylib", + "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.so", } outfile = outfile_map.get(os_name) From 73f438a62cd58cba53fe4d3adcf7be090c1cebb6 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Tue, 28 Jan 2025 16:54:11 -0500 Subject: [PATCH 35/44] clang-format cleanup --- gpu.hpp | 397 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 200 insertions(+), 197 deletions(-) diff --git a/gpu.hpp b/gpu.hpp index d3641f7..c34894a 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -323,19 +323,22 @@ struct KernelCode { * the y and z dimensions. * * @param[in] pData Shader template string with placeholders @param[in] - * workgroupSize 3D Workgroup size + * workgroupSize 3D Workgroup size * @param[in] precision Data type precision for the shader * * @code KernelCode code = {kPuzzle1, 256, kf32}; @endcode */ - inline KernelCode(const std::string &pData, const Shape &workgroupSize = - {256, 1, 1}, NumType precision = kf32) : data(pData), - workgroupSize(workgroupSize), precision(precision) { if (precision == kf16) { - data = "enable f16;\n" + data; } replaceAll(data, "{{workgroupSize}}", - toString(workgroupSize)); replaceAll(data, "{{precision}}", - toString(precision)); LOG(kDefLog, kInfo, "Shader code:\n%s", - data.c_str()); } - + inline KernelCode(const std::string &pData, + const Shape &workgroupSize = {256, 1, 1}, + NumType precision = kf32) + : data(pData), workgroupSize(workgroupSize), precision(precision) { + if (precision == kf16) { + data = "enable f16;\n" + data; + } + replaceAll(data, "{{workgroupSize}}", toString(workgroupSize)); + replaceAll(data, "{{precision}}", toString(precision)); + LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str()); + } /** * @brief Overload of the constructor, adding totalWorkgroups parameter to @@ -351,10 +354,8 @@ struct KernelCode { * KernelCode code = {kPuzzle1, {256, 1, 1}, kf32, {2, 2, 1}}; * @endcode */ - inline KernelCode(const std::string &pData, - const Shape &workgroupSize, - NumType precision, - const Shape &totalWorkgroups) + inline KernelCode(const std::string &pData, const Shape &workgroupSize, + NumType precision, const Shape &totalWorkgroups) : data(pData), workgroupSize(workgroupSize), precision(precision) { if (precision == kf16) { data = "enable f16;\n" + data; @@ -365,7 +366,6 @@ struct KernelCode { LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str()); } - /** * @brief Overload of the constructor, adding totalWorkgroups parameter as * well as the size_t 1D workgroupSize parameter. @@ -379,11 +379,10 @@ struct KernelCode { * KernelCode code = {kPuzzle1, {256, 1, 1}, kf32, {2, 2, 1}}; * @endcode */ - inline KernelCode(const std::string &pData, - const size_t &workgroupSize, - NumType precision, - const Shape &totalWorkgroups) - : data(pData), workgroupSize({workgroupSize, 1, 1}), precision(precision) { + inline KernelCode(const std::string &pData, const size_t &workgroupSize, + NumType precision, const Shape &totalWorkgroups) + : data(pData), workgroupSize({workgroupSize, 1, 1}), + precision(precision) { if (precision == kf16) { data = "enable f16;\n" + data; } @@ -465,7 +464,6 @@ struct RawKernel { typedef std::shared_ptr Kernel; - /** * @brief A struct to package the result of a WGSL code compilation. */ @@ -575,12 +573,11 @@ struct Context { * Tensor tensor = createTensor(pool, device, {256, 256}, kf32); * @endcode */ -inline Tensor -createTensor(TensorPool &pool, WGPUDevice &device, const Shape &shape, - NumType dtype, - WGPUBufferUsage usage = WGPUBufferUsage_Storage | - WGPUBufferUsage_CopyDst | - WGPUBufferUsage_CopySrc) { +inline Tensor createTensor(TensorPool &pool, WGPUDevice &device, + const Shape &shape, NumType dtype, + WGPUBufferUsage usage = WGPUBufferUsage_Storage | + WGPUBufferUsage_CopyDst | + WGPUBufferUsage_CopySrc) { LOG(kDefLog, kTrace, "Creating tensor"); size_t numElements = size(shape); size_t size = sizeBytes(dtype) * numElements; @@ -798,32 +795,31 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, auto onAdapterRequestEnded = [](WGPURequestAdapterStatus status, WGPUAdapter adapter, WGPUStringView message, void *pUserData, void *) { - AdapterData &adapterData = *reinterpret_cast(pUserData); + AdapterData &adapterData = *reinterpret_cast(pUserData); #ifdef __EMSCRIPTEN__ - if (status != WGPURequestAdapterStatus_Success) { - LOG(kDefLog, kError, "Could not get WebGPU adapter: %.*s", - static_cast(message.length), message.data); - LOG(kDefLog, kError, - "\n\nA common reason is that the browser does not have WebGPU " - "enabled, particularly on Linux.\n" - "- Open `chrome://flags/` in the browser and make sure " - "\"WebGPU Support\" is enabled.\n" - "- Chrome is launched with vulkan enabled. From the command line " - "launch chrome as `google-chrome --enable-features=Vulkan`\n"); - } + if (status != WGPURequestAdapterStatus_Success) { + LOG(kDefLog, kError, "Could not get WebGPU adapter: %.*s", + static_cast(message.length), message.data); + LOG(kDefLog, kError, + "\n\nA common reason is that the browser does not have WebGPU " + "enabled, particularly on Linux.\n" + "- Open `chrome://flags/` in the browser and make sure " + "\"WebGPU Support\" is enabled.\n" + "- Chrome is launched with vulkan enabled. From the command line " + "launch chrome as `google-chrome --enable-features=Vulkan`\n"); + } #endif - check(status == WGPURequestAdapterStatus_Success, - "Request WebGPU adapter", __FILE__, __LINE__); - adapterData.adapter = adapter; - adapterData.requestEnded = true; + check(status == WGPURequestAdapterStatus_Success, + "Request WebGPU adapter", __FILE__, __LINE__); + adapterData.adapter = adapter; + adapterData.requestEnded = true; }; WGPURequestAdapterCallbackInfo callbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, .callback = onAdapterRequestEnded, .userdata1 = &adapterData, - .userdata2 = nullptr - }; + .userdata2 = nullptr}; wgpuInstanceRequestAdapter(context.instance, &adapterOpts, callbackInfo); while (!adapterData.requestEnded) { @@ -844,22 +840,22 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, WGPUDevice device, WGPUStringView message, void *pUserData, void *) { - DeviceData &devData = *reinterpret_cast(pUserData); - check(status == WGPURequestDeviceStatus_Success, - "Could not get WebGPU device.", __FILE__, __LINE__); - LOG(kDefLog, kTrace, "Device Request succeeded %x", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; + DeviceData &devData = *reinterpret_cast(pUserData); + check(status == WGPURequestDeviceStatus_Success, + "Could not get WebGPU device.", __FILE__, __LINE__); + LOG(kDefLog, kTrace, "Device Request succeeded %x", + static_cast(device)); + devData.device = device; + devData.requestEnded = true; }; WGPURequestDeviceCallbackInfo deviceCallbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = onDeviceRequestEnded, - .userdata1 = &devData, - .userdata2 = nullptr - }; - wgpuAdapterRequestDevice(context.adapter, &devDescriptor, deviceCallbackInfo); + .callback = onDeviceRequestEnded, + .userdata1 = &devData, + .userdata2 = nullptr}; + wgpuAdapterRequestDevice(context.adapter, &devDescriptor, + deviceCallbackInfo); LOG(kDefLog, kInfo, "Waiting for device request to end"); while (!devData.requestEnded) { @@ -870,22 +866,23 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, context.device = devData.device; WGPULoggingCallbackInfo loggingCallbackInfo = { - .callback = [](WGPULoggingType type, WGPUStringView message, void* userdata1, void* userdata2) { - LOG(kDefLog, kError, "Device logging callback: %.*s", (int)message.length, message.data); - if (type == WGPULoggingType_Error) { + .callback = + [](WGPULoggingType type, WGPUStringView message, void *userdata1, + void *userdata2) { + LOG(kDefLog, kError, "Device logging callback: %.*s", + (int)message.length, message.data); + if (type == WGPULoggingType_Error) { throw std::runtime_error("Device error logged."); - } - }, + } + }, .userdata1 = nullptr, - .userdata2 = nullptr - }; + .userdata2 = nullptr}; wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo); } context.queue = wgpuDeviceGetQueue(context.device); return context; } - #ifdef USE_DAWN_API /** * @brief Factory function to create a GPU context, which aggregates WebGPU API @@ -893,7 +890,8 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, * queue. * * The function takes gpu index to support for multi GPUs. - * To activate this function, it needs not only webgpu's headers but also DAWN's headers. + * To activate this function, it needs not only webgpu's headers but also DAWN's + * headers. * * If dawn is used, it also sets up an error callback for device loss. * @@ -906,9 +904,9 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, * Context ctx = createContextByGpuIdx(1); * @endcode */ -inline Context createContextByGpuIdx(int gpuIdx, - const WGPUInstanceDescriptor &desc = {}, - const WGPUDeviceDescriptor &devDescriptor = {}) { +inline Context +createContextByGpuIdx(int gpuIdx, const WGPUInstanceDescriptor &desc = {}, + const WGPUDeviceDescriptor &devDescriptor = {}) { Context context; { #ifdef __EMSCRIPTEN__ @@ -925,12 +923,15 @@ inline Context createContextByGpuIdx(int gpuIdx, LOG(kDefLog, kInfo, "Requesting adapter"); { std::vector adapters = - dawn::native::Instance(reinterpret_cast(context.instance)) - .EnumerateAdapters(); + dawn::native::Instance( + reinterpret_cast(context.instance)) + .EnumerateAdapters(); LOG(kDefLog, kInfo, "The number of GPUs=%d\n", adapters.size()); - // Note: Second gpu is not available on Macos, but the number of GPUs is 2 on Macos. - // Calling wgpuAdapterGetInfo function for the second gpu becomes segfault. - // When you check all GPUs on linux, uncomment out following codes. + // Note: Second gpu is not available on Macos, but the number of GPUs is 2 + // on Macos. + // Calling wgpuAdapterGetInfo function for the second gpu becomes + // segfault. When you check all GPUs on linux, uncomment out following + // codes. // // for (size_t i = 0; i < adapters.size(); i++) { // WGPUAdapterInfo info {}; @@ -939,18 +940,19 @@ inline Context createContextByGpuIdx(int gpuIdx, // wgpuAdapterGetInfo(ptr, &info); // LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", i, info.description); // wgpuAdapterInfoFreeMembers(info); - // } + // } // } { LOG(kDefLog, kInfo, "Use GPU(Adapter)[%d]\n", gpuIdx); auto ptr = adapters[gpuIdx].Get(); if (ptr) { - WGPUAdapterInfo info {}; - wgpuAdapterGetInfo(ptr, &info); - LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", gpuIdx, info.description); - wgpuAdapterInfoFreeMembers(info); - } + WGPUAdapterInfo info{}; + wgpuAdapterGetInfo(ptr, &info); + LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", gpuIdx, + info.description); + wgpuAdapterInfoFreeMembers(info); + } context.adapter = adapters[gpuIdx].Get(); dawn::native::GetProcs().adapterAddRef(context.adapter); } @@ -965,25 +967,24 @@ inline Context createContextByGpuIdx(int gpuIdx, DeviceData devData; auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, - WGPUDevice device, WGPUStringView message, - void *pUserData, void *) { - DeviceData &devData = *reinterpret_cast(pUserData); - check(status == WGPURequestDeviceStatus_Success, - "Could not get WebGPU device.", __FILE__, __LINE__); - LOG(kDefLog, kTrace, "Device Request succeeded %x", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; -}; + WGPUDevice device, WGPUStringView message, + void *pUserData, void *) { + DeviceData &devData = *reinterpret_cast(pUserData); + check(status == WGPURequestDeviceStatus_Success, + "Could not get WebGPU device.", __FILE__, __LINE__); + LOG(kDefLog, kTrace, "Device Request succeeded %x", + static_cast(device)); + devData.device = device; + devData.requestEnded = true; + }; WGPURequestDeviceCallbackInfo deviceCallbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, .callback = onDeviceRequestEnded, .userdata1 = &devData, - .userdata2 = nullptr - }; - wgpuAdapterRequestDevice(context.adapter, &devDescriptor, deviceCallbackInfo); - + .userdata2 = nullptr}; + wgpuAdapterRequestDevice(context.adapter, &devDescriptor, + deviceCallbackInfo); LOG(kDefLog, kInfo, "Waiting for device request to end"); while (!devData.requestEnded) { @@ -994,17 +995,18 @@ inline Context createContextByGpuIdx(int gpuIdx, context.device = devData.device; WGPULoggingCallbackInfo loggingCallbackInfo = { - .callback = [](WGPULoggingType type, WGPUStringView message, void* userdata1, void* userdata2) { - LOG(kDefLog, kError, "Device logging callback: %.*s", (int)message.length, message.data); - if (type == WGPULoggingType_Error) { + .callback = + [](WGPULoggingType type, WGPUStringView message, void *userdata1, + void *userdata2) { + LOG(kDefLog, kError, "Device logging callback: %.*s", + (int)message.length, message.data); + if (type == WGPULoggingType_Error) { throw std::runtime_error("Device error logged."); - } - }, + } + }, .userdata1 = nullptr, - .userdata2 = nullptr - }; + .userdata2 = nullptr}; wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo); - } context.queue = wgpuDeviceGetQueue(context.device); return context; @@ -1039,31 +1041,33 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize, WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(userdata1); - WGPUBufferMapCallbackInfo mapCallbackInfo = { - .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = [](WGPUMapAsyncStatus status, WGPUStringView message, void* userdata1, void* userdata2) { - const auto *data = static_cast(userdata1); - check(status == WGPUMapAsyncStatus_Success, - "Map readbackBuffer", __FILE__, __LINE__); - const void *mappedData = wgpuBufferGetConstMappedRange( - data->buffer, /*offset=*/0, data->bufferSize); - check(mappedData, "Get mapped range", __FILE__, __LINE__); - memcpy(data->output, mappedData, data->bufferSize); - wgpuBufferUnmap(data->buffer); - data->promise->set_value(); - }, - .userdata1 = const_cast(data), - .userdata2 = nullptr - }; - wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, data->bufferSize, mapCallbackInfo); - }, + .callback = + [](WGPUQueueWorkDoneStatus status, void *userdata1, void *userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(userdata1); + WGPUBufferMapCallbackInfo mapCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = + [](WGPUMapAsyncStatus status, WGPUStringView message, + void *userdata1, void *userdata2) { + const auto *data = static_cast(userdata1); + check(status == WGPUMapAsyncStatus_Success, + "Map readbackBuffer", __FILE__, __LINE__); + const void *mappedData = wgpuBufferGetConstMappedRange( + data->buffer, /*offset=*/0, data->bufferSize); + check(mappedData, "Get mapped range", __FILE__, __LINE__); + memcpy(data->output, mappedData, data->bufferSize); + wgpuBufferUnmap(data->buffer); + data->promise->set_value(); + }, + .userdata1 = const_cast(data), + .userdata2 = nullptr}; + wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, + data->bufferSize, mapCallbackInfo); + }, .userdata1 = &callbackData, - .userdata2 = nullptr - }; + .userdata2 = nullptr}; wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); wait(ctx, op.future); @@ -1126,8 +1130,7 @@ void toCPU(Context &ctx, Tensor &tensor, std::array &data) { toCPU(ctx, tensor, data.data(), sizeof(data)); } -inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, - size_t size) { +inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, size_t size) { uint64_t bufferSize = size; CopyData op; op.future = op.promise.get_future(); @@ -1153,15 +1156,18 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, CallbackData callbackData = {op.readbackBuffer, bufferSize, data, &op.promise, &op.future}; - WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { - .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(userdata1); - WGPUBufferMapCallbackInfo mapCallbackInfo = { - .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = [](WGPUMapAsyncStatus status, WGPUStringView message, void* userdata1, void* userdata2) { + WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = + [](WGPUQueueWorkDoneStatus status, void *userdata1, void *userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + const auto *data = static_cast(userdata1); + WGPUBufferMapCallbackInfo mapCallbackInfo = { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = + [](WGPUMapAsyncStatus status, WGPUStringView message, + void *userdata1, void *userdata2) { const auto *data = static_cast(userdata1); check(status == WGPUMapAsyncStatus_Success, "Map readbackBuffer", __FILE__, __LINE__); @@ -1171,16 +1177,15 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, memcpy(data->output, mappedData, data->bufferSize); wgpuBufferUnmap(data->buffer); data->promise->set_value(); - }, - .userdata1 = const_cast(data), - .userdata2 = nullptr - }; - wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, data->bufferSize, mapCallbackInfo); + }, + .userdata1 = const_cast(data), + .userdata2 = nullptr}; + wgpuBufferMapAsync(data->buffer, WGPUMapMode_Read, 0, + data->bufferSize, mapCallbackInfo); }, - .userdata1 = &callbackData, - .userdata2 = nullptr - }; - wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); + .userdata1 = &callbackData, + .userdata2 = nullptr}; + wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); wait(ctx, op.future); if (op.readbackBuffer) { @@ -1188,7 +1193,6 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, } } - /** * @brief Copies data from CPU memory to a GPU buffer. The toGPU overloads are * effectively a convenience wrapper around the WebGPU API call @@ -1234,8 +1238,9 @@ inline void toGPU(Context &ctx, const int *data, Tensor &tensor) { wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, tensor.data.size); } - -inline void toGPU(Context &ctx, const float *data, Tensor &tensor, size_t size) { + +inline void toGPU(Context &ctx, const float *data, Tensor &tensor, + size_t size) { wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, size); } @@ -1345,13 +1350,15 @@ inline Shape cdiv(Shape total, Shape group) { */ inline Kernel createKernel(Context &ctx, const KernelCode &code, const Tensor *dataBindings, size_t numTensors, - const size_t *viewOffsets, const Shape &totalWorkgroups, - const void *params = nullptr, - size_t paramsSize = 0, - CompilationInfo* compilationInfo = nullptr, - const char* cacheKey = nullptr) { - // Create a cache key by the pointer values of the data bindings and the kernel code - if (cacheKey != nullptr && ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) { + const size_t *viewOffsets, + const Shape &totalWorkgroups, + const void *params = nullptr, size_t paramsSize = 0, + CompilationInfo *compilationInfo = nullptr, + const char *cacheKey = nullptr) { + // Create a cache key by the pointer values of the data bindings and the + // kernel code + if (cacheKey != nullptr && + ctx.kernelPool.data.find(cacheKey) != ctx.kernelPool.data.end()) { LOG(kDefLog, kInfo, "Kernel cache hit"); return ctx.kernelPool.data[cacheKey]; } @@ -1462,8 +1469,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, WGPUShaderSourceWGSL wgslDesc = { .chain = {.sType = WGPUSType_ShaderSourceWGSL}, - .code = {.data = code.data.c_str(), .length = code.data.length()} - }; + .code = {.data = code.data.c_str(), .length = code.data.length()}}; WGPUShaderModuleDescriptor shaderModuleDesc = {}; shaderModuleDesc.nextInChain = &wgslDesc.chain; @@ -1474,50 +1480,50 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, computePipelineDesc.compute.module = wgpuDeviceCreateShaderModule(device, &shaderModuleDesc); - computePipelineDesc.compute.entryPoint = {code.entryPoint.c_str(), code.entryPoint.length()}; + computePipelineDesc.compute.entryPoint = {code.entryPoint.c_str(), + code.entryPoint.length()}; computePipelineDesc.label = {code.label.c_str(), code.label.length()}; op->computePipeline = wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); - op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], totalWorkgroups[2]}; + op->totalWorkgroups = {totalWorkgroups[0], totalWorkgroups[1], + totalWorkgroups[2]}; resetCommandBuffer(device, op); if (cacheKey != nullptr) - ctx.kernelPool.data[cacheKey]=op; + ctx.kernelPool.data[cacheKey] = op; auto compilationInfoCallback = [](WGPUCompilationInfoRequestStatus status, WGPUCompilationInfo const *compilationInfo, void *userdata1, void *userdata2) { - CompilationInfo *result = static_cast(userdata1); - if (compilationInfo && result) { - result->status = status; - for (uint32_t i = 0; i < compilationInfo->messageCount; ++i) { - printf("Message %d: %.*s\n", i, - static_cast(compilationInfo->messages[i].message.length), - compilationInfo->messages[i].message.data); - result->messages.push_back(std::string( - compilationInfo->messages[i].message.data, - compilationInfo->messages[i].message.length)); - result->lineNums.push_back(compilationInfo->messages[i].lineNum); - result->linePos.push_back(compilationInfo->messages[i].linePos); - } - result->finished = true; - } else { - LOG(kDefLog, kTrace, "No compilation info or result"); + CompilationInfo *result = static_cast(userdata1); + if (compilationInfo && result) { + result->status = status; + for (uint32_t i = 0; i < compilationInfo->messageCount; ++i) { + printf("Message %d: %.*s\n", i, + static_cast(compilationInfo->messages[i].message.length), + compilationInfo->messages[i].message.data); + result->messages.push_back( + std::string(compilationInfo->messages[i].message.data, + compilationInfo->messages[i].message.length)); + result->lineNums.push_back(compilationInfo->messages[i].lineNum); + result->linePos.push_back(compilationInfo->messages[i].linePos); } + result->finished = true; + } else { + LOG(kDefLog, kTrace, "No compilation info or result"); + } }; WGPUCompilationInfoCallbackInfo compilationCallbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, .callback = compilationInfoCallback, .userdata1 = static_cast(compilationInfo), - .userdata2 = nullptr - }; + .userdata2 = nullptr}; while (compilationInfo && !compilationInfo->finished) { processEvents(ctx.instance); } return op; - } /** @@ -1530,8 +1536,8 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, * @param[in] code WGSL code for the kernel * @param[in] dataBindings A Bindings of tensors whose GPU buffers are bound * to the kernel as inputs and outputs. - * @param[in] totalWorkgroups Number of workgroups in the x, y, z grid, must be a - * Shape of rank == 3. + * @param[in] totalWorkgroups Number of workgroups in the x, y, z grid, must be + * a Shape of rank == 3. * @param[in] params Optional parameters for the kernel. If the kernel does * not have any parameters, use NoParam. * @return Kernel instance representing the created kernel @@ -1546,20 +1552,17 @@ Kernel createKernel(Context &ctx, const KernelCode &code, const Bindings &dataBindings, const Shape &totalWorkgroups, const ParamsType ¶ms = ParamsType{}, - CompilationInfo* compilationInfo = nullptr, - const char* cacheKey = nullptr - ) { + CompilationInfo *compilationInfo = nullptr, + const char *cacheKey = nullptr) { if constexpr (!IsNoParam) { return createKernel(ctx, code, dataBindings.data.data(), numInputs, dataBindings.viewOffsets.data(), totalWorkgroups, reinterpret_cast(¶ms), - sizeof(ParamsType), compilationInfo, - cacheKey); + sizeof(ParamsType), compilationInfo, cacheKey); } else { return createKernel(ctx, code, dataBindings.data.data(), numInputs, - dataBindings.viewOffsets.data(), totalWorkgroups, nullptr, - 0, compilationInfo, - cacheKey); + dataBindings.viewOffsets.data(), totalWorkgroups, + nullptr, 0, compilationInfo, cacheKey); } } @@ -1592,15 +1595,15 @@ inline void dispatchKernel(Context &ctx, Kernel &kernel, WGPUQueueWorkDoneCallbackInfo workDoneCallbackInfo = { .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = [](WGPUQueueWorkDoneStatus status, void* userdata1, void* userdata2) { - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - auto *promise = static_cast *>(userdata1); - promise->set_value(); - }, + .callback = + [](WGPUQueueWorkDoneStatus status, void *userdata1, void *userdata2) { + check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", + __FILE__, __LINE__); + auto *promise = static_cast *>(userdata1); + promise->set_value(); + }, .userdata1 = &promise, - .userdata2 = nullptr - }; + .userdata2 = nullptr}; wgpuQueueOnSubmittedWorkDone(ctx.queue, workDoneCallbackInfo); } From d8d618a56ce8558023ae944dc0f8bbf9d4c6f126 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Thu, 30 Jan 2025 09:47:02 -0500 Subject: [PATCH 36/44] make Context lifetime more robust dont rely on RVO which seems to fail if createContext is called from eg conditional branches. Remove webgpu from scratch tutorial to avoid having to maintain/update the implementation --- examples/hello_world/run.cpp | 8 +- examples/matmul/run.cpp | 45 +- examples/webgpu_from_scratch/CMakeLists.txt | 21 - examples/webgpu_from_scratch/Makefile | 8 - examples/webgpu_from_scratch/run.cpp | 446 ------- gpu.hpp | 189 ++- third_party/headers/portaudio/portaudio.h | 1251 ------------------- 7 files changed, 158 insertions(+), 1810 deletions(-) delete mode 100644 examples/webgpu_from_scratch/CMakeLists.txt delete mode 100644 examples/webgpu_from_scratch/Makefile delete mode 100644 examples/webgpu_from_scratch/run.cpp delete mode 100644 third_party/headers/portaudio/portaudio.h diff --git a/examples/hello_world/run.cpp b/examples/hello_world/run.cpp index 3fbafc4..7453869 100644 --- a/examples/hello_world/run.cpp +++ b/examples/hello_world/run.cpp @@ -3,9 +3,7 @@ #include #include -using namespace gpu; // createContext, createTensor, createKernel, - // createShader, dispatchKernel, wait, toCPU - // Tensor, Kernel, Context, Shape, kf32 +using namespace gpu; static const char *kGelu = R"( const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) @@ -29,6 +27,7 @@ int main(int argc, char **argv) { printf("\nHello gpu.cpp!\n"); printf("--------------\n\n"); + // std::unique_ptr ctx = createContext(); Context ctx = createContext(); static constexpr size_t N = 10000; std::array inputArr, outputArr; @@ -41,7 +40,7 @@ int main(int argc, char **argv) { std::future future = promise.get_future(); Kernel op = createKernel(ctx, {kGelu, 256, kf32}, Bindings{input, output}, - /* nWorkgroups */ {cdiv(N, 256), 1, 1}); + {cdiv(N, 256), 1, 1}); dispatchKernel(ctx, op, promise); wait(ctx, future); toCPU(ctx, output, outputArr.data(), sizeof(outputArr)); @@ -50,5 +49,4 @@ int main(int argc, char **argv) { } printf(" ...\n\n"); printf("Computed %zu values of GELU(x)\n\n", N); - return 0; } diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index 3db4f78..42d7009 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -792,13 +792,40 @@ void runTest(int version, size_t M, size_t K, size_t N, } // Allocate GPU buffers and copy data - Context ctx = createContext( - {}, {}, - /*device descriptor, enabling f16 in WGSL*/ - { + WGPUDeviceDescriptor devDescriptor = {}; + devDescriptor.requiredFeatureCount = 1; + devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(); + + Context ctx; + if (numtype == kf16) { + ctx = createContext( + {}, {}, + /*device descriptor, enabling f16 in WGSL*/ + { .requiredFeatureCount = 1, - .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(), - }); + .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data() + }); + if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) { + LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)."); + exit(1); + } + if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) { + LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)"); + exit(1); + } + } + + if (numtype == kf32) { + ctx = createContext({}, {}, {}); + if (ctx.adapterStatus != WGPURequestAdapterStatus_Success || + ctx.deviceStatus != WGPURequestDeviceStatus_Success) { + LOG(kDefLog, kError, "Failed to create adapter or device"); + // stop execution + exit(1); + } else { + LOG(kDefLog, kInfo, "Successfully created adapter and device"); + } + } Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get()); Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major @@ -810,8 +837,6 @@ void runTest(int version, size_t M, size_t K, size_t N, #endif // Initialize Kernel and bind GPU buffers - - // pre-allocate for async dispatch std::array, nIter> promises; std::array, nIter> futures; @@ -823,10 +848,6 @@ void runTest(int version, size_t M, size_t K, size_t N, kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype); } -#ifndef METAL_PROFILER - printf("[ Press enter to start tests ... ]\n"); - getchar(); -#endif LOG(kDefLog, kInfo, "Dispatching Kernel version %d: %s, %d iterations ...", version, versionToStr(version).c_str(), nIter); diff --git a/examples/webgpu_from_scratch/CMakeLists.txt b/examples/webgpu_from_scratch/CMakeLists.txt deleted file mode 100644 index 8804628..0000000 --- a/examples/webgpu_from_scratch/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -cmake_minimum_required(VERSION 3.11) -project(wgpu_tutorial) - -include(FetchContent) - -FetchContent_Declare( - webgpu-backend-dawn - GIT_REPOSITORY https://github.com/eliemichel/WebGPU-distribution - GIT_TAG dawn-6376 - GIT_SHALLOW TRUE -) -FetchContent_MakeAvailable(webgpu-backend-dawn) - -FetchContent_Declare(spdlog - GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG 27cb4c76708608465c413f6d0e6b8d99a4d84302 -) -FetchContent_MakeAvailable(spdlog) - -add_executable(wgpu_tutorial run.cpp) -target_link_libraries(wgpu_tutorial webgpu spdlog) diff --git a/examples/webgpu_from_scratch/Makefile b/examples/webgpu_from_scratch/Makefile deleted file mode 100644 index 6e0878f..0000000 --- a/examples/webgpu_from_scratch/Makefile +++ /dev/null @@ -1,8 +0,0 @@ -run: - mkdir -p build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Debug -DWEBGPU_BACKEND=DAWN -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON && make wgpu_tutorial && ./wgpu_tutorial - -watch: - mkdir -p build && cd build && ls ../* | entr -s "cmake .. -DCMAKE_BUILD_TYPE=Debug -DWEBGPU_BACKEND=DAWN -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON && make wgpu_tutorial && ./wgpu_tutorial" - -clean: - read -r -p "Are you sure? [CTRL-C to abort] " response && rm -rf build/* diff --git a/examples/webgpu_from_scratch/run.cpp b/examples/webgpu_from_scratch/run.cpp deleted file mode 100644 index 38f9b98..0000000 --- a/examples/webgpu_from_scratch/run.cpp +++ /dev/null @@ -1,446 +0,0 @@ -#include -#include - -#include "webgpu/webgpu.h" -#include "spdlog/spdlog.h" -/* - * Approximate GELU kernel definition, implemented as a WGSL. - * In general GPU device code for WEBGPU is written in the WGSL domain specific - * language. - * - * Here inp and out correspond to bindings 0 and 1 respectively. In the main - * code, we create buffers for these bindings and populate them with data. - * - */ -const char *kShaderGELU = R"( -const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) -@group(0) @binding(0) var inp: array; -@group(0) @binding(1) var out: array; -@compute @workgroup_size(256) -fn main( - @builtin(global_invocation_id) GlobalInvocationID: vec3) { - let i: u32 = GlobalInvocationID.x; - // Ensure we do not access out of bounds - if (i < 3072) { - let x: f32 = inp[i]; - let cube: f32 = 0.044715 * x * x * x; - out[i] = 0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR * (x + cube))); - } -} -)"; - -/* - * Convenience function to check if a condition is true, if not log an error - * message and exit. - * - * @param condition: The condition to check. - * @param message: The error message to log if the condition is false. - * @param file: The file where the error occurred. - * @param line: The line where the error occurred. - */ -inline void check(bool condition, const char *message, - const char *file = "unkown", int line = -1) { - if (!condition) { - spdlog::error("Error in file {} line {}:\n{}", file, line, message); - exit(1); - } else { - spdlog::trace("Success in file {} line {}:\n{}", file, line, message); - } -} - -/* - * Convenience function to display the first few elements of an array. A more - * robust/extensive version of this is in array_utils.hpp this is minimal to keep - * this example self-contained. - * - * @param a: The array to show. - * @param name: The name of the array. - * @return: A string representation of the array. - */ -template -std::string show(std::array a, std::string name) { - std::string output = "\n\n"; - if (name != "") { - output += name + " (" + std::to_string(N) + ") : \n"; - } - for (size_t i = 0; i < N; i++) { - output += std::to_string(a[i]) + "\n"; - if (i > 10) { - output += "...\n"; - break; - } - } - return output; -} - -int main() { - - static constexpr size_t N = 3072; - - // Host data - input and output arrays on the CPU - std::array inputArr; - std::array outputArr; - for (size_t i = 0; i < N; i++) { - // Populate input array with a range of dummy values - inputArr[i] = static_cast(i); - } - - // API representations for interfacing with the GPU - WGPUInstance instance; // The instance is the top-level context object for - // WebGPU. It is used to create adapters. - WGPUAdapter adapter; // The adapter is the physical device that WebGPU uses - // to interface with the GPU. - WGPUDevice device; // The device is the logical device that WebGPU uses to - // interface with the adapter. - WGPUQueue queue; // The queue is used to submit work to the GPU. - - // Buffers - buffers are used to store data on the GPU. - WGPUBuffer inputBuffer; // The input buffer is used to store the input data. - WGPUBuffer outputBuffer; // The output buffer is used to store the output data. - WGPUBuffer readbackBuffer; // The readback buffer is used to copy the output - // data from the GPU back to the CPU. - WGPUCommandBuffer commandBuffer; // The command buffer is used to store the - // sequence of operations to be executed on - // the GPU. - - // Async management - polling the GPU is asynchronous, so we need to manage - // the async work. - std::promise promise; // used to signal when the work is done. - std::future future; // used to wait for the work to be done. - - // Here we initialize the instance, adapter, device, and queue. - spdlog::info("Setting up GPU Context"); - { - const WGPUInstanceDescriptor desc = {}; - WGPURequestAdapterOptions adapterOpts = {}; - WGPUDeviceDescriptor devDescriptor = {}; - spdlog::info("Creating instance"); - { - instance = wgpuCreateInstance(&desc); - check(instance, "Initialize WebGPU", __FILE__, __LINE__); - } - spdlog::info("Requesting adapter"); - { - struct AdapterData { - WGPUAdapter adapter = nullptr; - bool requestEnded = false; - }; - AdapterData adapterData; - auto onAdapterRequestEnded = [](WGPURequestAdapterStatus status, - WGPUAdapter adapter, char const *message, - void *pUserData) { - AdapterData &adapterData = *reinterpret_cast(pUserData); - check(status == WGPURequestAdapterStatus_Success, - "Request WebGPU adapter", __FILE__, __LINE__); - adapterData.adapter = adapter; - adapterData.requestEnded = true; - }; - wgpuInstanceRequestAdapter(instance, &adapterOpts, onAdapterRequestEnded, - (void *)&adapterData); - assert(adapterData.requestEnded); - adapter = adapterData.adapter; - check(adapter, "Get WebGPU adapter", __FILE__, __LINE__); - } - spdlog::info("Requesting device"); - { - struct DeviceData { - WGPUDevice device = nullptr; - bool requestEnded = false; - }; - DeviceData devData; - auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, - WGPUDevice device, char const *message, - void *pUserData) { - DeviceData &devData = *reinterpret_cast(pUserData); - check(status == WGPURequestDeviceStatus_Success, - "Could not get WebGPU device.", __FILE__, __LINE__); - spdlog::info("Device Request succeeded {}", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; - }; - devDescriptor.deviceLostCallback = - [](WGPUDeviceLostReason reason, char const *message, void *userdata) { - spdlog::error("Device lost:\n{}", message); - }; - wgpuAdapterRequestDevice(adapter, &devDescriptor, onDeviceRequestEnded, - (void *)&devData); - assert(devData.requestEnded); - device = devData.device; - spdlog::info("Setting error callback"); - wgpuDeviceSetUncapturedErrorCallback( - device, - [](WGPUErrorType type, char const *message, void *devData) { - spdlog::error("Device uncaptured error: {}", message); - }, - nullptr); - wgpuDeviceSetLoggingCallback( - device, - [](WGPULoggingType level, const char *message, void *userdata) { - spdlog::info("WebGPU Validation: {}", message); - }, - NULL); - } - // Queue - spdlog::info("Instantiating device queue"); - queue = wgpuDeviceGetQueue(device); - } - - // Here we setup the binding group layout. The binding group layout is used to - // define the layout of the bind group - e.g. how many buffers are going to be - // used and what their sizes are. - // - // The general pattern of using the WebGPU API is to populate a configuration - // using a descriptor type (*Descriptor), and then pass the descriptor to a - // factory function (*Create*) operation which returns a handle to the - // object. Sometimes the descriptors can be hierarchical and nested, but - // ultimately they are still just an elaborate set of configuration - // parameters. - // - // For example, here we populate a WGPUBindGroupLayoutDescriptor and then - // pass that to the wgpuDeviceCreateBindGroupLayout() function to get back a - // WGPUBindGroupLayout. - spdlog::info("Setting up binding group layout"); - WGPUBindGroupLayout bgLayout; - static constexpr uint32_t bufferSize = - static_cast(sizeof(float) * N); - spdlog::info("Buffer size: {}, number of elements {}", bufferSize, N); - { - WGPUBindGroupLayoutEntry bgLayoutEntries[2]; - bgLayoutEntries[0] = (WGPUBindGroupLayoutEntry){ - .binding = 0, - .visibility = WGPUShaderStage_Compute, - .buffer = - (WGPUBufferBindingLayout){ - .type = WGPUBufferBindingType_Storage, - .minBindingSize = bufferSize, - }, - }; - bgLayoutEntries[1] = (WGPUBindGroupLayoutEntry){ - .binding = 1, - .visibility = WGPUShaderStage_Compute, - .buffer = - (WGPUBufferBindingLayout){ - .type = WGPUBufferBindingType_Storage, - .minBindingSize = bufferSize, - }, - }; - spdlog::info("Creating Binding Group Layout Description"); - WGPUBindGroupLayoutDescriptor bgLayoutDesc = { - .entryCount = std::size(bgLayoutEntries), - .entries = bgLayoutEntries, - }; - bgLayout = wgpuDeviceCreateBindGroupLayout(device, &bgLayoutDesc); - } - - // After setting up the binding group layout we initialize the buffers by - // interacting with the device. - spdlog::info("Create buffers: input, output, and readback"); - { - WGPUBufferDescriptor inputBufferDesc = { - .usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst, - .size = bufferSize, - }; - inputBuffer = wgpuDeviceCreateBuffer(device, &inputBufferDesc); - WGPUBufferDescriptor outputBufferDesc = { - .usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | - WGPUBufferUsage_CopySrc, - .size = bufferSize, - }; - outputBuffer = wgpuDeviceCreateBuffer(device, &outputBufferDesc); - WGPUBufferDescriptor readbackBufferDescriptor = { - .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, - .size = bufferSize, - }; - readbackBuffer = wgpuDeviceCreateBuffer(device, &readbackBufferDescriptor); - check(inputBuffer, "Create input buffer", __FILE__, __LINE__); - check(outputBuffer, "Create output buffer", __FILE__, __LINE__); - check(readbackBuffer, "Create readback buffer", __FILE__, __LINE__); - } - - // We create the bind group with references to the buffers and initialize the - // binding group. Does this seem redundant with the binding group layout? - // Probably. - // The bind group is used to bind the buffers to the compute pipeline. - // The bind group layout is used to define the layout of the bind group. - spdlog::info("Create the bind group"); - WGPUBindGroup bindGroup; - { - WGPUBindGroupEntry bindGroupEntries[2]; - bindGroupEntries[0] = (WGPUBindGroupEntry){ - .binding = 0, - .buffer = inputBuffer, - .offset = 0, - .size = bufferSize, - }; - bindGroupEntries[1] = (WGPUBindGroupEntry){ - .binding = 1, - .buffer = outputBuffer, - .offset = 0, - .size = bufferSize, - }; - WGPUBindGroupDescriptor bindGroupDesc = { - .layout = bgLayout, - .entryCount = std::size(bindGroupEntries), - .entries = bindGroupEntries, - }; - bindGroup = wgpuDeviceCreateBindGroup(device, &bindGroupDesc); - } - - // We create the compute pipeline with the shader module and pipeline layout. - // The compute pipeline is used to run the compute shader. - spdlog::info("Creating the compute pipeline"); - WGPUComputePipeline computePipeline; - { - WGPUPipelineLayout pipelineLayout; - WGPUPipelineLayoutDescriptor pipelineLayoutDesc = { - .bindGroupLayoutCount = 1, - .bindGroupLayouts = &bgLayout, - }; - pipelineLayout = - wgpuDeviceCreatePipelineLayout(device, &pipelineLayoutDesc); - WGPUShaderModuleWGSLDescriptor wgslDesc = { - .code = kShaderGELU, - }; - wgslDesc.chain.sType = WGPUSType_ShaderModuleWGSLDescriptor; - WGPUShaderModuleDescriptor shaderModuleDesc = {}; - shaderModuleDesc.nextInChain = &wgslDesc.chain; - shaderModuleDesc.label = "shader"; - WGPUComputePipelineDescriptor computePipelineDesc = {}; - computePipelineDesc.layout = pipelineLayout; - computePipelineDesc.compute.module = - wgpuDeviceCreateShaderModule(device, &shaderModuleDesc); - computePipelineDesc.compute.entryPoint = "main"; - computePipeline = - wgpuDeviceCreateComputePipeline(device, &computePipelineDesc); - check(computePipeline, "Create compute pipeline", __FILE__, __LINE__); - } - - // We create the command encoder and the compute pass encoder. The command - // encoder is used to encode commands for the GPU. The compute pass encoder is - // used to encode commands for the compute pipeline. - spdlog::info("Create the command encoder"); - { - static constexpr uint32_t kWorkgroupSize = 256; // This needs to match the - // workgroup size in the - // shader. - WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; - commandEncoder = wgpuDeviceCreateCommandEncoder(device, nullptr); - computePassEncoder = - wgpuCommandEncoderBeginComputePass(commandEncoder, nullptr); - wgpuComputePassEncoderSetPipeline(computePassEncoder, computePipeline); - wgpuComputePassEncoderSetBindGroup(computePassEncoder, 0, bindGroup, 0, - nullptr); - wgpuComputePassEncoderDispatchWorkgroups( - computePassEncoder, (N + (kWorkgroupSize - 1)) / kWorkgroupSize, 1, 1); - wgpuComputePassEncoderEnd(computePassEncoder); - wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, outputBuffer, 0, - readbackBuffer, 0, bufferSize); - commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); - check(commandBuffer, "Create command buffer", __FILE__, __LINE__); - } - spdlog::info("Initializing promise and future"); - promise = std::promise(); - future = promise.get_future(); - - spdlog::info("Copying input data to GPU"); - wgpuQueueWriteBuffer(queue, inputBuffer, 0, inputArr.data(), bufferSize); - - // Submit the command buffer and launch the kernel. The command buffer is - // submitted to the queue and a callback is set up to handle the completion of - // the job which updates the promise. A while loop is used to wait for the - // promise to be set. - spdlog::info("Submit the command buffer and launching the kernel"); - struct CallbackData { - WGPUBuffer buffer; - size_t bufferSize; - float *output; - std::promise *promise; - }; - { - - // Submit the command buffer - wgpuQueueSubmit(queue, 1, &commandBuffer); - CallbackData callbackData = - CallbackData{readbackBuffer, sizeof(outputArr), nullptr, &promise}; - // Set up the callback for when the work is done - wgpuQueueOnSubmittedWorkDone( - queue, - [](WGPUQueueWorkDoneStatus status, void *callbackData) { - spdlog::info("QueueOnSubmittedWorkDone status: {}", - WGPUQueueWorkDoneStatus_Success == status); - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(callbackData); - data->promise->set_value(); - }, - &callbackData); - // Wait for the promise to be set - while (future.wait_for(std::chrono::seconds(0)) != - std::future_status::ready) { - wgpuInstanceProcessEvents(instance); - } - } - - // Copy the output data back to the CPU. This requires its own command encoder - // and command buffer. As with the computation a job is asynchronously - // submitted to the queue and a callback is set up to handle the completion - // of the job which updates the promise. - // - // The execution blocks on the future until the promise is set, after which - // the result of the computation is copied to the outputArr array and is - // printed. - spdlog::info("Copying output to the CPU"); - { - // reset the promise and future - promise = std::promise(); - future = promise.get_future(); - spdlog::info("Setting up command encoder and command buffer for copying " - "output to the CPU"); - { - WGPUCommandEncoder commandEncoder; - WGPUComputePassEncoder computePassEncoder; - commandEncoder = wgpuDeviceCreateCommandEncoder(device, nullptr); - wgpuCommandEncoderCopyBufferToBuffer(commandEncoder, outputBuffer, 0, - readbackBuffer, 0, bufferSize); - commandBuffer = wgpuCommandEncoderFinish(commandEncoder, nullptr); - check(commandBuffer, "Create command buffer", __FILE__, __LINE__); - } - wgpuQueueSubmit(queue, 1, &commandBuffer); - CallbackData callbackData = {readbackBuffer, bufferSize, outputArr.data(), - &promise}; - wgpuQueueOnSubmittedWorkDone( - queue, - [](WGPUQueueWorkDoneStatus status, void *callbackData) { - spdlog::info("QueueOnSubmittedWorkDone status: {}", - WGPUQueueWorkDoneStatus_Success == status); - check(status == WGPUQueueWorkDoneStatus_Success, "Queue work done", - __FILE__, __LINE__); - const auto *data = static_cast(callbackData); - wgpuBufferMapAsync( - data->buffer, WGPUMapMode_Read, 0, bufferSize, - [](WGPUBufferMapAsyncStatus status, void *captureData) { - const auto *data = static_cast(captureData); - check(status == WGPUBufferMapAsyncStatus_Success, - "Map readbackBuffer", __FILE__, __LINE__); - const void *mappedData = wgpuBufferGetConstMappedRange( - data->buffer, /*offset=*/0, data->bufferSize); - check(mappedData, "Get mapped range", __FILE__, __LINE__); - memcpy(data->output, mappedData, data->bufferSize); - wgpuBufferUnmap(data->buffer); - data->promise->set_value(); - }, - callbackData); - }, - &callbackData); - while (future.wait_for(std::chrono::seconds(0)) != - std::future_status::ready) { - wgpuInstanceProcessEvents(instance); - } - } - - spdlog::info("{}", show(inputArr, "GELU Input")); - spdlog::info("{}", show(outputArr, "GELU Output")); - spdlog::info("Done with GELU kernel"); -} diff --git a/gpu.hpp b/gpu.hpp index c34894a..c4b6fe9 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -518,12 +518,56 @@ inline void processEvents(const WGPUInstance &instance) { * to simplify lifetime management of GPU resources. */ struct Context { - WGPUInstance instance; - WGPUAdapter adapter; - WGPUDevice device; - WGPUQueue queue; + WGPUInstance instance = nullptr; + WGPUAdapter adapter = nullptr; + WGPUDevice device = nullptr; + WGPUQueue queue = nullptr; TensorPool pool = TensorPool(this); KernelPool kernelPool = KernelPool(this); + WGPURequestAdapterStatus adapterStatus; + WGPURequestDeviceStatus deviceStatus; + + // Default constructor + Context() = default; + + // Move constructor: steals GPU handles so the source destructor won't free them. + Context(Context&& other) noexcept + : instance(other.instance), + adapter(other.adapter), + device(other.device), + queue(other.queue), + // Re‐initialize pools to point to *this*: + pool(this), + kernelPool(this), + adapterStatus(other.adapterStatus), + deviceStatus(other.deviceStatus) + { + // Move over the resources in the pools: + pool.data = std::move(other.pool.data); + kernelPool.data = std::move(other.kernelPool.data); + + // Null out handles in the source so its destructor won't release them. + other.instance = nullptr; + other.adapter = nullptr; + other.device = nullptr; + other.queue = nullptr; + // other.adapterStatus = 0; + // other.deviceStatus = 0; + } + + // Optional move‐assignment operator, similarly stealing resources: + Context& operator=(Context&& other) noexcept { + if (this != &other) { + // Free any existing resources. In most cases, this should be a no-op + // since we typically shouldn't have two active initialized Context + // instances with resources acquired. + this->~Context(); + // Then placement‐new a move‐constructed copy in-place: + new (this) Context(std::move(other)); + } + return *this; + } + ~Context() { LOG(kDefLog, kTrace, "Destroying context"); if (queue) { @@ -582,6 +626,7 @@ inline Tensor createTensor(TensorPool &pool, WGPUDevice &device, size_t numElements = size(shape); size_t size = sizeBytes(dtype) * numElements; WGPUBufferDescriptor bufferDesc = { + .label = {.data = nullptr, .length = 0}, .usage = usage, .size = size, }; @@ -767,66 +812,61 @@ inline void check(bool condition, const char *message, * @param[in] devDescriptor Device descriptor for the WebGPU device (optional) * @return Context instance representing the created GPU context * - * @code - * Context ctx = createContext(); - * @endcode */ -inline Context createContext(const WGPUInstanceDescriptor &desc = {}, - const WGPURequestAdapterOptions &adapterOpts = {}, - const WGPUDeviceDescriptor &devDescriptor = {}) { - Context context; - { +inline Context createContext( + const WGPUInstanceDescriptor &desc = {}, + const WGPURequestAdapterOptions &adapterOpts = {}, + const WGPUDeviceDescriptor &devDescriptor = {}) +{ + Context ctx; // stack-allocated + #ifdef __EMSCRIPTEN__ - context.instance = wgpuCreateInstance(nullptr); + ctx.instance = wgpuCreateInstance(nullptr); #else - context.instance = wgpuCreateInstance(&desc); + ctx.instance = wgpuCreateInstance(&desc); #endif - check(context.instance, "Initialize WebGPU", __FILE__, __LINE__); - } + check(ctx.instance, "Initialize WebGPU", __FILE__, __LINE__); LOG(kDefLog, kInfo, "Requesting adapter"); { struct AdapterData { WGPUAdapter adapter = nullptr; bool requestEnded = false; + WGPURequestAdapterStatus status; }; AdapterData adapterData; auto onAdapterRequestEnded = [](WGPURequestAdapterStatus status, - WGPUAdapter adapter, WGPUStringView message, + WGPUAdapter adapter, + WGPUStringView message, void *pUserData, void *) { - AdapterData &adapterData = *reinterpret_cast(pUserData); + auto &ad = *reinterpret_cast(pUserData); + ad.status = status; #ifdef __EMSCRIPTEN__ if (status != WGPURequestAdapterStatus_Success) { LOG(kDefLog, kError, "Could not get WebGPU adapter: %.*s", static_cast(message.length), message.data); - LOG(kDefLog, kError, - "\n\nA common reason is that the browser does not have WebGPU " - "enabled, particularly on Linux.\n" - "- Open `chrome://flags/` in the browser and make sure " - "\"WebGPU Support\" is enabled.\n" - "- Chrome is launched with vulkan enabled. From the command line " - "launch chrome as `google-chrome --enable-features=Vulkan`\n"); } #endif check(status == WGPURequestAdapterStatus_Success, "Request WebGPU adapter", __FILE__, __LINE__); - adapterData.adapter = adapter; - adapterData.requestEnded = true; + ad.adapter = adapter; + ad.requestEnded = true; }; - WGPURequestAdapterCallbackInfo callbackInfo = { - .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = onAdapterRequestEnded, - .userdata1 = &adapterData, - .userdata2 = nullptr}; - wgpuInstanceRequestAdapter(context.instance, &adapterOpts, callbackInfo); + WGPURequestAdapterCallbackInfo callbackInfo { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = onAdapterRequestEnded, + .userdata1 = &adapterData, + .userdata2 = nullptr + }; + wgpuInstanceRequestAdapter(ctx.instance, &adapterOpts, callbackInfo); while (!adapterData.requestEnded) { - processEvents(context.instance); + processEvents(ctx.instance); } - assert(adapterData.requestEnded); - context.adapter = adapterData.adapter; + ctx.adapter = adapterData.adapter; + ctx.adapterStatus = adapterData.status; } LOG(kDefLog, kInfo, "Requesting device"); @@ -834,55 +874,66 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {}, struct DeviceData { WGPUDevice device = nullptr; bool requestEnded = false; + WGPURequestDeviceStatus status; }; DeviceData devData; auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status, - WGPUDevice device, WGPUStringView message, + WGPUDevice device, + WGPUStringView message, void *pUserData, void *) { - DeviceData &devData = *reinterpret_cast(pUserData); + auto &dd = *reinterpret_cast(pUserData); + dd.status = status; check(status == WGPURequestDeviceStatus_Success, "Could not get WebGPU device.", __FILE__, __LINE__); - LOG(kDefLog, kTrace, "Device Request succeeded %x", - static_cast(device)); - devData.device = device; - devData.requestEnded = true; + LOG(kDefLog, kTrace, "Device Request succeeded %p", + static_cast(device)); + dd.device = device; + dd.requestEnded= true; }; - WGPURequestDeviceCallbackInfo deviceCallbackInfo = { - .mode = WGPUCallbackMode_AllowSpontaneous, - .callback = onDeviceRequestEnded, - .userdata1 = &devData, - .userdata2 = nullptr}; - wgpuAdapterRequestDevice(context.adapter, &devDescriptor, - deviceCallbackInfo); + WGPURequestDeviceCallbackInfo deviceCallbackInfo { + .mode = WGPUCallbackMode_AllowSpontaneous, + .callback = onDeviceRequestEnded, + .userdata1= &devData, + .userdata2= nullptr + }; + wgpuAdapterRequestDevice(ctx.adapter, &devDescriptor, deviceCallbackInfo); LOG(kDefLog, kInfo, "Waiting for device request to end"); while (!devData.requestEnded) { - processEvents(context.instance); + processEvents(ctx.instance); } LOG(kDefLog, kInfo, "Device request ended"); - assert(devData.requestEnded); - context.device = devData.device; - WGPULoggingCallbackInfo loggingCallbackInfo = { + ctx.device = devData.device; + ctx.deviceStatus = devData.status; + + // If the device was created, set up logging and fetch the queue + if (devData.status == WGPURequestDeviceStatus_Success) { + WGPULoggingCallbackInfo loggingCallbackInfo { + .nextInChain = nullptr, .callback = - [](WGPULoggingType type, WGPUStringView message, void *userdata1, - void *userdata2) { - LOG(kDefLog, kError, "Device logging callback: %.*s", - (int)message.length, message.data); - if (type == WGPULoggingType_Error) { - throw std::runtime_error("Device error logged."); - } - }, + [](WGPULoggingType type, WGPUStringView message, + void *, void *) { + LOG(kDefLog, kError, "Device logging callback: %.*s", + static_cast(message.length), message.data); + if (type == WGPULoggingType_Error) { + throw std::runtime_error("Device error logged."); + } + }, .userdata1 = nullptr, - .userdata2 = nullptr}; - wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo); + .userdata2 = nullptr + }; + wgpuDeviceSetLoggingCallback(ctx.device, loggingCallbackInfo); + ctx.queue = wgpuDeviceGetQueue(ctx.device); + } } - context.queue = wgpuDeviceGetQueue(context.device); - return context; + + return std::move(ctx); } + #ifdef USE_DAWN_API /** * @brief Factory function to create a GPU context, which aggregates WebGPU API @@ -995,11 +1046,12 @@ createContextByGpuIdx(int gpuIdx, const WGPUInstanceDescriptor &desc = {}, context.device = devData.device; WGPULoggingCallbackInfo loggingCallbackInfo = { + .nextInChain = nullptr, .callback = [](WGPULoggingType type, WGPUStringView message, void *userdata1, void *userdata2) { LOG(kDefLog, kError, "Device logging callback: %.*s", - (int)message.length, message.data); + static_cast(message.length), message.data); if (type == WGPULoggingType_Error) { throw std::runtime_error("Device error logged."); } @@ -1093,6 +1145,7 @@ inline void toCPU(Context &ctx, Tensor &tensor, void *data, size_t bufferSize) { op.future = op.promise.get_future(); { WGPUBufferDescriptor readbackBufferDescriptor = { + .label = {.data = nullptr, .length = 0}, .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, .size = bufferSize, }; @@ -1136,6 +1189,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, void *data, size_t size) { op.future = op.promise.get_future(); { WGPUBufferDescriptor readbackBufferDescriptor = { + .label = {.data = nullptr, .length = 0}, .usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead, .size = bufferSize, }; @@ -1348,7 +1402,7 @@ inline Shape cdiv(Shape total, Shape group) { * @endcode * output, nThreads, params, paramsSize); */ -inline Kernel createKernel(Context &ctx, const KernelCode &code, +inline Kernel createKernel(Context& ctx, const KernelCode &code, const Tensor *dataBindings, size_t numTensors, const size_t *viewOffsets, const Shape &totalWorkgroups, @@ -1422,6 +1476,7 @@ inline Kernel createKernel(Context &ctx, const KernelCode &code, // Create a buffer for the Params struct if (paramsSize > 0) { WGPUBufferDescriptor paramsBufferDesc = { + .label = {.data = nullptr, .length = 0}, .usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst, .size = paramsSize, .mappedAtCreation = false, diff --git a/third_party/headers/portaudio/portaudio.h b/third_party/headers/portaudio/portaudio.h deleted file mode 100644 index b8878cc..0000000 --- a/third_party/headers/portaudio/portaudio.h +++ /dev/null @@ -1,1251 +0,0 @@ -#ifndef PORTAUDIO_H -#define PORTAUDIO_H -/* - * $Id$ - * PortAudio Portable Real-Time Audio Library - * PortAudio API Header File - * Latest version available at: http://www.portaudio.com/ - * - * Copyright (c) 1999-2002 Ross Bencina and Phil Burk - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files - * (the "Software"), to deal in the Software without restriction, - * including without limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of the Software, - * and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR - * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF - * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION - * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -/* - * The text above constitutes the entire PortAudio license; however, - * the PortAudio community also makes the following non-binding requests: - * - * Any person wishing to distribute modifications to the Software is - * requested to send the modifications to the original developer so that - * they can be incorporated into the canonical version. It is also - * requested that these non-binding requests be included along with the - * license above. - */ - -/** @file - @ingroup public_header - @brief The portable PortAudio API. -*/ - - -#ifdef __cplusplus -extern "C" -{ -#endif /* __cplusplus */ - -/** Retrieve the release number of the currently running PortAudio build. - For example, for version "19.5.1" this will return 0x00130501. - - @see paMakeVersionNumber -*/ -int Pa_GetVersion( void ); - -/** Retrieve a textual description of the current PortAudio build, - e.g. "PortAudio V19.5.0-devel, revision 1952M". - The format of the text may change in the future. Do not try to parse the - returned string. - - @deprecated As of 19.5.0, use Pa_GetVersionInfo()->versionText instead. -*/ -const char* Pa_GetVersionText( void ); - -/** - Generate a packed integer version number in the same format used - by Pa_GetVersion(). Use this to compare a specified version number with - the currently running version. For example: - - @code - if( Pa_GetVersion() < paMakeVersionNumber(19,5,1) ) {} - @endcode - - @see Pa_GetVersion, Pa_GetVersionInfo - @version Available as of 19.5.0. -*/ -#define paMakeVersionNumber(major, minor, subminor) \ - (((major)&0xFF)<<16 | ((minor)&0xFF)<<8 | ((subminor)&0xFF)) - - -/** - A structure containing PortAudio API version information. - @see Pa_GetVersionInfo, paMakeVersionNumber - @version Available as of 19.5.0. -*/ -typedef struct PaVersionInfo { - int versionMajor; - int versionMinor; - int versionSubMinor; - /** - This is currently the Git revision hash but may change in the future. - The versionControlRevision is updated by running a script before compiling the library. - If the update does not occur, this value may refer to an earlier revision. - Encoded as UTF-8. - */ - const char *versionControlRevision; - /** Version as a string, for example "PortAudio V19.5.0-devel, revision 1952M". Encoded as UTF-8. */ - const char *versionText; -} PaVersionInfo; - -/** Retrieve version information for the currently running PortAudio build. - @return A pointer to an immutable PaVersionInfo structure. - - @note This function can be called at any time. It does not require PortAudio - to be initialized. The structure pointed to is statically allocated. Do not - attempt to free it or modify it. - - @see PaVersionInfo, paMakeVersionNumber - @version Available as of 19.5.0. -*/ -const PaVersionInfo* Pa_GetVersionInfo( void ); - - -/** Error codes returned by PortAudio functions. - Note that with the exception of paNoError, all PaErrorCodes are negative. -*/ - -typedef int PaError; -typedef enum PaErrorCode -{ - paNoError = 0, - - paNotInitialized = -10000, - paUnanticipatedHostError, - paInvalidChannelCount, - paInvalidSampleRate, - paInvalidDevice, - paInvalidFlag, - paSampleFormatNotSupported, - paBadIODeviceCombination, - paInsufficientMemory, - paBufferTooBig, - paBufferTooSmall, - paNullCallback, - paBadStreamPtr, - paTimedOut, - paInternalError, - paDeviceUnavailable, - paIncompatibleHostApiSpecificStreamInfo, - paStreamIsStopped, - paStreamIsNotStopped, - paInputOverflowed, - paOutputUnderflowed, - paHostApiNotFound, - paInvalidHostApi, - paCanNotReadFromACallbackStream, - paCanNotWriteToACallbackStream, - paCanNotReadFromAnOutputOnlyStream, - paCanNotWriteToAnInputOnlyStream, - paIncompatibleStreamHostApi, - paBadBufferPtr, - paCanNotInitializeRecursively -} PaErrorCode; - - -/** Translate the supplied PortAudio error code into a human readable - message, encoded as UTF-8. -*/ -const char *Pa_GetErrorText( PaError errorCode ); - - -/** Library initialization function - call this before using PortAudio. - This function initializes internal data structures and prepares underlying - host APIs for use. With the exception of Pa_GetVersion(), Pa_GetVersionText(), - and Pa_GetErrorText(), this function MUST be called before using any other - PortAudio API functions. - - If Pa_Initialize() is called multiple times, each successful - call must be matched with a corresponding call to Pa_Terminate(). - Pairs of calls to Pa_Initialize() / Pa_Terminate() may overlap, and are not - required to be fully nested. - - Note that if Pa_Initialize() returns an error code, Pa_Terminate() should - NOT be called. - - @return paNoError if successful, otherwise an error code indicating the cause - of failure. - - @see Pa_Terminate -*/ -PaError Pa_Initialize( void ); - - -/** Library termination function - call this when finished using PortAudio. - This function deallocates all resources allocated by PortAudio since it was - initialized by a call to Pa_Initialize(). In cases where Pa_Initialize() has - been called multiple times, each call must be matched with a corresponding call - to Pa_Terminate(). The final matching call to Pa_Terminate() will automatically - close any PortAudio streams that are still open. - - Pa_Terminate() MUST be called before exiting a program which uses PortAudio. - Failure to do so may result in serious resource leaks, such as audio devices - not being available until the next reboot. - - @return paNoError if successful, otherwise an error code indicating the cause - of failure. - - @see Pa_Initialize -*/ -PaError Pa_Terminate( void ); - - - -/** The type used to refer to audio devices. Values of this type usually - range from 0 to (Pa_GetDeviceCount()-1), and may also take on the PaNoDevice - and paUseHostApiSpecificDeviceSpecification values. - - @see Pa_GetDeviceCount, paNoDevice, paUseHostApiSpecificDeviceSpecification -*/ -typedef int PaDeviceIndex; - - -/** A special PaDeviceIndex value indicating that no device is available, - or should be used. - - @see PaDeviceIndex -*/ -#define paNoDevice ((PaDeviceIndex)-1) - - -/** A special PaDeviceIndex value indicating that the device(s) to be used - are specified in the host api specific stream info structure. - - @see PaDeviceIndex -*/ -#define paUseHostApiSpecificDeviceSpecification ((PaDeviceIndex)-2) - - -/* Host API enumeration mechanism */ - -/** The type used to enumerate to host APIs at runtime. Values of this type - range from 0 to (Pa_GetHostApiCount()-1). - - @see Pa_GetHostApiCount -*/ -typedef int PaHostApiIndex; - - -/** Retrieve the number of available host APIs. Even if a host API is - available it may have no devices available. - - @return A non-negative value indicating the number of available host APIs - or, a PaErrorCode (which are always negative) if PortAudio is not initialized - or an error is encountered. - - @see PaHostApiIndex -*/ -PaHostApiIndex Pa_GetHostApiCount( void ); - - -/** Retrieve the index of the default host API. The default host API will be - the lowest common denominator host API on the current platform and is - unlikely to provide the best performance. - - @return A non-negative value ranging from 0 to (Pa_GetHostApiCount()-1) - indicating the default host API index or, a PaErrorCode (which are always - negative) if PortAudio is not initialized or an error is encountered. -*/ -PaHostApiIndex Pa_GetDefaultHostApi( void ); - - -/** Unchanging unique identifiers for each supported host API. This type - is used in the PaHostApiInfo structure. The values are guaranteed to be - unique and to never change, thus allowing code to be written that - conditionally uses host API specific extensions. - - New type ids will be allocated when support for a host API reaches - "public alpha" status, prior to that developers should use the - paInDevelopment type id. - - @see PaHostApiInfo -*/ -typedef enum PaHostApiTypeId -{ - paInDevelopment=0, /* use while developing support for a new host API */ - paDirectSound=1, - paMME=2, - paASIO=3, - paSoundManager=4, - paCoreAudio=5, - paOSS=7, - paALSA=8, - paAL=9, - paBeOS=10, - paWDMKS=11, - paJACK=12, - paWASAPI=13, - paAudioScienceHPI=14, - paAudioIO=15, - paPulseAudio=16, - paSndio=17 -} PaHostApiTypeId; - - -/** A structure containing information about a particular host API. */ - -typedef struct PaHostApiInfo -{ - /** this is struct version 1 */ - int structVersion; - /** The well known unique identifier of this host API @see PaHostApiTypeId */ - PaHostApiTypeId type; - /** A textual description of the host API for display on user interfaces. Encoded as UTF-8. */ - const char *name; - - /** The number of devices belonging to this host API. This field may be - used in conjunction with Pa_HostApiDeviceIndexToDeviceIndex() to enumerate - all devices for this host API. - @see Pa_HostApiDeviceIndexToDeviceIndex - */ - int deviceCount; - - /** The default input device for this host API. The value will be a - device index ranging from 0 to (Pa_GetDeviceCount()-1), or paNoDevice - if no default input device is available. - */ - PaDeviceIndex defaultInputDevice; - - /** The default output device for this host API. The value will be a - device index ranging from 0 to (Pa_GetDeviceCount()-1), or paNoDevice - if no default output device is available. - */ - PaDeviceIndex defaultOutputDevice; - -} PaHostApiInfo; - - -/** Retrieve a pointer to a structure containing information about a specific - host Api. - - @param hostApi A valid host API index ranging from 0 to (Pa_GetHostApiCount()-1) - - @return A pointer to an immutable PaHostApiInfo structure describing - a specific host API. If the hostApi parameter is out of range or an error - is encountered, the function returns NULL. - - The returned structure is owned by the PortAudio implementation and must not - be manipulated or freed. The pointer is only guaranteed to be valid between - calls to Pa_Initialize() and Pa_Terminate(). -*/ -const PaHostApiInfo * Pa_GetHostApiInfo( PaHostApiIndex hostApi ); - - -/** Convert a static host API unique identifier, into a runtime - host API index. - - @param type A unique host API identifier belonging to the PaHostApiTypeId - enumeration. - - @return A valid PaHostApiIndex ranging from 0 to (Pa_GetHostApiCount()-1) or, - a PaErrorCode (which are always negative) if PortAudio is not initialized - or an error is encountered. - - The paHostApiNotFound error code indicates that the host API specified by the - type parameter is not available. - - @see PaHostApiTypeId -*/ -PaHostApiIndex Pa_HostApiTypeIdToHostApiIndex( PaHostApiTypeId type ); - - -/** Convert a host-API-specific device index to standard PortAudio device index. - This function may be used in conjunction with the deviceCount field of - PaHostApiInfo to enumerate all devices for the specified host API. - - @param hostApi A valid host API index ranging from 0 to (Pa_GetHostApiCount()-1) - - @param hostApiDeviceIndex A valid per-host device index in the range - 0 to (Pa_GetHostApiInfo(hostApi)->deviceCount-1) - - @return A non-negative PaDeviceIndex ranging from 0 to (Pa_GetDeviceCount()-1) - or, a PaErrorCode (which are always negative) if PortAudio is not initialized - or an error is encountered. - - A paInvalidHostApi error code indicates that the host API index specified by - the hostApi parameter is out of range. - - A paInvalidDevice error code indicates that the hostApiDeviceIndex parameter - is out of range. - - @see PaHostApiInfo -*/ -PaDeviceIndex Pa_HostApiDeviceIndexToDeviceIndex( PaHostApiIndex hostApi, - int hostApiDeviceIndex ); - - - -/** Structure used to return information about a host error condition. -*/ -typedef struct PaHostErrorInfo{ - PaHostApiTypeId hostApiType; /**< the host API which returned the error code */ - long errorCode; /**< the error code returned */ - const char *errorText; /**< a textual description of the error if available (encoded as UTF-8), otherwise a zero-length C string */ -}PaHostErrorInfo; - - -/** Return information about the last host error encountered. The error - information returned by Pa_GetLastHostErrorInfo() will never be modified - asynchronously by errors occurring in other PortAudio owned threads - (such as the thread that manages the stream callback.) - - This function is provided as a last resort, primarily to enhance debugging - by providing clients with access to all available error information. - - @return A pointer to an immutable structure constraining information about - the host error. The values in this structure will only be valid if a - PortAudio function has previously returned the paUnanticipatedHostError - error code. -*/ -const PaHostErrorInfo* Pa_GetLastHostErrorInfo( void ); - - - -/* Device enumeration and capabilities */ - -/** Retrieve the number of available devices. The number of available devices - may be zero. - - @return A non-negative value indicating the number of available devices or, - a PaErrorCode (which are always negative) if PortAudio is not initialized - or an error is encountered. -*/ -PaDeviceIndex Pa_GetDeviceCount( void ); - - -/** Retrieve the index of the default input device. The result can be - used in the inputDevice parameter to Pa_OpenStream(). - - @return The default input device index for the default host API, or paNoDevice - if no default input device is available or an error was encountered. -*/ -PaDeviceIndex Pa_GetDefaultInputDevice( void ); - - -/** Retrieve the index of the default output device. The result can be - used in the outputDevice parameter to Pa_OpenStream(). - - @return The default output device index for the default host API, or paNoDevice - if no default output device is available or an error was encountered. - - @note - On the PC, the user can specify a default device by - setting an environment variable. For example, to use device #1. -
- set PA_RECOMMENDED_OUTPUT_DEVICE=1
-
- The user should first determine the available device ids by using - the supplied application "pa_devs". -*/ -PaDeviceIndex Pa_GetDefaultOutputDevice( void ); - - -/** The type used to represent monotonic time in seconds. PaTime is - used for the fields of the PaStreamCallbackTimeInfo argument to the - PaStreamCallback and as the result of Pa_GetStreamTime(). - - PaTime values have unspecified origin. - - @see PaStreamCallback, PaStreamCallbackTimeInfo, Pa_GetStreamTime -*/ -typedef double PaTime; - - -/** A type used to specify one or more sample formats. Each value indicates - a possible format for sound data passed to and from the stream callback, - Pa_ReadStream() and Pa_WriteStream(). - - The standard formats paFloat32, paInt16, paInt32, paInt24, paInt8 - and paUInt8 are usually implemented by all implementations. - - The floating point representation (paFloat32) uses +1.0 and -1.0 as the - maximum and minimum respectively. - - paUInt8 is an unsigned 8 bit format where 128 is considered "ground" - - The paNonInterleaved flag indicates that audio data is passed as an array - of pointers to separate buffers, one buffer for each channel. Usually, - when this flag is not used, audio data is passed as a single buffer with - all channels interleaved. - - @see Pa_OpenStream, Pa_OpenDefaultStream, PaDeviceInfo - @see paFloat32, paInt16, paInt32, paInt24, paInt8 - @see paUInt8, paCustomFormat, paNonInterleaved -*/ -typedef unsigned long PaSampleFormat; - - -#define paFloat32 ((PaSampleFormat) 0x00000001) /**< @see PaSampleFormat */ -#define paInt32 ((PaSampleFormat) 0x00000002) /**< @see PaSampleFormat */ -#define paInt24 ((PaSampleFormat) 0x00000004) /**< Packed 24 bit format. @see PaSampleFormat */ -#define paInt16 ((PaSampleFormat) 0x00000008) /**< @see PaSampleFormat */ -#define paInt8 ((PaSampleFormat) 0x00000010) /**< @see PaSampleFormat */ -#define paUInt8 ((PaSampleFormat) 0x00000020) /**< @see PaSampleFormat */ -#define paCustomFormat ((PaSampleFormat) 0x00010000) /**< @see PaSampleFormat */ - -#define paNonInterleaved ((PaSampleFormat) 0x80000000) /**< @see PaSampleFormat */ - -/** A structure providing information and capabilities of PortAudio devices. - Devices may support input, output or both input and output. -*/ -typedef struct PaDeviceInfo -{ - int structVersion; /**< this is struct version 2 */ - - /** Human readable device name. Encoded as UTF-8. */ - const char *name; - - /** Host API index in the range 0 to (Pa_GetHostApiCount()-1). Note: this is a host API index, not a type id. */ - PaHostApiIndex hostApi; - - int maxInputChannels; - int maxOutputChannels; - - /** Default latency values for interactive performance. */ - PaTime defaultLowInputLatency; - PaTime defaultLowOutputLatency; - /** Default latency values for robust non-interactive applications (eg. playing sound files). */ - PaTime defaultHighInputLatency; - PaTime defaultHighOutputLatency; - - double defaultSampleRate; -} PaDeviceInfo; - - -/** Retrieve a pointer to a PaDeviceInfo structure containing information - about the specified device. - @return A pointer to an immutable PaDeviceInfo structure. If the device - parameter is out of range the function returns NULL. - - @param device A valid device index in the range 0 to (Pa_GetDeviceCount()-1) - - @note PortAudio manages the memory referenced by the returned pointer, - the client must not manipulate or free the memory. The pointer is only - guaranteed to be valid between calls to Pa_Initialize() and Pa_Terminate(). - - @see PaDeviceInfo, PaDeviceIndex -*/ -const PaDeviceInfo* Pa_GetDeviceInfo( PaDeviceIndex device ); - - -/** Parameters for one direction (input or output) of a stream. -*/ -typedef struct PaStreamParameters -{ - /** A valid device index in the range 0 to (Pa_GetDeviceCount()-1) - specifying the device to be used or the special constant - paUseHostApiSpecificDeviceSpecification which indicates that the actual - device(s) to use are specified in hostApiSpecificStreamInfo. - This field must not be set to paNoDevice. - */ - PaDeviceIndex device; - - /** The number of channels of sound to be delivered to the - stream callback or accessed by Pa_ReadStream() or Pa_WriteStream(). - It can range from 1 to the value of maxInputChannels in the - PaDeviceInfo record for the device specified by the device parameter. - */ - int channelCount; - - /** The sample format of the buffer provided to the stream callback, - Pa_ReadStream() or Pa_WriteStream(). It may be any of the formats described - by the PaSampleFormat enumeration. - */ - PaSampleFormat sampleFormat; - - /** The desired latency in seconds. Where practical, implementations should - configure their latency based on these parameters, otherwise they may - choose the closest viable latency instead. Unless the suggested latency - is greater than the absolute upper limit for the device implementations - should round the suggestedLatency up to the next practical value - ie to - provide an equal or higher latency than suggestedLatency wherever possible. - Actual latency values for an open stream may be retrieved using the - inputLatency and outputLatency fields of the PaStreamInfo structure - returned by Pa_GetStreamInfo(). - @see default*Latency in PaDeviceInfo, *Latency in PaStreamInfo - */ - PaTime suggestedLatency; - - /** An optional pointer to a host api specific data structure - containing additional information for device setup and/or stream processing. - hostApiSpecificStreamInfo is never required for correct operation, - if not used it should be set to NULL. - */ - void *hostApiSpecificStreamInfo; - -} PaStreamParameters; - - -/** Return code for Pa_IsFormatSupported() indicating success. */ -#define paFormatIsSupported (0) - -/** Determine whether it would be possible to open a stream with the specified - parameters. - - @param inputParameters A structure that describes the input parameters used to - open a stream. The suggestedLatency field is ignored. See PaStreamParameters - for a description of these parameters. inputParameters must be NULL for - output-only streams. - - @param outputParameters A structure that describes the output parameters used - to open a stream. The suggestedLatency field is ignored. See PaStreamParameters - for a description of these parameters. outputParameters must be NULL for - input-only streams. - - @param sampleRate The required sampleRate. For full-duplex streams it is the - sample rate for both input and output. - - @return Returns 0 if the format is supported, and an error code indicating why - the format is not supported otherwise. The constant paFormatIsSupported is - provided to compare with the return value for success. - - @see paFormatIsSupported, PaStreamParameters -*/ -PaError Pa_IsFormatSupported( const PaStreamParameters *inputParameters, - const PaStreamParameters *outputParameters, - double sampleRate ); - - - -/* Streaming types and functions */ - - -/** - A single PaStream can provide multiple channels of real-time - streaming audio input and output to a client application. A stream - provides access to audio hardware represented by one or more - PaDevices. Depending on the underlying Host API, it may be possible - to open multiple streams using the same device, however this behavior - is implementation defined. Portable applications should assume that - a PaDevice may be simultaneously used by at most one PaStream. - - Pointers to PaStream objects are passed between PortAudio functions that - operate on streams. - - @see Pa_OpenStream, Pa_OpenDefaultStream, Pa_OpenDefaultStream, Pa_CloseStream, - Pa_StartStream, Pa_StopStream, Pa_AbortStream, Pa_IsStreamActive, - Pa_GetStreamTime, Pa_GetStreamCpuLoad - -*/ -typedef void PaStream; - - -/** Can be passed as the framesPerBuffer parameter to Pa_OpenStream() - or Pa_OpenDefaultStream() to indicate that the stream callback will - accept buffers of any size. -*/ -#define paFramesPerBufferUnspecified (0) - - -/** Flags used to control the behavior of a stream. They are passed as - parameters to Pa_OpenStream() or Pa_OpenDefaultStream(). Multiple flags may be - ORed together. - - @see Pa_OpenStream, Pa_OpenDefaultStream - @see paNoFlag, paClipOff, paDitherOff, paNeverDropInput, - paPrimeOutputBuffersUsingStreamCallback, paPlatformSpecificFlags -*/ -typedef unsigned long PaStreamFlags; - -/** @see PaStreamFlags */ -#define paNoFlag ((PaStreamFlags) 0) - -/** Disable default clipping of out of range samples. - @see PaStreamFlags -*/ -#define paClipOff ((PaStreamFlags) 0x00000001) - -/** Disable default dithering. - @see PaStreamFlags -*/ -#define paDitherOff ((PaStreamFlags) 0x00000002) - -/** Flag requests that where possible a full duplex stream will not discard - overflowed input samples without calling the stream callback. This flag is - only valid for full duplex callback streams and only when used in combination - with the paFramesPerBufferUnspecified (0) framesPerBuffer parameter. Using - this flag incorrectly results in a paInvalidFlag error being returned from - Pa_OpenStream() and Pa_OpenDefaultStream(). - - @see PaStreamFlags, paFramesPerBufferUnspecified -*/ -#define paNeverDropInput ((PaStreamFlags) 0x00000004) - -/** Call the stream callback to fill initial output buffers, rather than the - default behavior of priming the buffers with zeros (silence). This flag has - no effect for input-only and blocking read/write streams. - - @see PaStreamFlags -*/ -#define paPrimeOutputBuffersUsingStreamCallback ((PaStreamFlags) 0x00000008) - -/** A mask specifying the platform specific bits. - @see PaStreamFlags -*/ -#define paPlatformSpecificFlags ((PaStreamFlags)0xFFFF0000) - -/** - Timing information for the buffers passed to the stream callback. - - Time values are expressed in seconds and are synchronised with the time base used by Pa_GetStreamTime() for the associated stream. - - @see PaStreamCallback, Pa_GetStreamTime -*/ -typedef struct PaStreamCallbackTimeInfo{ - PaTime inputBufferAdcTime; /**< The time when the first sample of the input buffer was captured at the ADC input */ - PaTime currentTime; /**< The time when the stream callback was invoked */ - PaTime outputBufferDacTime; /**< The time when the first sample of the output buffer will output the DAC */ -} PaStreamCallbackTimeInfo; - - -/** - Flag bit constants for the statusFlags to PaStreamCallback. - - @see paInputUnderflow, paInputOverflow, paOutputUnderflow, paOutputOverflow, - paPrimingOutput -*/ -typedef unsigned long PaStreamCallbackFlags; - -/** In a stream opened with paFramesPerBufferUnspecified, indicates that - input data is all silence (zeros) because no real data is available. In a - stream opened without paFramesPerBufferUnspecified, it indicates that one or - more zero samples have been inserted into the input buffer to compensate - for an input underflow. - @see PaStreamCallbackFlags -*/ -#define paInputUnderflow ((PaStreamCallbackFlags) 0x00000001) - -/** In a stream opened with paFramesPerBufferUnspecified, indicates that data - prior to the first sample of the input buffer was discarded due to an - overflow, possibly because the stream callback is using too much CPU time. - Otherwise indicates that data prior to one or more samples in the - input buffer was discarded. - @see PaStreamCallbackFlags -*/ -#define paInputOverflow ((PaStreamCallbackFlags) 0x00000002) - -/** Indicates that output data (or a gap) was inserted, possibly because the - stream callback is using too much CPU time. - @see PaStreamCallbackFlags -*/ -#define paOutputUnderflow ((PaStreamCallbackFlags) 0x00000004) - -/** Indicates that output data will be discarded because no room is available. - @see PaStreamCallbackFlags -*/ -#define paOutputOverflow ((PaStreamCallbackFlags) 0x00000008) - -/** Some of all of the output data will be used to prime the stream, input - data may be zero. - @see PaStreamCallbackFlags -*/ -#define paPrimingOutput ((PaStreamCallbackFlags) 0x00000010) - -/** - Allowable return values for the PaStreamCallback. - @see PaStreamCallback -*/ -typedef enum PaStreamCallbackResult -{ - paContinue=0, /**< Signal that the stream should continue invoking the callback and processing audio. */ - paComplete=1, /**< Signal that the stream should stop invoking the callback and finish once all output samples have played. */ - paAbort=2 /**< Signal that the stream should stop invoking the callback and finish as soon as possible. */ -} PaStreamCallbackResult; - - -/** - Functions of type PaStreamCallback are implemented by PortAudio clients. - They consume, process or generate audio in response to requests from an - active PortAudio stream. - - When a stream is running, PortAudio calls the stream callback periodically. - The callback function is responsible for processing buffers of audio samples - passed via the input and output parameters. - - The PortAudio stream callback runs at very high or real-time priority. - It is required to consistently meet its time deadlines. Do not allocate - memory, access the file system, call library functions or call other functions - from the stream callback that may block or take an unpredictable amount of - time to complete. - - In order for a stream to maintain glitch-free operation the callback - must consume and return audio data faster than it is recorded and/or - played. PortAudio anticipates that each callback invocation may execute for - a duration approaching the duration of frameCount audio frames at the stream - sample rate. It is reasonable to expect to be able to utilise 70% or more of - the available CPU time in the PortAudio callback. However, due to buffer size - adaption and other factors, not all host APIs are able to guarantee audio - stability under heavy CPU load with arbitrary fixed callback buffer sizes. - When high callback CPU utilisation is required the most robust behavior - can be achieved by using paFramesPerBufferUnspecified as the - Pa_OpenStream() framesPerBuffer parameter. - - @param input and @param output are either arrays of interleaved samples or; - if non-interleaved samples were requested using the paNonInterleaved sample - format flag, an array of buffer pointers, one non-interleaved buffer for - each channel. - - The format, packing and number of channels used by the buffers are - determined by parameters to Pa_OpenStream(). - - @param frameCount The number of sample frames to be processed by - the stream callback. - - @param timeInfo Timestamps indicating the ADC capture time of the first sample - in the input buffer, the DAC output time of the first sample in the output buffer - and the time the callback was invoked. - See PaStreamCallbackTimeInfo and Pa_GetStreamTime() - - @param statusFlags Flags indicating whether input and/or output buffers - have been inserted or will be dropped to overcome underflow or overflow - conditions. - - @param userData The value of a user supplied pointer passed to - Pa_OpenStream() intended for storing synthesis data etc. - - @return - The stream callback should return one of the values in the - ::PaStreamCallbackResult enumeration. To ensure that the callback continues - to be called, it should return paContinue (0). Either paComplete or paAbort - can be returned to finish stream processing, after either of these values is - returned the callback will not be called again. If paAbort is returned the - stream will finish as soon as possible. If paComplete is returned, the stream - will continue until all buffers generated by the callback have been played. - This may be useful in applications such as soundfile players where a specific - duration of output is required. However, it is not necessary to utilize this - mechanism as Pa_StopStream(), Pa_AbortStream() or Pa_CloseStream() can also - be used to stop the stream. The callback must always fill the entire output - buffer irrespective of its return value. - - @see Pa_OpenStream, Pa_OpenDefaultStream - - @note With the exception of Pa_GetStreamCpuLoad() it is not permissible to call - PortAudio API functions from within the stream callback. -*/ -typedef int PaStreamCallback( - const void *input, void *output, - unsigned long frameCount, - const PaStreamCallbackTimeInfo* timeInfo, - PaStreamCallbackFlags statusFlags, - void *userData ); - - -/** Opens a stream for either input, output or both. - - @param stream The address of a PaStream pointer which will receive - a pointer to the newly opened stream. - - @param inputParameters A structure that describes the input parameters used by - the opened stream. See PaStreamParameters for a description of these parameters. - inputParameters must be NULL for output-only streams. - - @param outputParameters A structure that describes the output parameters used by - the opened stream. See PaStreamParameters for a description of these parameters. - outputParameters must be NULL for input-only streams. - - @param sampleRate The desired sampleRate. For full-duplex streams it is the - sample rate for both input and output. Note that the actual sampleRate - may differ very slightly from the desired rate because of hardware limitations. - The exact rate can be queried using Pa_GetStreamInfo(). If nothing close - to the desired sampleRate is available then the open will fail and return an error. - - @param framesPerBuffer The number of frames passed to the stream callback - function, or the preferred block granularity for a blocking read/write stream. - The special value paFramesPerBufferUnspecified (0) may be used to request that - the stream callback will receive an optimal (and possibly varying) number of - frames based on host requirements and the requested latency settings. - Note: With some host APIs, the use of non-zero framesPerBuffer for a callback - stream may introduce an additional layer of buffering which could introduce - additional latency. PortAudio guarantees that the additional latency - will be kept to the theoretical minimum however, it is strongly recommended - that a non-zero framesPerBuffer value only be used when your algorithm - requires a fixed number of frames per stream callback. - - @param streamFlags Flags which modify the behavior of the streaming process. - This parameter may contain a combination of flags ORed together. Some flags may - only be relevant to certain buffer formats. - - @param streamCallback A pointer to a client supplied function that is responsible - for processing and filling input and output buffers. If this parameter is NULL - the stream will be opened in 'blocking read/write' mode. In blocking mode, - the client can receive sample data using Pa_ReadStream() and write sample data - using Pa_WriteStream(), the number of samples that may be read or written - without blocking is returned by Pa_GetStreamReadAvailable() and - Pa_GetStreamWriteAvailable() respectively. - - @param userData A client supplied pointer which is passed to the stream callback - function. It could for example, contain a pointer to instance data necessary - for processing the audio buffers. This parameter is ignored if streamCallback - is NULL. - - @return - Upon success Pa_OpenStream() returns paNoError and places a pointer to a - valid PaStream in the stream argument. The stream is inactive (stopped). - If a call to Pa_OpenStream() fails, a non-zero error code is returned (see - PaError for possible error codes) and the value of stream is invalid. - - @see PaStreamParameters, PaStreamCallback, Pa_ReadStream, Pa_WriteStream, - Pa_GetStreamReadAvailable, Pa_GetStreamWriteAvailable -*/ -PaError Pa_OpenStream( PaStream** stream, - const PaStreamParameters *inputParameters, - const PaStreamParameters *outputParameters, - double sampleRate, - unsigned long framesPerBuffer, - PaStreamFlags streamFlags, - PaStreamCallback *streamCallback, - void *userData ); - - -/** A simplified version of Pa_OpenStream() that opens the default input - and/or output devices. - - @param stream The address of a PaStream pointer which will receive - a pointer to the newly opened stream. - - @param numInputChannels The number of channels of sound that will be supplied - to the stream callback or returned by Pa_ReadStream(). It can range from 1 to - the value of maxInputChannels in the PaDeviceInfo record for the default input - device. If 0 the stream is opened as an output-only stream. - - @param numOutputChannels The number of channels of sound to be delivered to the - stream callback or passed to Pa_WriteStream. It can range from 1 to the value - of maxOutputChannels in the PaDeviceInfo record for the default output device. - If 0 the stream is opened as an input-only stream. - - @param sampleFormat The sample format of both the input and output buffers - provided to the callback or passed to and from Pa_ReadStream() and Pa_WriteStream(). - sampleFormat may be any of the formats described by the PaSampleFormat - enumeration. - - @param sampleRate Same as Pa_OpenStream parameter of the same name. - @param framesPerBuffer Same as Pa_OpenStream parameter of the same name. - @param streamCallback Same as Pa_OpenStream parameter of the same name. - @param userData Same as Pa_OpenStream parameter of the same name. - - @return As for Pa_OpenStream - - @see Pa_OpenStream, PaStreamCallback -*/ -PaError Pa_OpenDefaultStream( PaStream** stream, - int numInputChannels, - int numOutputChannels, - PaSampleFormat sampleFormat, - double sampleRate, - unsigned long framesPerBuffer, - PaStreamCallback *streamCallback, - void *userData ); - - -/** Closes an audio stream. If the audio stream is active it - discards any pending buffers as if Pa_AbortStream() had been called. -*/ -PaError Pa_CloseStream( PaStream *stream ); - - -/** Functions of type PaStreamFinishedCallback are implemented by PortAudio - clients. They can be registered with a stream using the Pa_SetStreamFinishedCallback() - function. Once registered they are called when the stream becomes inactive - (ie once a call to Pa_StopStream() will not block). - A stream will become inactive after the stream callback returns non-zero, - or when Pa_StopStream() or Pa_AbortStream() is called. For a stream providing audio - output, if the stream callback returns paComplete, or Pa_StopStream() is called, - the stream finished callback will not be called until all generated sample data - has been played. - - @param userData The userData parameter supplied to Pa_OpenStream() - - @see Pa_SetStreamFinishedCallback -*/ -typedef void PaStreamFinishedCallback( void *userData ); - - -/** Register a stream finished callback function which will be called when the - stream becomes inactive. See the description of PaStreamFinishedCallback for - further details about when the callback will be called. - - @param stream a pointer to a PaStream that is in the stopped state - if the - stream is not stopped, the stream's finished callback will remain unchanged - and an error code will be returned. - - @param streamFinishedCallback a pointer to a function with the same signature - as PaStreamFinishedCallback, that will be called when the stream becomes - inactive. Passing NULL for this parameter will un-register a previously - registered stream finished callback function. - - @return on success returns paNoError, otherwise an error code indicating the cause - of the error. - - @see PaStreamFinishedCallback -*/ -PaError Pa_SetStreamFinishedCallback( PaStream *stream, PaStreamFinishedCallback* streamFinishedCallback ); - - -/** Commences audio processing. -*/ -PaError Pa_StartStream( PaStream *stream ); - - -/** Terminates audio processing. It waits until all pending - audio buffers have been played before it returns. -*/ -PaError Pa_StopStream( PaStream *stream ); - - -/** Terminates audio processing promptly without necessarily waiting for - pending buffers to complete. -*/ -PaError Pa_AbortStream( PaStream *stream ); - - -/** Determine whether the stream is stopped. - A stream is considered to be stopped prior to a successful call to - Pa_StartStream() and after a successful call to Pa_StopStream() or Pa_AbortStream(). - If a stream callback returns a value other than paContinue the stream is NOT - considered to be stopped. - - @return Returns one (1) when the stream is stopped, zero (0) when - the stream is running or, a PaErrorCode (which are always negative) if - PortAudio is not initialized or an error is encountered. - - @see Pa_StopStream, Pa_AbortStream, Pa_IsStreamActive -*/ -PaError Pa_IsStreamStopped( PaStream *stream ); - - -/** Determine whether the stream is active. - A stream is active after a successful call to Pa_StartStream(), until it - becomes inactive either as a result of a call to Pa_StopStream() or - Pa_AbortStream(), or as a result of a return value other than paContinue from - the stream callback. In the latter case, the stream is considered inactive - after the last buffer has finished playing. - - @return Returns one (1) when the stream is active (ie playing or recording - audio), zero (0) when not playing or, a PaErrorCode (which are always negative) - if PortAudio is not initialized or an error is encountered. - - @see Pa_StopStream, Pa_AbortStream, Pa_IsStreamStopped -*/ -PaError Pa_IsStreamActive( PaStream *stream ); - - - -/** A structure containing unchanging information about an open stream. - @see Pa_GetStreamInfo -*/ - -typedef struct PaStreamInfo -{ - /** this is struct version 1 */ - int structVersion; - - /** The input latency of the stream in seconds. This value provides the most - accurate estimate of input latency available to the implementation. It may - differ significantly from the suggestedLatency value passed to Pa_OpenStream(). - The value of this field will be zero (0.) for output-only streams. - @see PaTime - */ - PaTime inputLatency; - - /** The output latency of the stream in seconds. This value provides the most - accurate estimate of output latency available to the implementation. It may - differ significantly from the suggestedLatency value passed to Pa_OpenStream(). - The value of this field will be zero (0.) for input-only streams. - @see PaTime - */ - PaTime outputLatency; - - /** The sample rate of the stream in Hertz (samples per second). In cases - where the hardware sample rate is inaccurate and PortAudio is aware of it, - the value of this field may be different from the sampleRate parameter - passed to Pa_OpenStream(). If information about the actual hardware sample - rate is not available, this field will have the same value as the sampleRate - parameter passed to Pa_OpenStream(). - */ - double sampleRate; - -} PaStreamInfo; - - -/** Retrieve a pointer to a PaStreamInfo structure containing information - about the specified stream. - @return A pointer to an immutable PaStreamInfo structure. If the stream - parameter is invalid, or an error is encountered, the function returns NULL. - - @param stream A pointer to an open stream previously created with Pa_OpenStream(). - - @note PortAudio manages the memory referenced by the returned pointer, - the client must not manipulate or free the memory. The pointer is only - guaranteed to be valid until the specified stream is closed. - - @see PaStreamInfo -*/ -const PaStreamInfo* Pa_GetStreamInfo( PaStream *stream ); - - -/** Returns the current time in seconds for a stream according to the same clock used - to generate callback PaStreamCallbackTimeInfo timestamps. The time values are - monotonically increasing and have unspecified origin. - - Pa_GetStreamTime returns valid time values for the entire life of the stream, - from when the stream is opened until it is closed. Starting and stopping the stream - does not affect the passage of time returned by Pa_GetStreamTime. - - This time may be used for synchronizing other events to the audio stream, for - example synchronizing audio to MIDI. - - @return The stream's current time in seconds, or 0 if an error occurred. - - @see PaTime, PaStreamCallback, PaStreamCallbackTimeInfo -*/ -PaTime Pa_GetStreamTime( PaStream *stream ); - - -/** Retrieve CPU usage information for the specified stream. - The "CPU Load" is a fraction of total CPU time consumed by a callback stream's - audio processing routines including, but not limited to the client supplied - stream callback. This function does not work with blocking read/write streams. - - This function may be called from the stream callback function or the - application. - - @return - A floating point value, typically between 0.0 and 1.0, where 1.0 indicates - that the stream callback is consuming the maximum number of CPU cycles possible - to maintain real-time operation. A value of 0.5 would imply that PortAudio and - the stream callback was consuming roughly 50% of the available CPU time. The - return value may exceed 1.0. A value of 0.0 will always be returned for a - blocking read/write stream, or if an error occurs. -*/ -double Pa_GetStreamCpuLoad( PaStream* stream ); - - -/** Read samples from an input stream. The function doesn't return until - the entire buffer has been filled - this may involve waiting for the operating - system to supply the data. - - Reading from a stream that is stopped is not currently supported. In particular, - it is not possible to drain the read buffer by calling Pa_ReadStream after - calling Pa_StopStream(). - - @param stream A pointer to an open stream previously created with Pa_OpenStream(). - - @param buffer A pointer to a buffer of sample frames. The buffer contains - samples in the format specified by the inputParameters->sampleFormat field - used to open the stream, and the number of channels specified by - inputParameters->numChannels. If non-interleaved samples were requested using - the paNonInterleaved sample format flag, buffer is a pointer to the first element - of an array of buffer pointers, one non-interleaved buffer for each channel. - - @param frames The number of frames to be read into buffer. This parameter - is not constrained to a specific range, however high performance applications - will want to match this parameter to the framesPerBuffer parameter used - when opening the stream. - - @return On success PaNoError will be returned, or PaInputOverflowed if input - data was discarded by PortAudio after the previous call and before this call. -*/ -PaError Pa_ReadStream( PaStream* stream, - void *buffer, - unsigned long frames ); - - -/** Write samples to an output stream. This function doesn't return until the - entire buffer has been written - this may involve waiting for the operating - system to consume the data. - - Writing to a stream that is stopped is not currently supported. In particular, - it is not possible to prefill the write buffer by calling Pa_WriteStream - prior to calling Pa_StartStream(). - - @param stream A pointer to an open stream previously created with Pa_OpenStream(). - - @param buffer A pointer to a buffer of sample frames. The buffer contains - samples in the format specified by the outputParameters->sampleFormat field - used to open the stream, and the number of channels specified by - outputParameters->numChannels. If non-interleaved samples were requested using - the paNonInterleaved sample format flag, buffer is a pointer to the first element - of an array of buffer pointers, one non-interleaved buffer for each channel. - - @param frames The number of frames to be written from buffer. This parameter - is not constrained to a specific range, however high performance applications - will want to match this parameter to the framesPerBuffer parameter used - when opening the stream. - - @return On success PaNoError will be returned, or paOutputUnderflowed if - additional output data was inserted after the previous call and before this - call. -*/ -PaError Pa_WriteStream( PaStream* stream, - const void *buffer, - unsigned long frames ); - - -/** Retrieve the number of frames that can be read from the stream without - waiting. - - When the stream is stopped the return value of Pa_GetStreamReadAvailable is not - defined. - - @return Returns a non-negative value representing the maximum number of frames - that can be read from the stream without blocking or busy waiting or, a - PaErrorCode (which are always negative) if PortAudio is not initialized or an - error is encountered. -*/ -signed long Pa_GetStreamReadAvailable( PaStream* stream ); - - -/** Retrieve the number of frames that can be written to the stream without - waiting. - - When the stream is stopped the return value of Pa_GetStreamWriteAvailable is not - defined. - - @return Returns a non-negative value representing the maximum number of frames - that can be written to the stream without blocking or busy waiting or, a - PaErrorCode (which are always negative) if PortAudio is not initialized or an - error is encountered. -*/ -signed long Pa_GetStreamWriteAvailable( PaStream* stream ); - - -/* Miscellaneous utilities */ - - -/** Retrieve the size of a given sample format in bytes. - - @return The size in bytes of a single sample in the specified format, - or paSampleFormatNotSupported if the format is not supported. -*/ -PaError Pa_GetSampleSize( PaSampleFormat format ); - - -/** Put the caller to sleep for at least 'msec' milliseconds. This function is - provided only as a convenience for authors of portable code (such as the tests - and examples in the PortAudio distribution.) - - The function may sleep longer than requested so don't rely on this for accurate - musical timing. -*/ -void Pa_Sleep( long msec ); - - - -#ifdef __cplusplus -} -#endif /* __cplusplus */ -#endif /* PORTAUDIO_H */ From ec68b142d82bcb09a08af41e3134adfcc1bde188 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Thu, 30 Jan 2025 09:56:05 -0500 Subject: [PATCH 37/44] Change priority of internal logging from kInfo to kTrace. Make julia the default shader for shadertui. --- examples/shadertui/shader.wgsl | 65 ++++++++++++---------------------- gpu.hpp | 21 ++++++----- 2 files changed, 32 insertions(+), 54 deletions(-) diff --git a/examples/shadertui/shader.wgsl b/examples/shadertui/shader.wgsl index 84d3206..7d95150 100644 --- a/examples/shadertui/shader.wgsl +++ b/examples/shadertui/shader.wgsl @@ -2,60 +2,39 @@ @group(0) @binding(1) var params: Params; struct Params { - time: f32, - screenwidth: u32, - screenheight: u32, + time: f32, + screenwidth: u32, + screenheight: u32, }; -struct Particle { - position: vec2, - velocity: vec2, - life: f32, -} - -const NUM_PARTICLES: u32 = 1000u; -const PARTICLE_LIFE: f32 = 9.0; -const EMISSION_RATE: f32 = 300.0; +const MAX_ITERATIONS: u32 = 100; -fn rand(n: f32) -> f32 { - return fract(sin(n) * 43758.5453123); -} - -fn initialize_particle(id: f32, time: f32) -> Particle { - let random1 = rand(id * 0.01 + time * 0.1); - let random2 = rand(id * 0.02 + time * 0.1); - let angle = random1 * 2.0 * 3.14159; - let speed = 0.05 + random2 * 0.05; - - return Particle( - vec2(0.5, 0.5), - vec2(cos(angle) * speed, sin(angle) * speed), - PARTICLE_LIFE - ); +fn complex_mul(a: vec2, b: vec2) -> vec2 { + return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); } @compute @workgroup_size(16, 16, 1) fn main(@builtin(global_invocation_id) globalID : vec3) { let resolution = vec2(f32(params.screenwidth), f32(params.screenheight)); - let uv = vec2(f32(globalID.x) / resolution.x, f32(globalID.y) / resolution.y); - let idx = globalID.y * params.screenwidth + globalID.x; + let uv = (vec2(globalID.xy) - 0.5 * resolution) / min(resolution.x, resolution.y); + + // Animate the Julia set parameters + let t = params.time * 0.3; + let c = 0.7885 * vec2(cos(t), sin(t)); - var color: f32 = 0.0; + var z = uv * 3.0; + var i: u32 = 0u; - for (var i: f32 = 0.0; i < f32(NUM_PARTICLES); i += 1.0) { - let spawn_time = i / EMISSION_RATE; - let particle_age = fract((params.time - spawn_time) / PARTICLE_LIFE) * PARTICLE_LIFE; - - if (particle_age < PARTICLE_LIFE) { - var particle = initialize_particle(i, spawn_time); - particle.position += particle.velocity * particle_age; - - let distance = length(uv - particle.position); - if (distance < 0.005) { - color += 0.5 * (1.0 - particle_age / PARTICLE_LIFE); - } + for (; i < MAX_ITERATIONS; i = i + 1u) { + z = complex_mul(z, z) + c; + if (dot(z, z) > 4.0) { + break; } } - out[idx] = min(color, 1.0); + let smooth_i = f32(i) + 1.0 - log2(log2(dot(z, z))); + let color = 0.5 + 0.5 * cos(3.0 + smooth_i * 0.15 + vec3(0.0, 0.6, 1.0)); + + let idx = globalID.y * params.screenwidth + globalID.x; + out[idx] = (color.r + color.g + color.b) / 3.0; } diff --git a/gpu.hpp b/gpu.hpp index c4b6fe9..5327fe7 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -530,7 +530,6 @@ struct Context { // Default constructor Context() = default; - // Move constructor: steals GPU handles so the source destructor won't free them. Context(Context&& other) noexcept : instance(other.instance), adapter(other.adapter), @@ -542,6 +541,7 @@ struct Context { adapterStatus(other.adapterStatus), deviceStatus(other.deviceStatus) { + LOG(kDefLog, kTrace, "Moving Context ownership"); // Move over the resources in the pools: pool.data = std::move(other.pool.data); kernelPool.data = std::move(other.kernelPool.data); @@ -555,7 +555,6 @@ struct Context { // other.deviceStatus = 0; } - // Optional move‐assignment operator, similarly stealing resources: Context& operator=(Context&& other) noexcept { if (this != &other) { // Free any existing resources. In most cases, this should be a no-op @@ -573,26 +572,26 @@ struct Context { if (queue) { wgpuQueueRelease(queue); } else { - LOG(kDefLog, kWarn, "Queue is null"); + LOG(kDefLog, kTrace, "Queue is null"); } if (device) { wgpuDeviceRelease(device); processEvents(instance); } else { - LOG(kDefLog, kWarn, "Device is null"); + LOG(kDefLog, kTrace, "Device is null"); } if (adapter) { wgpuAdapterRelease(adapter); processEvents(instance); } else { - LOG(kDefLog, kWarn, "Adapter is null"); + LOG(kDefLog, kTrace, "Adapter is null"); } if (instance) { wgpuInstanceRelease(instance); } else { - LOG(kDefLog, kWarn, "Instance is null"); + LOG(kDefLog, kTrace, "Instance is null"); } - LOG(kDefLog, kInfo, "Context destroyed"); + LOG(kDefLog, kTrace, "Context destroyed"); } }; @@ -827,7 +826,7 @@ inline Context createContext( #endif check(ctx.instance, "Initialize WebGPU", __FILE__, __LINE__); - LOG(kDefLog, kInfo, "Requesting adapter"); + LOG(kDefLog, kTrace, "Requesting adapter"); { struct AdapterData { WGPUAdapter adapter = nullptr; @@ -869,7 +868,7 @@ inline Context createContext( ctx.adapterStatus = adapterData.status; } - LOG(kDefLog, kInfo, "Requesting device"); + LOG(kDefLog, kTrace, "Requesting device"); { struct DeviceData { WGPUDevice device = nullptr; @@ -900,11 +899,11 @@ inline Context createContext( }; wgpuAdapterRequestDevice(ctx.adapter, &devDescriptor, deviceCallbackInfo); - LOG(kDefLog, kInfo, "Waiting for device request to end"); + LOG(kDefLog, kTrace, "Waiting for device request to end"); while (!devData.requestEnded) { processEvents(ctx.instance); } - LOG(kDefLog, kInfo, "Device request ended"); + LOG(kDefLog, kTrace, "Device request ended"); ctx.device = devData.device; ctx.deviceStatus = devData.status; From 4589f1f2492cd8dd18c9c968cff0027ae1d6c3a1 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sat, 1 Feb 2025 17:09:04 -0500 Subject: [PATCH 38/44] bump dawn version to c469d593ac and remove WebGPU-distribution from third_party --- .gitmodules | 4 - setup.py | 4 +- third_party/headers/webgpu/README.md | 6 +- third_party/headers/webgpu/webgpu.h | 481 +++++++++++++------------- third_party/local/.gitkeep | 0 third_party/local/WebGPU-distribution | 1 - 6 files changed, 241 insertions(+), 255 deletions(-) delete mode 100644 third_party/local/.gitkeep delete mode 160000 third_party/local/WebGPU-distribution diff --git a/.gitmodules b/.gitmodules index 59eacdf..468fa55 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ -[submodule "third_party/local/WebGPU-distribution"] - path = third_party/local/WebGPU-distribution - url = https://github.com/eliemichel/WebGPU-distribution.git - branch = dawn [submodule "third_party/llm.c"] path = third_party/llm.c url = https://github.com/karpathy/llm.c diff --git a/setup.py b/setup.py index 40cc5cc..964f345 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,8 @@ def download_dawn(os_name): "Linux": "third_party/lib/libwebgpu_dawn.so", } url_map = { - "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.dylib", - "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.so", + "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0-pre2/libwebgpu_dawn.dylib", + "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0-pre2/libwebgpu_dawn.so", } outfile = outfile_map.get(os_name) diff --git a/third_party/headers/webgpu/README.md b/third_party/headers/webgpu/README.md index 5b2f551..c29db50 100644 --- a/third_party/headers/webgpu/README.md +++ b/third_party/headers/webgpu/README.md @@ -1 +1,5 @@ -webgpu.h from dawn build: _deps/dawn-build/gen/include/dawn/ +0.2.0 + +Dawn Commit Hash: c469d593ac + +webgpu.h from dawn build directory: gen/include/dawn/webgpu.h diff --git a/third_party/headers/webgpu/webgpu.h b/third_party/headers/webgpu/webgpu.h index b36a758..a77052f 100644 --- a/third_party/headers/webgpu/webgpu.h +++ b/third_party/headers/webgpu/webgpu.h @@ -144,7 +144,6 @@ struct WGPUBufferBindingLayout; struct WGPUBufferHostMappedPointer; struct WGPUColor; struct WGPUColorTargetStateExpandResolveTextureDawn; -struct WGPUComputePassTimestampWrites; struct WGPUCopyTextureForBrowserOptions; struct WGPUDawnWGSLBlocklist; struct WGPUDawnAdapterPropertiesPowerPreference; @@ -165,18 +164,18 @@ struct WGPUExtent3D; struct WGPUExternalTextureBindingEntry; struct WGPUExternalTextureBindingLayout; struct WGPUFuture; -struct WGPUInstanceFeatures; +struct WGPUInstanceCapabilities; struct WGPULimits; struct WGPUMemoryHeapInfo; struct WGPUMultisampleState; struct WGPUOrigin2D; struct WGPUOrigin3D; +struct WGPUPassTimestampWrites; struct WGPUPipelineLayoutStorageAttachment; struct WGPUPrimitiveState; struct WGPURenderPassDepthStencilAttachment; struct WGPURenderPassDescriptorExpandResolveRect; struct WGPURenderPassMaxDrawCount; -struct WGPURenderPassTimestampWrites; struct WGPURequestAdapterOptions; struct WGPUSamplerBindingLayout; struct WGPUShaderModuleCompilationOptions; @@ -287,18 +286,6 @@ struct WGPUComputePipelineDescriptor; struct WGPUFragmentState; struct WGPURenderPipelineDescriptor; -typedef enum WGPUWGSLFeatureName { - WGPUWGSLFeatureName_ReadonlyAndReadwriteStorageTextures = 0x00000001, - WGPUWGSLFeatureName_Packed4x8IntegerDotProduct = 0x00000002, - WGPUWGSLFeatureName_UnrestrictedPointerParameters = 0x00000003, - WGPUWGSLFeatureName_PointerCompositeAccess = 0x00000004, - WGPUWGSLFeatureName_ChromiumTestingUnimplemented = 0x00050000, - WGPUWGSLFeatureName_ChromiumTestingUnsafeExperimental = 0x00050001, - WGPUWGSLFeatureName_ChromiumTestingExperimental = 0x00050002, - WGPUWGSLFeatureName_ChromiumTestingShippedWithKillswitch = 0x00050003, - WGPUWGSLFeatureName_ChromiumTestingShipped = 0x00050004, - WGPUWGSLFeatureName_Force32 = 0x7FFFFFFF -} WGPUWGSLFeatureName WGPU_ENUM_ATTRIBUTE; typedef enum WGPUWGSLLanguageFeatureName { WGPUWGSLLanguageFeatureName_ReadonlyAndReadwriteStorageTextures = 0x00000001, WGPUWGSLLanguageFeatureName_Packed4x8IntegerDotProduct = 0x00000002, @@ -1066,11 +1053,11 @@ typedef struct WGPUBufferMapCallbackInfo { } WGPUBufferMapCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BUFFER_MAP_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferMapCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUCompilationInfoCallbackInfo { @@ -1082,11 +1069,11 @@ typedef struct WGPUCompilationInfoCallbackInfo { } WGPUCompilationInfoCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPILATION_INFO_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfoCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUCreateComputePipelineAsyncCallbackInfo { @@ -1098,11 +1085,11 @@ typedef struct WGPUCreateComputePipelineAsyncCallbackInfo { } WGPUCreateComputePipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_CREATE_COMPUTE_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateComputePipelineAsyncCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUCreateRenderPipelineAsyncCallbackInfo { @@ -1114,11 +1101,11 @@ typedef struct WGPUCreateRenderPipelineAsyncCallbackInfo { } WGPUCreateRenderPipelineAsyncCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_CREATE_RENDER_PIPELINE_ASYNC_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCreateRenderPipelineAsyncCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUDeviceLostCallbackInfo { @@ -1130,11 +1117,11 @@ typedef struct WGPUDeviceLostCallbackInfo { } WGPUDeviceLostCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DEVICE_LOST_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUDeviceLostCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPULoggingCallbackInfo { @@ -1145,10 +1132,10 @@ typedef struct WGPULoggingCallbackInfo { } WGPULoggingCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_LOGGING_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPULoggingCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUPopErrorScopeCallbackInfo { @@ -1160,11 +1147,11 @@ typedef struct WGPUPopErrorScopeCallbackInfo { } WGPUPopErrorScopeCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_POP_ERROR_SCOPE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUPopErrorScopeCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUQueueWorkDoneCallbackInfo { @@ -1176,11 +1163,11 @@ typedef struct WGPUQueueWorkDoneCallbackInfo { } WGPUQueueWorkDoneCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_QUEUE_WORK_DONE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueWorkDoneCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPURequestAdapterCallbackInfo { @@ -1192,11 +1179,11 @@ typedef struct WGPURequestAdapterCallbackInfo { } WGPURequestAdapterCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_REQUEST_ADAPTER_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPURequestDeviceCallbackInfo { @@ -1208,11 +1195,11 @@ typedef struct WGPURequestDeviceCallbackInfo { } WGPURequestDeviceCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_REQUEST_DEVICE_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestDeviceCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.mode=*/{} WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) typedef struct WGPUUncapturedErrorCallbackInfo { @@ -1223,10 +1210,10 @@ typedef struct WGPUUncapturedErrorCallbackInfo { } WGPUUncapturedErrorCallbackInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_UNCAPTURED_ERROR_CALLBACK_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUUncapturedErrorCallbackInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.callback=*/nullptr WGPU_COMMA \ - /*.userdata1=*/nullptr WGPU_COMMA \ - /*.userdata2=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.callback=*/NULL WGPU_COMMA \ + /*.userdata1=*/NULL WGPU_COMMA \ + /*.userdata2=*/NULL WGPU_COMMA \ }) @@ -1245,7 +1232,7 @@ typedef struct WGPUAdapterPropertiesD3D { } WGPUAdapterPropertiesD3D WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_PROPERTIES_D3D_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesD3D, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesD3D} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesD3D} WGPU_COMMA \ /*.shaderModel=*/{} WGPU_COMMA \ }) @@ -1257,7 +1244,7 @@ typedef struct WGPUAdapterPropertiesSubgroups { } WGPUAdapterPropertiesSubgroups WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_PROPERTIES_SUBGROUPS_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesSubgroups, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesSubgroups} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesSubgroups} WGPU_COMMA \ /*.subgroupMinSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ /*.subgroupMaxSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) @@ -1269,7 +1256,7 @@ typedef struct WGPUAdapterPropertiesVk { } WGPUAdapterPropertiesVk WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_PROPERTIES_VK_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesVk, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesVk} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesVk} WGPU_COMMA \ /*.driverVersion=*/{} WGPU_COMMA \ }) @@ -1284,13 +1271,13 @@ typedef struct WGPUBindGroupEntry { } WGPUBindGroupEntry WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BIND_GROUP_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupEntry, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.binding=*/{} WGPU_COMMA \ - /*.buffer=*/nullptr WGPU_COMMA \ + /*.buffer=*/NULL WGPU_COMMA \ /*.offset=*/0 WGPU_COMMA \ /*.size=*/WGPU_WHOLE_SIZE WGPU_COMMA \ - /*.sampler=*/nullptr WGPU_COMMA \ - /*.textureView=*/nullptr WGPU_COMMA \ + /*.sampler=*/NULL WGPU_COMMA \ + /*.textureView=*/NULL WGPU_COMMA \ }) typedef struct WGPUBlendComponent { @@ -1313,7 +1300,7 @@ typedef struct WGPUBufferBindingLayout { } WGPUBufferBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BUFFER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferBindingLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.type=*/WGPUBufferBindingType_Uniform WGPU_COMMA \ /*.hasDynamicOffset=*/false WGPU_COMMA \ /*.minBindingSize=*/0 WGPU_COMMA \ @@ -1328,7 +1315,7 @@ typedef struct WGPUBufferHostMappedPointer { } WGPUBufferHostMappedPointer WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BUFFER_HOST_MAPPED_POINTER_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferHostMappedPointer, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_BufferHostMappedPointer} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_BufferHostMappedPointer} WGPU_COMMA \ /*.pointer=*/{} WGPU_COMMA \ /*.disposeCallback=*/{} WGPU_COMMA \ /*.userdata=*/{} WGPU_COMMA \ @@ -1355,22 +1342,10 @@ typedef struct WGPUColorTargetStateExpandResolveTextureDawn { } WGPUColorTargetStateExpandResolveTextureDawn WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COLOR_TARGET_STATE_EXPAND_RESOLVE_TEXTURE_DAWN_INIT WGPU_MAKE_INIT_STRUCT(WGPUColorTargetStateExpandResolveTextureDawn, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ColorTargetStateExpandResolveTextureDawn} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ColorTargetStateExpandResolveTextureDawn} WGPU_COMMA \ /*.enabled=*/false WGPU_COMMA \ }) -typedef struct WGPUComputePassTimestampWrites { - WGPUQuerySet querySet; - uint32_t beginningOfPassWriteIndex; - uint32_t endOfPassWriteIndex; -} WGPUComputePassTimestampWrites WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_COMPUTE_PASS_TIMESTAMP_WRITES_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePassTimestampWrites, { \ - /*.querySet=*/{} WGPU_COMMA \ - /*.beginningOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ - /*.endOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ -}) - typedef struct WGPUCopyTextureForBrowserOptions { WGPUChainedStruct* nextInChain; WGPUBool flipY; @@ -1384,13 +1359,13 @@ typedef struct WGPUCopyTextureForBrowserOptions { } WGPUCopyTextureForBrowserOptions WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COPY_TEXTURE_FOR_BROWSER_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPUCopyTextureForBrowserOptions, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.flipY=*/false WGPU_COMMA \ /*.needsColorSpaceConversion=*/false WGPU_COMMA \ /*.srcAlphaMode=*/WGPUAlphaMode_Unpremultiplied WGPU_COMMA \ - /*.srcTransferFunctionParameters=*/nullptr WGPU_COMMA \ - /*.conversionMatrix=*/nullptr WGPU_COMMA \ - /*.dstTransferFunctionParameters=*/nullptr WGPU_COMMA \ + /*.srcTransferFunctionParameters=*/NULL WGPU_COMMA \ + /*.conversionMatrix=*/NULL WGPU_COMMA \ + /*.dstTransferFunctionParameters=*/NULL WGPU_COMMA \ /*.dstAlphaMode=*/WGPUAlphaMode_Unpremultiplied WGPU_COMMA \ /*.internalUsage=*/false WGPU_COMMA \ }) @@ -1403,7 +1378,7 @@ typedef struct WGPUDawnWGSLBlocklist { } WGPUDawnWGSLBlocklist WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_WGSL_BLOCKLIST_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnWGSLBlocklist, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnWGSLBlocklist} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnWGSLBlocklist} WGPU_COMMA \ /*.blocklistedFeatureCount=*/0 WGPU_COMMA \ /*.blocklistedFeatures=*/{} WGPU_COMMA \ }) @@ -1415,7 +1390,7 @@ typedef struct WGPUDawnAdapterPropertiesPowerPreference { } WGPUDawnAdapterPropertiesPowerPreference WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_ADAPTER_PROPERTIES_POWER_PREFERENCE_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnAdapterPropertiesPowerPreference, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnAdapterPropertiesPowerPreference} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnAdapterPropertiesPowerPreference} WGPU_COMMA \ /*.powerPreference=*/WGPUPowerPreference_Undefined WGPU_COMMA \ }) @@ -1426,7 +1401,7 @@ typedef struct WGPUDawnBufferDescriptorErrorInfoFromWireClient { } WGPUDawnBufferDescriptorErrorInfoFromWireClient WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_BUFFER_DESCRIPTOR_ERROR_INFO_FROM_WIRE_CLIENT_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnBufferDescriptorErrorInfoFromWireClient, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnBufferDescriptorErrorInfoFromWireClient} WGPU_COMMA \ /*.outOfMemory=*/false WGPU_COMMA \ }) @@ -1447,7 +1422,7 @@ typedef struct WGPUDawnEncoderInternalUsageDescriptor { } WGPUDawnEncoderInternalUsageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_ENCODER_INTERNAL_USAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnEncoderInternalUsageDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnEncoderInternalUsageDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnEncoderInternalUsageDescriptor} WGPU_COMMA \ /*.useInternalUsages=*/false WGPU_COMMA \ }) @@ -1458,7 +1433,7 @@ typedef struct WGPUDawnExperimentalImmediateDataLimits { } WGPUDawnExperimentalImmediateDataLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_EXPERIMENTAL_IMMEDIATE_DATA_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnExperimentalImmediateDataLimits, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalImmediateDataLimits} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalImmediateDataLimits} WGPU_COMMA \ /*.maxImmediateDataRangeByteSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) @@ -1470,7 +1445,7 @@ typedef struct WGPUDawnExperimentalSubgroupLimits { } WGPUDawnExperimentalSubgroupLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_EXPERIMENTAL_SUBGROUP_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnExperimentalSubgroupLimits, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalSubgroupLimits} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnExperimentalSubgroupLimits} WGPU_COMMA \ /*.minSubgroupSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ /*.maxSubgroupSize=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) @@ -1480,7 +1455,7 @@ typedef struct WGPUDawnFormatCapabilities { } WGPUDawnFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnFormatCapabilities, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ }) // Can be chained in WGPURenderPassColorAttachment @@ -1490,7 +1465,7 @@ typedef struct WGPUDawnRenderPassColorAttachmentRenderToSingleSampled { } WGPUDawnRenderPassColorAttachmentRenderToSingleSampled WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_RENDER_PASS_COLOR_ATTACHMENT_RENDER_TO_SINGLE_SAMPLED_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnRenderPassColorAttachmentRenderToSingleSampled, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnRenderPassColorAttachmentRenderToSingleSampled} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnRenderPassColorAttachmentRenderToSingleSampled} WGPU_COMMA \ /*.implicitSampleCount=*/1 WGPU_COMMA \ }) @@ -1501,7 +1476,7 @@ typedef struct WGPUDawnShaderModuleSPIRVOptionsDescriptor { } WGPUDawnShaderModuleSPIRVOptionsDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_SHADER_MODULE_SPIRV_OPTIONS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnShaderModuleSPIRVOptionsDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnShaderModuleSPIRVOptionsDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnShaderModuleSPIRVOptionsDescriptor} WGPU_COMMA \ /*.allowNonUniformDerivatives=*/false WGPU_COMMA \ }) @@ -1512,7 +1487,7 @@ typedef struct WGPUDawnTexelCopyBufferRowAlignmentLimits { } WGPUDawnTexelCopyBufferRowAlignmentLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_TEXEL_COPY_BUFFER_ROW_ALIGNMENT_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTexelCopyBufferRowAlignmentLimits, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTexelCopyBufferRowAlignmentLimits} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnTexelCopyBufferRowAlignmentLimits} WGPU_COMMA \ /*.minTexelCopyBufferRowAlignment=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) @@ -1523,7 +1498,7 @@ typedef struct WGPUDawnTextureInternalUsageDescriptor { } WGPUDawnTextureInternalUsageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_TEXTURE_INTERNAL_USAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTextureInternalUsageDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTextureInternalUsageDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnTextureInternalUsageDescriptor} WGPU_COMMA \ /*.internalUsage=*/WGPUTextureUsage_None WGPU_COMMA \ }) @@ -1539,7 +1514,7 @@ typedef struct WGPUDawnTogglesDescriptor { } WGPUDawnTogglesDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_TOGGLES_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnTogglesDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnTogglesDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnTogglesDescriptor} WGPU_COMMA \ /*.enabledToggleCount=*/0 WGPU_COMMA \ /*.enabledToggles=*/{} WGPU_COMMA \ /*.disabledToggleCount=*/0 WGPU_COMMA \ @@ -1555,7 +1530,7 @@ typedef struct WGPUDawnWireWGSLControl { } WGPUDawnWireWGSLControl WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_WIRE_WGSL_CONTROL_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnWireWGSLControl, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnWireWGSLControl} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnWireWGSLControl} WGPU_COMMA \ /*.enableExperimental=*/false WGPU_COMMA \ /*.enableUnsafe=*/false WGPU_COMMA \ /*.enableTesting=*/false WGPU_COMMA \ @@ -1590,7 +1565,7 @@ typedef struct WGPUExternalTextureBindingEntry { } WGPUExternalTextureBindingEntry WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EXTERNAL_TEXTURE_BINDING_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureBindingEntry, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingEntry} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingEntry} WGPU_COMMA \ /*.externalTexture=*/{} WGPU_COMMA \ }) @@ -1600,7 +1575,7 @@ typedef struct WGPUExternalTextureBindingLayout { } WGPUExternalTextureBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EXTERNAL_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureBindingLayout, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingLayout} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ExternalTextureBindingLayout} WGPU_COMMA \ }) typedef struct WGPUFuture { @@ -1611,14 +1586,14 @@ typedef struct WGPUFuture { /*.id=*/{} WGPU_COMMA \ }) -typedef struct WGPUInstanceFeatures { +typedef struct WGPUInstanceCapabilities { WGPUChainedStruct* nextInChain; WGPUBool timedWaitAnyEnable; size_t timedWaitAnyMaxCount; -} WGPUInstanceFeatures WGPU_STRUCTURE_ATTRIBUTE; +} WGPUInstanceCapabilities WGPU_STRUCTURE_ATTRIBUTE; -#define WGPU_INSTANCE_FEATURES_INIT WGPU_MAKE_INIT_STRUCT(WGPUInstanceFeatures, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ +#define WGPU_INSTANCE_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUInstanceCapabilities, { \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.timedWaitAnyEnable=*/false WGPU_COMMA \ /*.timedWaitAnyMaxCount=*/0 WGPU_COMMA \ }) @@ -1719,7 +1694,7 @@ typedef struct WGPUMultisampleState { } WGPUMultisampleState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_MULTISAMPLE_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUMultisampleState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.count=*/1 WGPU_COMMA \ /*.mask=*/0xFFFFFFFF WGPU_COMMA \ /*.alphaToCoverageEnabled=*/false WGPU_COMMA \ @@ -1747,6 +1722,20 @@ typedef struct WGPUOrigin3D { /*.z=*/0 WGPU_COMMA \ }) +typedef struct WGPUPassTimestampWrites { + WGPUChainedStruct* nextInChain; + WGPUQuerySet querySet; + uint32_t beginningOfPassWriteIndex; + uint32_t endOfPassWriteIndex; +} WGPUPassTimestampWrites WGPU_STRUCTURE_ATTRIBUTE; + +#define WGPU_PASS_TIMESTAMP_WRITES_INIT WGPU_MAKE_INIT_STRUCT(WGPUPassTimestampWrites, { \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.querySet=*/{} WGPU_COMMA \ + /*.beginningOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ + /*.endOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ +}) + typedef struct WGPUPipelineLayoutStorageAttachment { uint64_t offset; WGPUTextureFormat format; @@ -1767,7 +1756,7 @@ typedef struct WGPUPrimitiveState { } WGPUPrimitiveState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PRIMITIVE_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUPrimitiveState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.topology=*/WGPUPrimitiveTopology_TriangleList WGPU_COMMA \ /*.stripIndexFormat=*/WGPUIndexFormat_Undefined WGPU_COMMA \ /*.frontFace=*/WGPUFrontFace_CCW WGPU_COMMA \ @@ -1776,6 +1765,7 @@ typedef struct WGPUPrimitiveState { }) typedef struct WGPURenderPassDepthStencilAttachment { + WGPUChainedStruct* nextInChain; WGPUTextureView view; WGPULoadOp depthLoadOp; WGPUStoreOp depthStoreOp; @@ -1788,6 +1778,7 @@ typedef struct WGPURenderPassDepthStencilAttachment { } WGPURenderPassDepthStencilAttachment WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_DEPTH_STENCIL_ATTACHMENT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDepthStencilAttachment, { \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.view=*/{} WGPU_COMMA \ /*.depthLoadOp=*/WGPULoadOp_Undefined WGPU_COMMA \ /*.depthStoreOp=*/WGPUStoreOp_Undefined WGPU_COMMA \ @@ -1809,7 +1800,7 @@ typedef struct WGPURenderPassDescriptorExpandResolveRect { } WGPURenderPassDescriptorExpandResolveRect WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_DESCRIPTOR_EXPAND_RESOLVE_RECT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDescriptorExpandResolveRect, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassDescriptorExpandResolveRect} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_RenderPassDescriptorExpandResolveRect} WGPU_COMMA \ /*.x=*/{} WGPU_COMMA \ /*.y=*/{} WGPU_COMMA \ /*.width=*/{} WGPU_COMMA \ @@ -1823,22 +1814,10 @@ typedef struct WGPURenderPassMaxDrawCount { } WGPURenderPassMaxDrawCount WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_MAX_DRAW_COUNT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassMaxDrawCount, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassMaxDrawCount} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_RenderPassMaxDrawCount} WGPU_COMMA \ /*.maxDrawCount=*/50000000 WGPU_COMMA \ }) -typedef struct WGPURenderPassTimestampWrites { - WGPUQuerySet querySet; - uint32_t beginningOfPassWriteIndex; - uint32_t endOfPassWriteIndex; -} WGPURenderPassTimestampWrites WGPU_STRUCTURE_ATTRIBUTE; - -#define WGPU_RENDER_PASS_TIMESTAMP_WRITES_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassTimestampWrites, { \ - /*.querySet=*/{} WGPU_COMMA \ - /*.beginningOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ - /*.endOfPassWriteIndex=*/WGPU_QUERY_SET_INDEX_UNDEFINED WGPU_COMMA \ -}) - typedef struct WGPURequestAdapterOptions { WGPUChainedStruct* nextInChain; WGPU_NULLABLE WGPUSurface compatibleSurface; @@ -1849,8 +1828,8 @@ typedef struct WGPURequestAdapterOptions { } WGPURequestAdapterOptions WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_REQUEST_ADAPTER_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPURequestAdapterOptions, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.compatibleSurface=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.compatibleSurface=*/NULL WGPU_COMMA \ /*.featureLevel=*/WGPUFeatureLevel_Core WGPU_COMMA \ /*.powerPreference=*/WGPUPowerPreference_Undefined WGPU_COMMA \ /*.backendType=*/WGPUBackendType_Undefined WGPU_COMMA \ @@ -1863,7 +1842,7 @@ typedef struct WGPUSamplerBindingLayout { } WGPUSamplerBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SAMPLER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerBindingLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.type=*/WGPUSamplerBindingType_Filtering WGPU_COMMA \ }) @@ -1874,7 +1853,7 @@ typedef struct WGPUShaderModuleCompilationOptions { } WGPUShaderModuleCompilationOptions WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHADER_MODULE_COMPILATION_OPTIONS_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleCompilationOptions, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderModuleCompilationOptions} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ShaderModuleCompilationOptions} WGPU_COMMA \ /*.strictMath=*/{} WGPU_COMMA \ }) @@ -1886,7 +1865,7 @@ typedef struct WGPUShaderSourceSPIRV { } WGPUShaderSourceSPIRV WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHADER_SOURCE_SPIRV_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderSourceSPIRV, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceSPIRV} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceSPIRV} WGPU_COMMA \ /*.codeSize=*/{} WGPU_COMMA \ /*.code=*/{} WGPU_COMMA \ }) @@ -1900,7 +1879,7 @@ typedef struct WGPUSharedBufferMemoryBeginAccessDescriptor { } WGPUSharedBufferMemoryBeginAccessDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_BUFFER_MEMORY_BEGIN_ACCESS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryBeginAccessDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.initialized=*/{} WGPU_COMMA \ /*.fenceCount=*/0 WGPU_COMMA \ /*.fences=*/{} WGPU_COMMA \ @@ -1916,7 +1895,7 @@ typedef struct WGPUSharedBufferMemoryEndAccessState { } WGPUSharedBufferMemoryEndAccessState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_BUFFER_MEMORY_END_ACCESS_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryEndAccessState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.initialized=*/{} WGPU_COMMA \ /*.fenceCount=*/0 WGPU_COMMA \ /*.fences=*/{} WGPU_COMMA \ @@ -1930,7 +1909,7 @@ typedef struct WGPUSharedBufferMemoryProperties { } WGPUSharedBufferMemoryProperties WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_BUFFER_MEMORY_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryProperties, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.usage=*/{} WGPU_COMMA \ /*.size=*/{} WGPU_COMMA \ }) @@ -1942,7 +1921,7 @@ typedef struct WGPUSharedFenceDXGISharedHandleDescriptor { } WGPUSharedFenceDXGISharedHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_DXGI_SHARED_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDXGISharedHandleDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -1953,7 +1932,7 @@ typedef struct WGPUSharedFenceDXGISharedHandleExportInfo { } WGPUSharedFenceDXGISharedHandleExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_DXGI_SHARED_HANDLE_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDXGISharedHandleExportInfo, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleExportInfo} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceDXGISharedHandleExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -1964,7 +1943,7 @@ typedef struct WGPUSharedFenceMTLSharedEventDescriptor { } WGPUSharedFenceMTLSharedEventDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_MTL_SHARED_EVENT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceMTLSharedEventDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventDescriptor} WGPU_COMMA \ /*.sharedEvent=*/{} WGPU_COMMA \ }) @@ -1975,7 +1954,7 @@ typedef struct WGPUSharedFenceMTLSharedEventExportInfo { } WGPUSharedFenceMTLSharedEventExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_MTL_SHARED_EVENT_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceMTLSharedEventExportInfo, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventExportInfo} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceMTLSharedEventExportInfo} WGPU_COMMA \ /*.sharedEvent=*/{} WGPU_COMMA \ }) @@ -1985,7 +1964,7 @@ typedef struct WGPUSharedFenceExportInfo { } WGPUSharedFenceExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceExportInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.type=*/{} WGPU_COMMA \ }) @@ -1996,7 +1975,7 @@ typedef struct WGPUSharedFenceSyncFDDescriptor { } WGPUSharedFenceSyncFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_SYNC_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceSyncFDDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2007,7 +1986,7 @@ typedef struct WGPUSharedFenceSyncFDExportInfo { } WGPUSharedFenceSyncFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_SYNC_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceSyncFDExportInfo, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDExportInfo} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceSyncFDExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2018,7 +1997,7 @@ typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor { } WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2029,7 +2008,7 @@ typedef struct WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo { } WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_OPAQUE_FD_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreOpaqueFDExportInfo, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDExportInfo} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreOpaqueFDExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2040,7 +2019,7 @@ typedef struct WGPUSharedFenceVkSemaphoreZirconHandleDescriptor { } WGPUSharedFenceVkSemaphoreZirconHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_ZIRCON_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreZirconHandleDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2051,7 +2030,7 @@ typedef struct WGPUSharedFenceVkSemaphoreZirconHandleExportInfo { } WGPUSharedFenceVkSemaphoreZirconHandleExportInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_VK_SEMAPHORE_ZIRCON_HANDLE_EXPORT_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceVkSemaphoreZirconHandleExportInfo, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleExportInfo} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedFenceVkSemaphoreZirconHandleExportInfo} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ }) @@ -2062,7 +2041,7 @@ typedef struct WGPUSharedTextureMemoryD3DSwapchainBeginState { } WGPUSharedTextureMemoryD3DSwapchainBeginState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_D3D_SWAPCHAIN_BEGIN_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryD3DSwapchainBeginState, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryD3DSwapchainBeginState} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryD3DSwapchainBeginState} WGPU_COMMA \ /*.isSwapchain=*/false WGPU_COMMA \ }) @@ -2074,7 +2053,7 @@ typedef struct WGPUSharedTextureMemoryDXGISharedHandleDescriptor { } WGPUSharedTextureMemoryDXGISharedHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_DXGI_SHARED_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDXGISharedHandleDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDXGISharedHandleDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDXGISharedHandleDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ /*.useKeyedMutex=*/{} WGPU_COMMA \ }) @@ -2086,7 +2065,7 @@ typedef struct WGPUSharedTextureMemoryEGLImageDescriptor { } WGPUSharedTextureMemoryEGLImageDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_EGL_IMAGE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryEGLImageDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryEGLImageDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryEGLImageDescriptor} WGPU_COMMA \ /*.image=*/{} WGPU_COMMA \ }) @@ -2097,7 +2076,7 @@ typedef struct WGPUSharedTextureMemoryIOSurfaceDescriptor { } WGPUSharedTextureMemoryIOSurfaceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_IO_SURFACE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryIOSurfaceDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryIOSurfaceDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryIOSurfaceDescriptor} WGPU_COMMA \ /*.ioSurface=*/{} WGPU_COMMA \ }) @@ -2109,7 +2088,7 @@ typedef struct WGPUSharedTextureMemoryAHardwareBufferDescriptor { } WGPUSharedTextureMemoryAHardwareBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_A_HARDWARE_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryAHardwareBufferDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferDescriptor} WGPU_COMMA \ /*.handle=*/{} WGPU_COMMA \ /*.useExternalFormat=*/{} WGPU_COMMA \ }) @@ -2124,7 +2103,7 @@ typedef struct WGPUSharedTextureMemoryBeginAccessDescriptor { } WGPUSharedTextureMemoryBeginAccessDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_BEGIN_ACCESS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryBeginAccessDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.concurrentRead=*/{} WGPU_COMMA \ /*.initialized=*/{} WGPU_COMMA \ /*.fenceCount=*/{} WGPU_COMMA \ @@ -2153,7 +2132,7 @@ typedef struct WGPUSharedTextureMemoryEndAccessState { } WGPUSharedTextureMemoryEndAccessState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_END_ACCESS_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryEndAccessState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.initialized=*/{} WGPU_COMMA \ /*.fenceCount=*/{} WGPU_COMMA \ /*.fences=*/{} WGPU_COMMA \ @@ -2171,7 +2150,7 @@ typedef struct WGPUSharedTextureMemoryOpaqueFDDescriptor { } WGPUSharedTextureMemoryOpaqueFDDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_OPAQUE_FD_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryOpaqueFDDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryOpaqueFDDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryOpaqueFDDescriptor} WGPU_COMMA \ /*.vkImageCreateInfo=*/{} WGPU_COMMA \ /*.memoryFD=*/{} WGPU_COMMA \ /*.memoryTypeIndex=*/{} WGPU_COMMA \ @@ -2186,7 +2165,7 @@ typedef struct WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor { } WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_DEDICATED_ALLOCATION_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkDedicatedAllocationDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkDedicatedAllocationDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkDedicatedAllocationDescriptor} WGPU_COMMA \ /*.dedicatedAllocation=*/{} WGPU_COMMA \ }) @@ -2198,7 +2177,7 @@ typedef struct WGPUSharedTextureMemoryVkImageLayoutBeginState { } WGPUSharedTextureMemoryVkImageLayoutBeginState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_IMAGE_LAYOUT_BEGIN_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkImageLayoutBeginState, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutBeginState} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutBeginState} WGPU_COMMA \ /*.oldLayout=*/{} WGPU_COMMA \ /*.newLayout=*/{} WGPU_COMMA \ }) @@ -2211,7 +2190,7 @@ typedef struct WGPUSharedTextureMemoryVkImageLayoutEndState { } WGPUSharedTextureMemoryVkImageLayoutEndState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_VK_IMAGE_LAYOUT_END_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryVkImageLayoutEndState, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutEndState} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryVkImageLayoutEndState} WGPU_COMMA \ /*.oldLayout=*/{} WGPU_COMMA \ /*.newLayout=*/{} WGPU_COMMA \ }) @@ -2224,7 +2203,7 @@ typedef struct WGPUSharedTextureMemoryZirconHandleDescriptor { } WGPUSharedTextureMemoryZirconHandleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_ZIRCON_HANDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryZirconHandleDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryZirconHandleDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryZirconHandleDescriptor} WGPU_COMMA \ /*.memoryFD=*/{} WGPU_COMMA \ /*.allocationSize=*/{} WGPU_COMMA \ }) @@ -2237,7 +2216,7 @@ typedef struct WGPUStaticSamplerBindingLayout { } WGPUStaticSamplerBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_STATIC_SAMPLER_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUStaticSamplerBindingLayout, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_StaticSamplerBindingLayout} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_StaticSamplerBindingLayout} WGPU_COMMA \ /*.sampler=*/{} WGPU_COMMA \ /*.sampledTextureBinding=*/WGPU_LIMIT_U32_UNDEFINED WGPU_COMMA \ }) @@ -2264,7 +2243,7 @@ typedef struct WGPUStorageTextureBindingLayout { } WGPUStorageTextureBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_STORAGE_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUStorageTextureBindingLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.access=*/WGPUStorageTextureAccess_WriteOnly WGPU_COMMA \ /*.format=*/WGPUTextureFormat_Undefined WGPU_COMMA \ /*.viewDimension=*/WGPUTextureViewDimension_2D WGPU_COMMA \ @@ -2276,7 +2255,7 @@ typedef struct WGPUStringView { } WGPUStringView WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_STRING_VIEW_INIT WGPU_MAKE_INIT_STRUCT(WGPUStringView, { \ - /*.data=*/nullptr WGPU_COMMA \ + /*.data=*/NULL WGPU_COMMA \ /*.length=*/WGPU_STRLEN WGPU_COMMA \ }) @@ -2312,7 +2291,7 @@ typedef struct WGPUSurfaceCapabilities { } WGPUSurfaceCapabilities WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceCapabilities, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.usages=*/{} WGPU_COMMA \ /*.formatCount=*/{} WGPU_COMMA \ /*.formats=*/{} WGPU_COMMA \ @@ -2336,12 +2315,12 @@ typedef struct WGPUSurfaceConfiguration { } WGPUSurfaceConfiguration WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_CONFIGURATION_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceConfiguration, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.device=*/{} WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ /*.usage=*/WGPUTextureUsage_RenderAttachment WGPU_COMMA \ /*.viewFormatCount=*/0 WGPU_COMMA \ - /*.viewFormats=*/nullptr WGPU_COMMA \ + /*.viewFormats=*/NULL WGPU_COMMA \ /*.alphaMode=*/WGPUCompositeAlphaMode_Auto WGPU_COMMA \ /*.width=*/{} WGPU_COMMA \ /*.height=*/{} WGPU_COMMA \ @@ -2355,7 +2334,7 @@ typedef struct WGPUSurfaceDescriptorFromWindowsCoreWindow { } WGPUSurfaceDescriptorFromWindowsCoreWindow WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_CORE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsCoreWindow, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsCoreWindow} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsCoreWindow} WGPU_COMMA \ /*.coreWindow=*/{} WGPU_COMMA \ }) @@ -2366,7 +2345,7 @@ typedef struct WGPUSurfaceDescriptorFromWindowsSwapChainPanel { } WGPUSurfaceDescriptorFromWindowsSwapChainPanel WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_DESCRIPTOR_FROM_WINDOWS_SWAP_CHAIN_PANEL_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptorFromWindowsSwapChainPanel, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsSwapChainPanel} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceDescriptorFromWindowsSwapChainPanel} WGPU_COMMA \ /*.swapChainPanel=*/{} WGPU_COMMA \ }) @@ -2378,7 +2357,7 @@ typedef struct WGPUSurfaceSourceXCBWindow { } WGPUSurfaceSourceXCBWindow WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_XCB_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceXCBWindow, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXCBWindow} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXCBWindow} WGPU_COMMA \ /*.connection=*/{} WGPU_COMMA \ /*.window=*/{} WGPU_COMMA \ }) @@ -2390,7 +2369,7 @@ typedef struct WGPUSurfaceSourceAndroidNativeWindow { } WGPUSurfaceSourceAndroidNativeWindow WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_ANDROID_NATIVE_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceAndroidNativeWindow, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceAndroidNativeWindow} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceAndroidNativeWindow} WGPU_COMMA \ /*.window=*/{} WGPU_COMMA \ }) @@ -2401,7 +2380,7 @@ typedef struct WGPUSurfaceSourceMetalLayer { } WGPUSurfaceSourceMetalLayer WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_METAL_LAYER_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceMetalLayer, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceMetalLayer} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceMetalLayer} WGPU_COMMA \ /*.layer=*/{} WGPU_COMMA \ }) @@ -2413,7 +2392,7 @@ typedef struct WGPUSurfaceSourceWaylandSurface { } WGPUSurfaceSourceWaylandSurface WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_WAYLAND_SURFACE_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceWaylandSurface, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWaylandSurface} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWaylandSurface} WGPU_COMMA \ /*.display=*/{} WGPU_COMMA \ /*.surface=*/{} WGPU_COMMA \ }) @@ -2426,7 +2405,7 @@ typedef struct WGPUSurfaceSourceWindowsHWND { } WGPUSurfaceSourceWindowsHWND WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_WINDOWS_HWND_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceWindowsHWND, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWindowsHWND} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceWindowsHWND} WGPU_COMMA \ /*.hinstance=*/{} WGPU_COMMA \ /*.hwnd=*/{} WGPU_COMMA \ }) @@ -2439,7 +2418,7 @@ typedef struct WGPUSurfaceSourceXlibWindow { } WGPUSurfaceSourceXlibWindow WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_SOURCE_XLIB_WINDOW_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceSourceXlibWindow, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXlibWindow} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SurfaceSourceXlibWindow} WGPU_COMMA \ /*.display=*/{} WGPU_COMMA \ /*.window=*/{} WGPU_COMMA \ }) @@ -2464,7 +2443,7 @@ typedef struct WGPUTextureBindingLayout { } WGPUTextureBindingLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_BINDING_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureBindingLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.sampleType=*/WGPUTextureSampleType_Float WGPU_COMMA \ /*.viewDimension=*/WGPUTextureViewDimension_2D WGPU_COMMA \ /*.multisampled=*/false WGPU_COMMA \ @@ -2477,7 +2456,7 @@ typedef struct WGPUTextureBindingViewDimensionDescriptor { } WGPUTextureBindingViewDimensionDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_BINDING_VIEW_DIMENSION_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureBindingViewDimensionDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_TextureBindingViewDimensionDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_TextureBindingViewDimensionDescriptor} WGPU_COMMA \ /*.textureBindingViewDimension=*/WGPUTextureViewDimension_Undefined WGPU_COMMA \ }) @@ -2489,7 +2468,7 @@ typedef struct WGPUTextureDataLayout { } WGPUTextureDataLayout WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_DATA_LAYOUT_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureDataLayout, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.offset=*/0 WGPU_COMMA \ /*.bytesPerRow=*/WGPU_COPY_STRIDE_UNDEFINED WGPU_COMMA \ /*.rowsPerImage=*/WGPU_COPY_STRIDE_UNDEFINED WGPU_COMMA \ @@ -2526,7 +2505,7 @@ typedef struct WGPUYCbCrVkDescriptor { } WGPUYCbCrVkDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_Y_CB_CR_VK_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUYCbCrVkDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_YCbCrVkDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_YCbCrVkDescriptor} WGPU_COMMA \ /*.vkFormat=*/0 WGPU_COMMA \ /*.vkYCbCrModel=*/0 WGPU_COMMA \ /*.vkYCbCrRange=*/0 WGPU_COMMA \ @@ -2563,7 +2542,7 @@ typedef struct WGPUAdapterInfo { } WGPUAdapterInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.vendor=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.architecture=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.device=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ @@ -2583,7 +2562,7 @@ typedef struct WGPUAdapterPropertiesMemoryHeaps { } WGPUAdapterPropertiesMemoryHeaps WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_ADAPTER_PROPERTIES_MEMORY_HEAPS_INIT WGPU_MAKE_INIT_STRUCT(WGPUAdapterPropertiesMemoryHeaps, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesMemoryHeaps} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_AdapterPropertiesMemoryHeaps} WGPU_COMMA \ /*.heapCount=*/{} WGPU_COMMA \ /*.heapInfo=*/{} WGPU_COMMA \ }) @@ -2597,7 +2576,7 @@ typedef struct WGPUBindGroupDescriptor { } WGPUBindGroupDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BIND_GROUP_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.layout=*/{} WGPU_COMMA \ /*.entryCount=*/{} WGPU_COMMA \ @@ -2615,7 +2594,7 @@ typedef struct WGPUBindGroupLayoutEntry { } WGPUBindGroupLayoutEntry WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BIND_GROUP_LAYOUT_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupLayoutEntry, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.binding=*/{} WGPU_COMMA \ /*.visibility=*/{} WGPU_COMMA \ /*.buffer=*/WGPU_BUFFER_BINDING_LAYOUT_INIT WGPU_COMMA \ @@ -2643,7 +2622,7 @@ typedef struct WGPUBufferDescriptor { } WGPUBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBufferDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.usage=*/{} WGPU_COMMA \ /*.size=*/{} WGPU_COMMA \ @@ -2656,7 +2635,7 @@ typedef struct WGPUCommandBufferDescriptor { } WGPUCommandBufferDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMMAND_BUFFER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandBufferDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -2666,7 +2645,7 @@ typedef struct WGPUCommandEncoderDescriptor { } WGPUCommandEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMMAND_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUCommandEncoderDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -2684,7 +2663,7 @@ typedef struct WGPUCompilationMessage { } WGPUCompilationMessage WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPILATION_MESSAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationMessage, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.message=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.type=*/{} WGPU_COMMA \ /*.lineNum=*/{} WGPU_COMMA \ @@ -2699,13 +2678,13 @@ typedef struct WGPUCompilationMessage { typedef struct WGPUComputePassDescriptor { WGPUChainedStruct* nextInChain; WGPUStringView label; - WGPU_NULLABLE WGPUComputePassTimestampWrites const * timestampWrites; + WGPU_NULLABLE WGPUPassTimestampWrites const * timestampWrites; } WGPUComputePassDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPUTE_PASS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePassDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ - /*.timestampWrites=*/nullptr WGPU_COMMA \ + /*.timestampWrites=*/NULL WGPU_COMMA \ }) typedef struct WGPUConstantEntry { @@ -2715,7 +2694,7 @@ typedef struct WGPUConstantEntry { } WGPUConstantEntry WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_CONSTANT_ENTRY_INIT WGPU_MAKE_INIT_STRUCT(WGPUConstantEntry, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.key=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.value=*/{} WGPU_COMMA \ }) @@ -2730,11 +2709,11 @@ typedef struct WGPUDawnCacheDeviceDescriptor { } WGPUDawnCacheDeviceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_CACHE_DEVICE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnCacheDeviceDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnCacheDeviceDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnCacheDeviceDescriptor} WGPU_COMMA \ /*.isolationKey=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ - /*.loadDataFunction=*/nullptr WGPU_COMMA \ - /*.storeDataFunction=*/nullptr WGPU_COMMA \ - /*.functionUserdata=*/nullptr WGPU_COMMA \ + /*.loadDataFunction=*/NULL WGPU_COMMA \ + /*.storeDataFunction=*/NULL WGPU_COMMA \ + /*.functionUserdata=*/NULL WGPU_COMMA \ }) // Can be chained in WGPUDawnFormatCapabilities @@ -2745,7 +2724,7 @@ typedef struct WGPUDawnDrmFormatCapabilities { } WGPUDawnDrmFormatCapabilities WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DAWN_DRM_FORMAT_CAPABILITIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUDawnDrmFormatCapabilities, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_DawnDrmFormatCapabilities} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_DawnDrmFormatCapabilities} WGPU_COMMA \ /*.propertiesCount=*/{} WGPU_COMMA \ /*.properties=*/{} WGPU_COMMA \ }) @@ -2765,7 +2744,7 @@ typedef struct WGPUDepthStencilState { } WGPUDepthStencilState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DEPTH_STENCIL_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUDepthStencilState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ /*.depthWriteEnabled=*/WGPUOptionalBool_Undefined WGPU_COMMA \ /*.depthCompare=*/WGPUCompareFunction_Undefined WGPU_COMMA \ @@ -2785,7 +2764,7 @@ typedef struct WGPUEmscriptenSurfaceSourceCanvasHTMLSelector { } WGPUEmscriptenSurfaceSourceCanvasHTMLSelector WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EMSCRIPTEN_SURFACE_SOURCE_CANVAS_HTML_SELECTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUEmscriptenSurfaceSourceCanvasHTMLSelector, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_EmscriptenSurfaceSourceCanvasHTMLSelector} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_EmscriptenSurfaceSourceCanvasHTMLSelector} WGPU_COMMA \ /*.selector=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -2807,15 +2786,15 @@ typedef struct WGPUExternalTextureDescriptor { } WGPUExternalTextureDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_EXTERNAL_TEXTURE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUExternalTextureDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.plane0=*/{} WGPU_COMMA \ - /*.plane1=*/nullptr WGPU_COMMA \ + /*.plane1=*/NULL WGPU_COMMA \ /*.cropOrigin=*/WGPU_ORIGIN_2D_INIT WGPU_COMMA \ /*.cropSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ /*.apparentSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ /*.doYuvToRgbConversionOnly=*/false WGPU_COMMA \ - /*.yuvToRgbConversionMatrix=*/nullptr WGPU_COMMA \ + /*.yuvToRgbConversionMatrix=*/NULL WGPU_COMMA \ /*.srcTransferFunctionParameters=*/{} WGPU_COMMA \ /*.dstTransferFunctionParameters=*/{} WGPU_COMMA \ /*.gamutConversionMatrix=*/{} WGPU_COMMA \ @@ -2851,7 +2830,7 @@ typedef struct WGPUImageCopyExternalTexture { } WGPUImageCopyExternalTexture WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_IMAGE_COPY_EXTERNAL_TEXTURE_INIT WGPU_MAKE_INIT_STRUCT(WGPUImageCopyExternalTexture, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.externalTexture=*/{} WGPU_COMMA \ /*.origin=*/WGPU_ORIGIN_3D_INIT WGPU_COMMA \ /*.naturalSize=*/WGPU_EXTENT_2D_INIT WGPU_COMMA \ @@ -2873,12 +2852,14 @@ typedef struct WGPUImageCopyTexture { typedef struct WGPUInstanceDescriptor { WGPUChainedStruct* nextInChain; - WGPUInstanceFeatures features; + WGPUInstanceCapabilities capabilities; + WGPUInstanceCapabilities features; } WGPUInstanceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_INSTANCE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUInstanceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.features=*/WGPU_INSTANCE_FEATURES_INIT WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.capabilities=*/WGPU_INSTANCE_CAPABILITIES_INIT WGPU_COMMA \ + /*.features=*/WGPU_INSTANCE_CAPABILITIES_INIT WGPU_COMMA \ }) typedef struct WGPUPipelineLayoutDescriptor { @@ -2890,10 +2871,10 @@ typedef struct WGPUPipelineLayoutDescriptor { } WGPUPipelineLayoutDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PIPELINE_LAYOUT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.bindGroupLayoutCount=*/{} WGPU_COMMA \ - /*.bindGroupLayouts=*/nullptr WGPU_COMMA \ + /*.bindGroupLayouts=*/NULL WGPU_COMMA \ /*.immediateDataRangeByteSize=*/0 WGPU_COMMA \ }) @@ -2906,7 +2887,7 @@ typedef struct WGPUPipelineLayoutPixelLocalStorage { } WGPUPipelineLayoutPixelLocalStorage WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_PIPELINE_LAYOUT_PIXEL_LOCAL_STORAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPUPipelineLayoutPixelLocalStorage, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_PipelineLayoutPixelLocalStorage} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_PipelineLayoutPixelLocalStorage} WGPU_COMMA \ /*.totalPixelLocalStorageSize=*/{} WGPU_COMMA \ /*.storageAttachmentCount=*/0 WGPU_COMMA \ /*.storageAttachments=*/{} WGPU_COMMA \ @@ -2920,7 +2901,7 @@ typedef struct WGPUQuerySetDescriptor { } WGPUQuerySetDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_QUERY_SET_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQuerySetDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.type=*/{} WGPU_COMMA \ /*.count=*/{} WGPU_COMMA \ @@ -2932,7 +2913,7 @@ typedef struct WGPUQueueDescriptor { } WGPUQueueDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_QUEUE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUQueueDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -2942,7 +2923,7 @@ typedef struct WGPURenderBundleDescriptor { } WGPURenderBundleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_BUNDLE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -2958,7 +2939,7 @@ typedef struct WGPURenderBundleEncoderDescriptor { } WGPURenderBundleEncoderDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_BUNDLE_ENCODER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderBundleEncoderDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.colorFormatCount=*/{} WGPU_COMMA \ /*.colorFormats=*/{} WGPU_COMMA \ @@ -2979,10 +2960,10 @@ typedef struct WGPURenderPassColorAttachment { } WGPURenderPassColorAttachment WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_COLOR_ATTACHMENT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassColorAttachment, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ - /*.view=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ + /*.view=*/NULL WGPU_COMMA \ /*.depthSlice=*/WGPU_DEPTH_SLICE_UNDEFINED WGPU_COMMA \ - /*.resolveTarget=*/nullptr WGPU_COMMA \ + /*.resolveTarget=*/NULL WGPU_COMMA \ /*.loadOp=*/{} WGPU_COMMA \ /*.storeOp=*/{} WGPU_COMMA \ /*.clearValue=*/WGPU_COLOR_INIT WGPU_COMMA \ @@ -2998,7 +2979,7 @@ typedef struct WGPURenderPassStorageAttachment { } WGPURenderPassStorageAttachment WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_STORAGE_ATTACHMENT_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassStorageAttachment, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.offset=*/0 WGPU_COMMA \ /*.storage=*/{} WGPU_COMMA \ /*.loadOp=*/{} WGPU_COMMA \ @@ -3012,7 +2993,7 @@ typedef struct WGPURequiredLimits { } WGPURequiredLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_REQUIRED_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPURequiredLimits, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.limits=*/WGPU_LIMITS_INIT WGPU_COMMA \ }) @@ -3032,7 +3013,7 @@ typedef struct WGPUSamplerDescriptor { } WGPUSamplerDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SAMPLER_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSamplerDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.addressModeU=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ /*.addressModeV=*/WGPUAddressMode_ClampToEdge WGPU_COMMA \ @@ -3052,7 +3033,7 @@ typedef struct WGPUShaderModuleDescriptor { } WGPUShaderModuleDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHADER_MODULE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderModuleDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3063,7 +3044,7 @@ typedef struct WGPUShaderSourceWGSL { } WGPUShaderSourceWGSL WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHADER_SOURCE_WGSL_INIT WGPU_MAKE_INIT_STRUCT(WGPUShaderSourceWGSL, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceWGSL} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_ShaderSourceWGSL} WGPU_COMMA \ /*.code=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3073,7 +3054,7 @@ typedef struct WGPUSharedBufferMemoryDescriptor { } WGPUSharedBufferMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_BUFFER_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedBufferMemoryDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3083,7 +3064,7 @@ typedef struct WGPUSharedFenceDescriptor { } WGPUSharedFenceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_FENCE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedFenceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3094,7 +3075,7 @@ typedef struct WGPUSharedTextureMemoryAHardwareBufferProperties { } WGPUSharedTextureMemoryAHardwareBufferProperties WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_A_HARDWARE_BUFFER_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryAHardwareBufferProperties, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferProperties} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryAHardwareBufferProperties} WGPU_COMMA \ /*.yCbCrInfo=*/WGPU_Y_CB_CR_VK_DESCRIPTOR_INIT WGPU_COMMA \ }) @@ -3104,7 +3085,7 @@ typedef struct WGPUSharedTextureMemoryDescriptor { } WGPUSharedTextureMemoryDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3119,7 +3100,7 @@ typedef struct WGPUSharedTextureMemoryDmaBufDescriptor { } WGPUSharedTextureMemoryDmaBufDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_DMA_BUF_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryDmaBufDescriptor, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDmaBufDescriptor} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_SharedTextureMemoryDmaBufDescriptor} WGPU_COMMA \ /*.size=*/WGPU_EXTENT_3D_INIT WGPU_COMMA \ /*.drmFormat=*/{} WGPU_COMMA \ /*.drmModifier=*/{} WGPU_COMMA \ @@ -3135,7 +3116,7 @@ typedef struct WGPUSharedTextureMemoryProperties { } WGPUSharedTextureMemoryProperties WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SHARED_TEXTURE_MEMORY_PROPERTIES_INIT WGPU_MAKE_INIT_STRUCT(WGPUSharedTextureMemoryProperties, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.usage=*/{} WGPU_COMMA \ /*.size=*/WGPU_EXTENT_3D_INIT WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ @@ -3147,7 +3128,7 @@ typedef struct WGPUSupportedLimits { } WGPUSupportedLimits WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SUPPORTED_LIMITS_INIT WGPU_MAKE_INIT_STRUCT(WGPUSupportedLimits, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.limits=*/WGPU_LIMITS_INIT WGPU_COMMA \ }) @@ -3157,7 +3138,7 @@ typedef struct WGPUSurfaceDescriptor { } WGPUSurfaceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_SURFACE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUSurfaceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ }) @@ -3175,7 +3156,7 @@ typedef struct WGPUTextureDescriptor { } WGPUTextureDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.usage=*/{} WGPU_COMMA \ /*.dimension=*/WGPUTextureDimension_2D WGPU_COMMA \ @@ -3184,7 +3165,7 @@ typedef struct WGPUTextureDescriptor { /*.mipLevelCount=*/1 WGPU_COMMA \ /*.sampleCount=*/1 WGPU_COMMA \ /*.viewFormatCount=*/0 WGPU_COMMA \ - /*.viewFormats=*/nullptr WGPU_COMMA \ + /*.viewFormats=*/NULL WGPU_COMMA \ }) typedef struct WGPUTextureViewDescriptor { @@ -3201,7 +3182,7 @@ typedef struct WGPUTextureViewDescriptor { } WGPUTextureViewDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_TEXTURE_VIEW_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUTextureViewDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.format=*/WGPUTextureFormat_Undefined WGPU_COMMA \ /*.dimension=*/WGPUTextureViewDimension_Undefined WGPU_COMMA \ @@ -3235,7 +3216,7 @@ typedef struct WGPUBindGroupLayoutDescriptor { } WGPUBindGroupLayoutDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_BIND_GROUP_LAYOUT_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUBindGroupLayoutDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.entryCount=*/{} WGPU_COMMA \ /*.entries=*/{} WGPU_COMMA \ @@ -3249,9 +3230,9 @@ typedef struct WGPUColorTargetState { } WGPUColorTargetState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COLOR_TARGET_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUColorTargetState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.format=*/{} WGPU_COMMA \ - /*.blend=*/nullptr WGPU_COMMA \ + /*.blend=*/NULL WGPU_COMMA \ /*.writeMask=*/WGPUColorWriteMask_All WGPU_COMMA \ }) @@ -3262,7 +3243,7 @@ typedef struct WGPUCompilationInfo { } WGPUCompilationInfo WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPILATION_INFO_INIT WGPU_MAKE_INIT_STRUCT(WGPUCompilationInfo, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.messageCount=*/{} WGPU_COMMA \ /*.messages=*/{} WGPU_COMMA \ }) @@ -3276,7 +3257,7 @@ typedef struct WGPUComputeState { } WGPUComputeState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPUTE_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputeState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.module=*/{} WGPU_COMMA \ /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.constantCount=*/0 WGPU_COMMA \ @@ -3295,11 +3276,11 @@ typedef struct WGPUDeviceDescriptor { } WGPUDeviceDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_DEVICE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUDeviceDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.requiredFeatureCount=*/0 WGPU_COMMA \ - /*.requiredFeatures=*/nullptr WGPU_COMMA \ - /*.requiredLimits=*/nullptr WGPU_COMMA \ + /*.requiredFeatures=*/NULL WGPU_COMMA \ + /*.requiredLimits=*/NULL WGPU_COMMA \ /*.defaultQueue=*/WGPU_QUEUE_DESCRIPTOR_INIT WGPU_COMMA \ /*.deviceLostCallbackInfo=*/{} WGPU_COMMA \ /*.uncapturedErrorCallbackInfo=*/{} WGPU_COMMA \ @@ -3312,17 +3293,17 @@ typedef struct WGPURenderPassDescriptor { WGPURenderPassColorAttachment const * colorAttachments; WGPU_NULLABLE WGPURenderPassDepthStencilAttachment const * depthStencilAttachment; WGPU_NULLABLE WGPUQuerySet occlusionQuerySet; - WGPU_NULLABLE WGPURenderPassTimestampWrites const * timestampWrites; + WGPU_NULLABLE WGPUPassTimestampWrites const * timestampWrites; } WGPURenderPassDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.colorAttachmentCount=*/{} WGPU_COMMA \ /*.colorAttachments=*/{} WGPU_COMMA \ - /*.depthStencilAttachment=*/nullptr WGPU_COMMA \ - /*.occlusionQuerySet=*/nullptr WGPU_COMMA \ - /*.timestampWrites=*/nullptr WGPU_COMMA \ + /*.depthStencilAttachment=*/NULL WGPU_COMMA \ + /*.occlusionQuerySet=*/NULL WGPU_COMMA \ + /*.timestampWrites=*/NULL WGPU_COMMA \ }) // Can be chained in WGPURenderPassDescriptor @@ -3334,7 +3315,7 @@ typedef struct WGPURenderPassPixelLocalStorage { } WGPURenderPassPixelLocalStorage WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PASS_PIXEL_LOCAL_STORAGE_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPassPixelLocalStorage, { \ - /*.chain=*/{/*.nextInChain*/nullptr WGPU_COMMA /*.sType*/WGPUSType_RenderPassPixelLocalStorage} WGPU_COMMA \ + /*.chain=*/{/*.nextInChain*/NULL WGPU_COMMA /*.sType*/WGPUSType_RenderPassPixelLocalStorage} WGPU_COMMA \ /*.totalPixelLocalStorageSize=*/{} WGPU_COMMA \ /*.storageAttachmentCount=*/0 WGPU_COMMA \ /*.storageAttachments=*/{} WGPU_COMMA \ @@ -3351,7 +3332,7 @@ typedef struct WGPUVertexState { } WGPUVertexState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_VERTEX_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUVertexState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.module=*/{} WGPU_COMMA \ /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.constantCount=*/0 WGPU_COMMA \ @@ -3368,9 +3349,9 @@ typedef struct WGPUComputePipelineDescriptor { } WGPUComputePipelineDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_COMPUTE_PIPELINE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPUComputePipelineDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ - /*.layout=*/nullptr WGPU_COMMA \ + /*.layout=*/NULL WGPU_COMMA \ /*.compute=*/WGPU_COMPUTE_STATE_INIT WGPU_COMMA \ }) @@ -3385,7 +3366,7 @@ typedef struct WGPUFragmentState { } WGPUFragmentState WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_FRAGMENT_STATE_INIT WGPU_MAKE_INIT_STRUCT(WGPUFragmentState, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.module=*/{} WGPU_COMMA \ /*.entryPoint=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ /*.constantCount=*/0 WGPU_COMMA \ @@ -3406,20 +3387,28 @@ typedef struct WGPURenderPipelineDescriptor { } WGPURenderPipelineDescriptor WGPU_STRUCTURE_ATTRIBUTE; #define WGPU_RENDER_PIPELINE_DESCRIPTOR_INIT WGPU_MAKE_INIT_STRUCT(WGPURenderPipelineDescriptor, { \ - /*.nextInChain=*/nullptr WGPU_COMMA \ + /*.nextInChain=*/NULL WGPU_COMMA \ /*.label=*/WGPU_STRING_VIEW_INIT WGPU_COMMA \ - /*.layout=*/nullptr WGPU_COMMA \ + /*.layout=*/NULL WGPU_COMMA \ /*.vertex=*/WGPU_VERTEX_STATE_INIT WGPU_COMMA \ /*.primitive=*/WGPU_PRIMITIVE_STATE_INIT WGPU_COMMA \ - /*.depthStencil=*/nullptr WGPU_COMMA \ + /*.depthStencil=*/NULL WGPU_COMMA \ /*.multisample=*/WGPU_MULTISAMPLE_STATE_INIT WGPU_COMMA \ - /*.fragment=*/nullptr WGPU_COMMA \ + /*.fragment=*/NULL WGPU_COMMA \ }) +// WGPUComputePassTimestampWrites is deprecated. +// Use WGPUPassTimestampWrites instead. +typedef WGPUPassTimestampWrites WGPUComputePassTimestampWrites; + // WGPURenderPassDescriptorMaxDrawCount is deprecated. // Use WGPURenderPassMaxDrawCount instead. typedef WGPURenderPassMaxDrawCount WGPURenderPassDescriptorMaxDrawCount; +// WGPURenderPassTimestampWrites is deprecated. +// Use WGPUPassTimestampWrites instead. +typedef WGPUPassTimestampWrites WGPURenderPassTimestampWrites; + // WGPUShaderModuleSPIRVDescriptor is deprecated. // Use WGPUShaderSourceSPIRV instead. typedef WGPUShaderSourceSPIRV WGPUShaderModuleSPIRVDescriptor; @@ -3468,7 +3457,7 @@ typedef void (*WGPUProcAdapterInfoFreeMembers)( WGPUAdapterInfo value) WG typedef void (*WGPUProcAdapterPropertiesMemoryHeapsFreeMembers)( WGPUAdapterPropertiesMemoryHeaps value) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUInstance (*WGPUProcCreateInstance)( WGPU_NULLABLE WGPUInstanceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcDawnDrmFormatCapabilitiesFreeMembers)( WGPUDawnDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; -typedef WGPUStatus (*WGPUProcGetInstanceFeatures)( WGPUInstanceFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +typedef WGPUStatus (*WGPUProcGetInstanceCapabilities)( WGPUInstanceCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUProc (*WGPUProcGetProcAddress)( WGPUStringView procName) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedBufferMemoryEndAccessStateFreeMembers)( WGPUSharedBufferMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcSharedTextureMemoryEndAccessStateFreeMembers)( WGPUSharedTextureMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; @@ -3609,7 +3598,6 @@ typedef void (*WGPUProcExternalTextureRelease)(WGPUExternalTexture externalTextu // Procs of Instance typedef WGPUSurface (*WGPUProcInstanceCreateSurface)(WGPUInstance instance, WGPUSurfaceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -typedef size_t (*WGPUProcInstanceEnumerateWGSLLanguageFeatures)(WGPUInstance instance, WGPUWGSLFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUStatus (*WGPUProcInstanceGetWGSLLanguageFeatures)(WGPUInstance instance, WGPUSupportedWGSLLanguageFeatures * features) WGPU_FUNCTION_ATTRIBUTE; typedef WGPUBool (*WGPUProcInstanceHasWGSLLanguageFeature)(WGPUInstance instance, WGPUWGSLLanguageFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; typedef void (*WGPUProcInstanceProcessEvents)(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; @@ -3774,7 +3762,7 @@ WGPU_EXPORT void wgpuAdapterInfoFreeMembers(WGPUAdapterInfo value) WGPU_FUNCTION WGPU_EXPORT void wgpuAdapterPropertiesMemoryHeapsFreeMembers(WGPUAdapterPropertiesMemoryHeaps value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUInstance wgpuCreateInstance(WGPU_NULLABLE WGPUInstanceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuDawnDrmFormatCapabilitiesFreeMembers(WGPUDawnDrmFormatCapabilities value) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT WGPUStatus wgpuGetInstanceFeatures(WGPUInstanceFeatures * features) WGPU_FUNCTION_ATTRIBUTE; +WGPU_EXPORT WGPUStatus wgpuGetInstanceCapabilities(WGPUInstanceCapabilities * capabilities) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUProc wgpuGetProcAddress(WGPUStringView procName) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedBufferMemoryEndAccessStateFreeMembers(WGPUSharedBufferMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuSharedTextureMemoryEndAccessStateFreeMembers(WGPUSharedTextureMemoryEndAccessState value) WGPU_FUNCTION_ATTRIBUTE; @@ -3915,7 +3903,6 @@ WGPU_EXPORT void wgpuExternalTextureRelease(WGPUExternalTexture externalTexture) // Methods of Instance WGPU_EXPORT WGPUSurface wgpuInstanceCreateSurface(WGPUInstance instance, WGPUSurfaceDescriptor const * descriptor) WGPU_FUNCTION_ATTRIBUTE; -WGPU_EXPORT size_t wgpuInstanceEnumerateWGSLLanguageFeatures(WGPUInstance instance, WGPUWGSLFeatureName * features) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUStatus wgpuInstanceGetWGSLLanguageFeatures(WGPUInstance instance, WGPUSupportedWGSLLanguageFeatures * features) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT WGPUBool wgpuInstanceHasWGSLLanguageFeature(WGPUInstance instance, WGPUWGSLLanguageFeatureName feature) WGPU_FUNCTION_ATTRIBUTE; WGPU_EXPORT void wgpuInstanceProcessEvents(WGPUInstance instance) WGPU_FUNCTION_ATTRIBUTE; diff --git a/third_party/local/.gitkeep b/third_party/local/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/third_party/local/WebGPU-distribution b/third_party/local/WebGPU-distribution deleted file mode 160000 index 1025b97..0000000 --- a/third_party/local/WebGPU-distribution +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1025b977e1927b6d0327e67352f90feb4bcf8274 From c3ee69b894816c641afabeb9be738d6d6fad6a4d Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 2 Feb 2025 11:07:57 -0500 Subject: [PATCH 39/44] move web build example to experimental due to emscriptens webgpu implementation lagging. few readme tweaks --- README.md | 16 ++++------------ examples/README.md | 2 -- {examples => experimental}/web/CMakeLists.txt | 0 {examples => experimental}/web/Makefile | 0 experimental/web/README.md | 3 +++ {examples => experimental}/web/build/.gitkeep | 0 {examples => experimental}/web/custom_shell.html | 0 {examples => experimental}/web/run.cpp | 0 8 files changed, 7 insertions(+), 14 deletions(-) rename {examples => experimental}/web/CMakeLists.txt (100%) rename {examples => experimental}/web/Makefile (100%) create mode 100644 experimental/web/README.md rename {examples => experimental}/web/build/.gitkeep (100%) rename {examples => experimental}/web/custom_shell.html (100%) rename {examples => experimental}/web/run.cpp (100%) diff --git a/README.md b/README.md index 4b69bef..46340b7 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ GPU code in C++ projects and have it run on Nvidia, Intel, AMD, and other GPUs. The same C++ code can work on a wide variety of laptops, workstations, mobile devices or virtually any hardware with Vulkan, Metal, or DirectX support. -## Technical Objectives: Lightweight, Fast Iteration, and Low Boilerplate +## Objectives: Lightweight, Fast Iteration, and Low Boilerplate With gpu.cpp we want to enable a high-leverage library for individual developers and researchers to incorporate GPU computation into programs relying on nothing more than a standard C++ compiler as tooling. Our goals are: @@ -189,7 +189,7 @@ illustrate how to use gpu.cpp as a library. After you have run `make` in the top-level directory which retrieves the prebuilt Dawn shared library, you can run each example by navigating to its directory and running `make` from the example's directory. -An example of tiled matrix multiplication is in [examples/matmul](https://github.com/AnswerDotAI/gpu.cpp/blob/main/examples/matmul/). This implements a WebGPU version of the first few kernels of Simon Boehm's [How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog](https://siboehm.com/articles/22/CUDA-MMM) post. It currently runs at ~ 2.5+ TFLOPs on a Macbook Pro M1 Max laptop, which has a theoretical peak of 10.4 TFLOPs. Contributions to optimize this further are welcome. +An example of tiled matrix multiplication is in [examples/matmul](https://github.com/AnswerDotAI/gpu.cpp/blob/main/examples/matmul/). This implements a WebGPU version of the first few kernels of Simon Boehm's [How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog](https://siboehm.com/articles/22/CUDA-MMM) post. It currently runs at ~ 3.5+ TFLOPs on a Macbook Pro M1 Max laptop. Contributions to optimize this further are welcome. A parallel physics simulation of an ensemble of double pendulums simulated in parallel with different initial conditions on the GPU is shown in [examples/physics](https://github.com/AnswerDotAI/gpu.cpp/tree/main/examples/physics). @@ -198,9 +198,7 @@ A parallel physics simulation of an ensemble of double pendulums simulated in pa physics example animated gif -We also show some examples of signed distance function computations, rendered in the terminal as ascii. A 3D SDF of spheres is shown in [examples/render](https://github.com/AnswerDotAI/gpu.cpp/tree/main/examples/render]) and a shadertoy-like live-reloading example is in [examples/shadertui](https://github.com/AnswerDotAI/gpu.cpp/tree/main/examples/shadertui). - -Interestingly, given a starting example, LLMs such as Claude 3.5 Sonnet can be quite capable at writing low-level WGSL code for you - the other shaders in the shadertui example are written by the LLM. +We also show some examples of signed distance function computations, rendered in the terminal as ascii. A 3D SDF of spheres is shown in [examples/render](https://github.com/AnswerDotAI/gpu.cpp/tree/main/examples/render) and a shadertoy-like live-reloading example is in [examples/shadertui](https://github.com/AnswerDotAI/gpu.cpp/tree/main/examples/shadertui).
shadertui example animated gif @@ -232,22 +230,16 @@ gpu.cpp lets us implement and drop-in any algorithm with fine-grained control of gpu.cpp is meant for developers with some familiarity with C++ and GPU programming. It is not a high-level numerical computing or machine learning framework or inference engine, though it can be used in support of such implementations. -Second, in spite of the name, WebGPU has native implementations decoupled from the web and the browser. gpu.cpp leverages WebGPU as a portable _native_ GPU API first and foremost, with the possibility of running in the browser being a convenient additional benefit in the future. - -If you find it counterintuitive, as many do, that WebGPU is a native technology and not just for the web, watch Elie Michel's excellent talk ["WebGPU is Not Just About the Web"](https://www.youtube.com/watch?v=qHrx41aOTUQ). +Second, in spite of the name, WebGPU has native implementations decoupled from the web and the browser. If you find it counterintuitive, watch Elie Michel's excellent talk ["WebGPU is Not Just About the Web"](https://www.youtube.com/watch?v=qHrx41aOTUQ). Finally, the focus of gpu.cpp is general-purpose GPU computation rather than rendering/graphics on the GPU, although it can be useful for offline rendering or video processing use cases. We may explore directions with graphics in the future, but for now our focus is GPU compute. ## Limitations and Upcoming Features -_API Improvements_ - gpu.cpp is a work-in-progress and there are many features and improvements to come. At this early stage, we expect the API design to evolve as we identify improvements / needs from use cases. In particular, the handling of structured parameters and asynchronous dispatch will undergo refinement and maturation in the short-term. - _Browser Targets_ - In spite of using WebGPU we haven't tested builds targeting the browser yet though this is a short-term priority. _Reusable Kernel Library_ - Currently the core library is strictly the operations and types for interfacing with the WebGPU API, with some specific use case example WGSL implementations in `examples/`. Over time, as kernel implementations mature we may migrate some of the reusable operations from specific examples into a small reusable kernel library. -_More Use Case Examples and Tests_ - Expect an iteration loop of use cases to design tweaks and improvements, which in turn make the use cases cleaner and easier to write. One short term use cases to flesh out the kernels from [llm.c](https://github.com/karpathy/llm.c) in WebGPU form. As these mature into a reusable kernel library, we hope to help realize the potential for WebGPU compute in AI. - ## Troubleshooting If you run into issues building the project, please open an issue. diff --git a/examples/README.md b/examples/README.md index b73de3b..bfd513e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -18,7 +18,6 @@ directory of the repository. | [shadertui](shadertui) | An example of runtime live reloading of WGSL - demonstrated using a terminal shadertoy-like scii rendering. | | [render](render) | GPU ascii rendering of a signed distance function for two rotating 3D spheres. | | [physics](physics) | Parallel physics simulation of a double pendulum with each thread starting at a different initial condition. | -| [web](web) | A minimal example of how to use gpu.cpp to build a WebAssembly module that runs in the browser. Before building this example, make sure you've installed the emscripten sdk by following the [instructions here](https://emscripten.org/docs/getting_started/downloads.html) and run `source emsdk_env.sh` from the `emsdk/` directory that was created when you cloned the emscripten repository. | ## Advanced Examples @@ -27,4 +26,3 @@ directory of the repository. | [float16](float16) | Hello World example using the float16 WebGPU extension, instead of the default float32. | | [matmul](matmul) | Tiled matrix multiplication. | | [transpose](transpose) | Tiled matrix transpose. | -| [webgpu_from_scratch](webgpu_from_scratch) | A minimal from-scratch example of how to use WebGPU directly without this library. This is useful to understand the code internals of gpu.cpp. Note this takes a while to build as it compiles the WebGPU C API implementation. | diff --git a/examples/web/CMakeLists.txt b/experimental/web/CMakeLists.txt similarity index 100% rename from examples/web/CMakeLists.txt rename to experimental/web/CMakeLists.txt diff --git a/examples/web/Makefile b/experimental/web/Makefile similarity index 100% rename from examples/web/Makefile rename to experimental/web/Makefile diff --git a/experimental/web/README.md b/experimental/web/README.md new file mode 100644 index 0000000..6b4e11b --- /dev/null +++ b/experimental/web/README.md @@ -0,0 +1,3 @@ +Warning: web targets are not supported for now. + +We'll enable them and move this to examples/ once emscripten's WebGPU implementation catches up with the Dawn commit we're using. diff --git a/examples/web/build/.gitkeep b/experimental/web/build/.gitkeep similarity index 100% rename from examples/web/build/.gitkeep rename to experimental/web/build/.gitkeep diff --git a/examples/web/custom_shell.html b/experimental/web/custom_shell.html similarity index 100% rename from examples/web/custom_shell.html rename to experimental/web/custom_shell.html diff --git a/examples/web/run.cpp b/experimental/web/run.cpp similarity index 100% rename from examples/web/run.cpp rename to experimental/web/run.cpp From 3924552b206f35629e3992045e8b4c8a765b4a80 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 2 Feb 2025 11:09:58 -0500 Subject: [PATCH 40/44] update artifact link --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 964f345..40cc5cc 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,8 @@ def download_dawn(os_name): "Linux": "third_party/lib/libwebgpu_dawn.so", } url_map = { - "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0-pre2/libwebgpu_dawn.dylib", - "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0-pre2/libwebgpu_dawn.so", + "macOS": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.dylib", + "Linux": "https://github.com/austinvhuang/dawn-artifacts/releases/download/0.2.0/libwebgpu_dawn.so", } outfile = outfile_map.get(os_name) From a8a44d30bf74f83ebfc4b639dc5c0e18a84d3d92 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 2 Feb 2025 11:18:41 -0500 Subject: [PATCH 41/44] skip float16 targets in CI --- .github/workflows/build.yml | 2 +- Makefile | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 21eacea..74a5f88 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,7 +30,7 @@ jobs: sudo apt-get install -y libxrandr-dev - name: Build - run: make all + run: make all-portable - name: Run hello world run: make diff --git a/Makefile b/Makefile index 8e5d67b..434c4b3 100644 --- a/Makefile +++ b/Makefile @@ -69,6 +69,15 @@ all: dawnlib check-clang check-linux-vulkan lib pch cd examples/shadertui && make build/shadertui cd examples/transpose && make build/transpose +all-portable: dawnlib check-clang check-linux-vulkan lib pch + cd examples/gpu_puzzles && make build/gpu_puzzles + cd examples/hello_world && make build/hello_world + cd examples/matmul && export MATMUL_VERSION=9 && make build/matmul + cd examples/physics && make build/physics + cd examples/render && make build/render + cd examples/shadertui && make build/shadertui + cd examples/transpose && make build/transpose + # Test 16-bit floating point type test-half: dawnlib check-clang $(LIBSPEC) && clang++ -std=c++17 $(INCLUDES) numeric_types/half.cpp -L$(LIBDIR) -lwebgpu_dawn -ldl -o build/half && ./build/half From 7dc064ca5a33ac12aff59e1149da643882c42411 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 2 Feb 2025 11:47:19 -0500 Subject: [PATCH 42/44] test float16 in CI --- .github/workflows/build.yml | 2 +- Makefile | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 74a5f88..21eacea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -30,7 +30,7 @@ jobs: sudo apt-get install -y libxrandr-dev - name: Build - run: make all-portable + run: make all - name: Run hello world run: make diff --git a/Makefile b/Makefile index 434c4b3..8e5d67b 100644 --- a/Makefile +++ b/Makefile @@ -69,15 +69,6 @@ all: dawnlib check-clang check-linux-vulkan lib pch cd examples/shadertui && make build/shadertui cd examples/transpose && make build/transpose -all-portable: dawnlib check-clang check-linux-vulkan lib pch - cd examples/gpu_puzzles && make build/gpu_puzzles - cd examples/hello_world && make build/hello_world - cd examples/matmul && export MATMUL_VERSION=9 && make build/matmul - cd examples/physics && make build/physics - cd examples/render && make build/render - cd examples/shadertui && make build/shadertui - cd examples/transpose && make build/transpose - # Test 16-bit floating point type test-half: dawnlib check-clang $(LIBSPEC) && clang++ -std=c++17 $(INCLUDES) numeric_types/half.cpp -L$(LIBDIR) -lwebgpu_dawn -ldl -o build/half && ./build/half From 041d2fd1fcef84bd04294f1fd02fe6f495a644b2 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 7 Feb 2025 15:52:53 +0900 Subject: [PATCH 43/44] Fix pybind --- bindings/python/Makefile | 5 ++--- bindings/python/gpu_cpp.cpp | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/bindings/python/Makefile b/bindings/python/Makefile index fa70468..78e0b58 100644 --- a/bindings/python/Makefile +++ b/bindings/python/Makefile @@ -10,15 +10,14 @@ else STDLIB := -stdlib=libc++ endif -FLAGS=-shared -fPIC -std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib -ldawn \ +FLAGS=-shared -fPIC -std=c++17 $(STDLIB) -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -L$(GPUCPP)/third_party/lib -lwebgpu_dawn \ `python3 -m pybind11 --includes` \ - `python3-config --include --ldflags --embed` + `python3-config --includes --ldflags` SUFFIX=$(shell $(PYTHON)-config --extension-suffix) gpu_cpp$(SUFFIX): gpu_cpp.cpp $(CXX) $(FLAGS) -o $@ $< - install_name_tool -change @rpath/libdawn.dylib $(LIBDIR)/libdawn.dylib gpu_cpp$(SUFFIX) test: test_gpu_cpp.py gpu_cpp$(SUFFIX) $(PYTHON) test_gpu_cpp.py diff --git a/bindings/python/gpu_cpp.cpp b/bindings/python/gpu_cpp.cpp index 51f7c9e..8bd762d 100644 --- a/bindings/python/gpu_cpp.cpp +++ b/bindings/python/gpu_cpp.cpp @@ -40,7 +40,7 @@ KernelCode* py_createKernelCode(const std::string &pData, size_t workgroupSize, return new KernelCode(pData, workgroupSize, (NumType)precision); } -Kernel* py_createKernel(Context *ctx, const KernelCode *code, +Kernel py_createKernel(Context *ctx, const KernelCode *code, // const Tensor *dataBindings, size_t numTensors, const py::list& dataBindings_py, // const size_t *viewOffsets, @@ -54,7 +54,7 @@ Kernel* py_createKernel(Context *ctx, const KernelCode *code, for (auto item : viewOffsets_py) { viewOffsets.push_back(item.cast()); } - return new Kernel(createKernel(*ctx, *code, bindings.data(), bindings.size(), viewOffsets.data(), vector_to_shape(totalWorkgroups))); + return createKernel(*ctx, *code, bindings.data(), bindings.size(), viewOffsets.data(), vector_to_shape(totalWorkgroups)); } Tensor* py_createTensor(Context *ctx, const std::vector &dims, int dtype) { @@ -82,9 +82,9 @@ struct GpuAsync { } }; -GpuAsync* py_dispatchKernel(Context *ctx, Kernel *kernel) { +GpuAsync* py_dispatchKernel(Context *ctx, Kernel kernel) { auto async = new GpuAsync(); - dispatchKernel(*ctx, *kernel, async->promise); + dispatchKernel(*ctx, kernel, async->promise); return async; } @@ -96,12 +96,12 @@ PYBIND11_MODULE(gpu_cpp, m) { m.doc() = "gpu.cpp plugin"; py::class_(m, "Context"); py::class_(m, "Tensor"); - py::class_(m, "Kernel"); + py::class_>(m, "Kernel"); py::class_(m, "KernelCode"); py::class_(m, "GpuAsync"); m.def("create_context", &py_createContext, py::return_value_policy::take_ownership); m.def("create_tensor", &py_createTensor, py::return_value_policy::take_ownership); - m.def("create_kernel", &py_createKernel, py::return_value_policy::take_ownership); + m.def("create_kernel", &py_createKernel); m.def("create_kernel_code", &py_createKernelCode, py::return_value_policy::take_ownership); m.def("dispatch_kernel", &py_dispatchKernel, py::return_value_policy::take_ownership); m.def("wait", &py_wait, "Wait for GPU"); From 89f9097b2cf316733ed30607561ac473a770a2a3 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Fri, 7 Feb 2025 22:01:46 +0900 Subject: [PATCH 44/44] Fix haskell binding --- bindings/haskell/gpu-cpp.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/haskell/gpu-cpp.cabal b/bindings/haskell/gpu-cpp.cabal index 39bab54..90cb4fa 100644 --- a/bindings/haskell/gpu-cpp.cabal +++ b/bindings/haskell/gpu-cpp.cabal @@ -26,7 +26,7 @@ library hs-source-dirs: src default-language: Haskell2010 ghc-options: -optcxx-std=c++17 - extra-libraries: dawn + extra-libraries: webgpu_dawn executable gpu-cpp import: warnings