From f3e0dbca692d6d3aee0323854256bfa408a231bf Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Wed, 30 Oct 2024 17:48:42 +0900 Subject: [PATCH 1/3] Add summantion kernels --- experimental/kernels/Makefile | 4 + experimental/kernels/kernels.h | 72 ++++++ experimental/kernels/reduce.cpp | 415 ++++++++++++++++++++++++++++++++ gpu.hpp | 8 + 4 files changed, 499 insertions(+) create mode 100644 experimental/kernels/reduce.cpp diff --git a/experimental/kernels/Makefile b/experimental/kernels/Makefile index c233ef5..90da2fe 100644 --- a/experimental/kernels/Makefile +++ b/experimental/kernels/Makefile @@ -29,6 +29,10 @@ endif default: run-native +build/reduce: reduce.cpp kernels.h + $(CC) $(CFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $< + $(LIBSPEC) && build/reduce + run_llm.c: ./build/test_gpt2 dawnlib $(LIBSPEC) && $< diff --git a/experimental/kernels/kernels.h b/experimental/kernels/kernels.h index 212c075..b8a08f8 100644 --- a/experimental/kernels/kernels.h +++ b/experimental/kernels/kernels.h @@ -683,6 +683,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3) { } )"; +static const char *kSum = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (blockSize >= 1024 && threadId < 512) { + buffer[threadId] += buffer[threadId + 512]; + } + workgroupBarrier(); + + if (blockSize >= 512 && threadId < 256) { + buffer[threadId] += buffer[threadId + 256]; + } + workgroupBarrier(); + + if (blockSize >= 256 && threadId < 128) { + buffer[threadId] += buffer[threadId + 128]; + } + workgroupBarrier(); + + if (threadId < 64) { + buffer[threadId] += buffer[threadId + 64]; + } + workgroupBarrier(); + + if (threadId < 32) { + buffer[threadId] += buffer[threadId + 32]; + } + workgroupBarrier(); + + if (threadId < 16) { + buffer[threadId] += buffer[threadId + 16]; + } + workgroupBarrier(); + + if (threadId < 8) { + buffer[threadId] += buffer[threadId + 8]; + } + workgroupBarrier(); + + if (threadId < 4) { + buffer[threadId] += buffer[threadId + 4]; + } + workgroupBarrier(); + + if (threadId < 2) { + buffer[threadId] += buffer[threadId + 2]; + } + workgroupBarrier(); + + if (threadId == 0) { + buffer[0] += buffer[1]; + out[blockId] = buffer[0]; + } +} +)"; + } // namespace gpu #endif // KERNELS_H diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp new file mode 100644 index 0000000..13c6c40 --- /dev/null +++ b/experimental/kernels/reduce.cpp @@ -0,0 +1,415 @@ +#include "gpu.hpp" +#include +#include +#include +#include +#include +#include "utils/array_utils.hpp" // show, isclose, randn, randint +#include "kernels.h" + +using namespace gpu; + +#define LIMITS { \ + .nextInChain = nullptr, \ + .limits = { \ + .maxTextureDimension1D=8192, \ + .maxTextureDimension2D=8192, \ + .maxTextureDimension3D=2048, \ + .maxTextureArrayLayers=256, \ + .maxBindGroups=4, \ + .maxBindGroupsPlusVertexBuffers=24, \ + .maxBindingsPerBindGroup=1000, \ + .maxDynamicUniformBuffersPerPipelineLayout=8, \ + .maxDynamicStorageBuffersPerPipelineLayout=4, \ + .maxSampledTexturesPerShaderStage=16, \ + .maxSamplersPerShaderStage=16, \ + .maxStorageBuffersPerShaderStage=8, \ + .maxStorageTexturesPerShaderStage=4, \ + .maxUniformBuffersPerShaderStage=12, \ + .maxUniformBufferBindingSize=65536, \ + .maxStorageBufferBindingSize=1073741824, \ + .minUniformBufferOffsetAlignment=256, \ + .minStorageBufferOffsetAlignment=256, \ + .maxVertexBuffers=8, \ + .maxBufferSize=0x80000000, \ + .maxVertexAttributes=16, \ + .maxVertexBufferArrayStride=2048, \ + .maxInterStageShaderComponents=64, \ + .maxInterStageShaderVariables=16, \ + .maxColorAttachments=8, \ + .maxColorAttachmentBytesPerSample=32, \ + .maxComputeWorkgroupStorageSize=16384, \ + .maxComputeInvocationsPerWorkgroup=1024, \ + .maxComputeWorkgroupSizeX=1024, \ + .maxComputeWorkgroupSizeY=1024, \ + .maxComputeWorkgroupSizeZ=64, \ + .maxComputeWorkgroupsPerDimension=65535 \ + } \ + } + + +struct DurationTime { + std::chrono::high_resolution_clock::time_point start; + std::chrono::high_resolution_clock::time_point end; + std::chrono::microseconds duration; + std::string src; + bool verbose; + int num; + + inline DurationTime(const std::string& src, bool verbose = true, int num = 1) { + this->src = src; + this->verbose = verbose; + this->num = num; + start = std::chrono::high_resolution_clock::now(); + } + + inline ~DurationTime() { + end = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end - start); + if (this->verbose) { + printf("Duration(%s): %.1f microseconds\n", src.c_str(), static_cast(duration.count()) / static_cast(num)); + } + } +}; + +static const char *kSumVersion1 = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + + for (var stride: u32 = blockSize / 2; stride > 0; stride /= 2) { + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + } + + if (threadId == 0) { + out[blockId] = buffer[0]; + } +} +)"; + +static const char *kSumVersion2 = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let n: u32 = arrayLength(&inp); + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/4 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/8 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/16 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/32 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/64 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/128 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/256 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/512 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + + stride /= 2; // 1/1024 + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + + if (threadId == 0) { + out[blockId] = buffer[0]; + } +} +)"; + +static const char *kSum2d = R"( +@group(0) @binding(0) var inp: array<{{precision}}>; +@group(0) @binding(1) var out: array<{{precision}}>; +@group(0) @binding(2) var params : Params; +struct Params { + N: u32, + C: u32, +}; +var buffer: array<{{precision}}, 1024>; +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3, + @builtin(num_workgroups) numGroups : vec3) { + let blockSize3d: vec3 = vec3({{workgroupSize}}); + let blockSize: u32 = blockSize3d.x; + let threadId: u32 = localID.x; + let blockId: u32 = groupid.x + groupid.y * numGroups.x; + let blockStart = blockId * blockSize * 2 + threadId; + + buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; + workgroupBarrier(); + var stride: u32 = blockSize / 2; + + if (blockSize >= 1024 && threadId < 512) { + buffer[threadId] += buffer[threadId + 512]; + } + workgroupBarrier(); + + if (blockSize >= 512 && threadId < 256) { + buffer[threadId] += buffer[threadId + 256]; + } + workgroupBarrier(); + + if (blockSize >= 256 && threadId < 128) { + buffer[threadId] += buffer[threadId + 128]; + } + workgroupBarrier(); + + if (threadId < 64) { + buffer[threadId] += buffer[threadId + 64]; + } + workgroupBarrier(); + + if (threadId < 32) { + buffer[threadId] += buffer[threadId + 32]; + } + workgroupBarrier(); + + if (threadId < 16) { + buffer[threadId] += buffer[threadId + 16]; + } + workgroupBarrier(); + + if (threadId < 8) { + buffer[threadId] += buffer[threadId + 8]; + } + workgroupBarrier(); + + if (threadId < 4) { + buffer[threadId] += buffer[threadId + 4]; + } + workgroupBarrier(); + + if (threadId < 2) { + buffer[threadId] += buffer[threadId + 2]; + } + workgroupBarrier(); + + if (threadId == 0) { + buffer[0] += buffer[1]; + out[blockId] = buffer[0]; + } +} +)"; + +float sum_cpu(const float* data, size_t size) { + float result = 0; + for (size_t i = 0; i < size; ++i) { + result += data[i]; + } + return result; +} + +Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) { + uint32_t num_threads = 1024; + uint32_t num_blocks = ((size + num_threads -1) / num_threads); + uint32_t size_x = 32768u < num_blocks ? 32768u : num_blocks; + uint32_t size_y = size_x == 32768u ? num_blocks / 32768u : 1; + size_x /= 2; + size_x = size_x < 1 ? 1 : size_x; + // print size_x, size_y + // printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); +} + +float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) { + WGPURequiredLimits requiredLimits = LIMITS; + uint32_t num_threads = 1024; + int nSum = round(log2(size) / log2(num_threads)); + int input_size = size; + unsigned long output_size = size; + std::vector outputs; + std::vector ops; + outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(1024*2))}, kf32)); + for(int i=size,j=0;i>0;i/=num_threads,j++){ + output_size = (output_size + num_threads - 1) / num_threads; + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(1024*2))}, kf32)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + // printf("size: %d\n", input_size); + input_size = output_size; + } + toGPU(ctx, data, outputs[0], size * sizeof(float)); + + + { + for(int i=size,j=0;i>0;i/=num_threads,j++){ + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, ops[j], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[j]); + } + } + + { + int nIter = 100; + DurationTime dt("GPU", true, nIter); + for (int t = 0; t < nIter; t++){ + for(int i=size,j=0;i>0;i/=num_threads,j++){ + std::promise promise; + std::future future = promise.get_future(); + dispatchKernel(ctx, ops[j], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[j]); + } + } + } + + float r = 0; + toCPU(ctx, outputs[outputs.size()-1], (void*)buffer, 4 * sizeof(float)); + + return buffer[0]; +} + +// float sum_gpu2d(Context& ctx, const float* data, const float* buffer, size_t size_x, size_t size_y) { +// WGPURequiredLimits requiredLimits = LIMITS; +// Tensor input = createTensor(ctx, Shape{size}, kf32, data); +// Tensor output = createTensor(ctx, Shape{size}, kf32); +// uint32_t num_threads = 1024; +// uint32_t num_blocks = ((size_x + num_threads -1) / num_threads); +// printf("size: %u, size_x: %u, size_y: %u\n", size, size_x, size_y); +// Kernel op = createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); +// +// { +// for (int i = 0; i < 100; ++i){ +// DurationTime dt("GPU"); +// std::promise promise; +// std::future future = promise.get_future(); +// dispatchKernel(ctx, op, promise); +// wait(ctx, future); +// resetCommandBuffer(ctx.device, op); +// } +// } +// +// float r = 0; +// toCPU(ctx, output, (void*)buffer, num_blocks * sizeof(float)); +// +// for (int i = 0; i < num_blocks; i++){ +// r+=buffer[i]; +// } +// return r; +// } + +int main(int argc, char **argv) { + static constexpr size_t M = 4096*2; + static constexpr size_t N = 4096*2; + static constexpr size_t BUF_SIZE = 16; + std::unique_ptr inputArr = std::make_unique(M * N); + std::unique_ptr buffer = std::make_unique(BUF_SIZE); + std::mt19937 gen(314159); + printf("Initializing %zu values\n", M*N); + randn(inputArr.get(), M*N, gen); + // for(int i=0;i= 1e-0f) { + printf("Error: diff = %.6f\n", diff); + } else { + printf("Success: diff = %.6f\n", diff); + } + + printf("Computed %zu values of kSum(x)\n\n", M*N); + return 0; +} diff --git a/gpu.hpp b/gpu.hpp index 83fc94b..8047646 100644 --- a/gpu.hpp +++ b/gpu.hpp @@ -1119,6 +1119,14 @@ inline void toGPU(Context &ctx, const half *data, Tensor &tensor) { tensor.data.size); } +inline void toGPU(Context &ctx, const float *data, Tensor &tensor, size_t size) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, size); +} + +inline void toGPU(Context &ctx, const half *data, Tensor &tensor, size_t size) { + wgpuQueueWriteBuffer(ctx.queue, tensor.data.buffer, 0, data, size); +} + template inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) { // TODO(avh): Maintain params metadata in Kernel and check for consistency. From f956f2b78b08bbfd0212efa53b3d1fd90d9b0941 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Thu, 31 Oct 2024 04:14:51 +0900 Subject: [PATCH 2/3] Add SumKernel --- experimental/kernels/reduce.cpp | 61 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp index 13c6c40..e1bc387 100644 --- a/experimental/kernels/reduce.cpp +++ b/experimental/kernels/reduce.cpp @@ -285,51 +285,55 @@ Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); } -float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) { - WGPURequiredLimits requiredLimits = LIMITS; - uint32_t num_threads = 1024; - int nSum = round(log2(size) / log2(num_threads)); - int input_size = size; - unsigned long output_size = size; +struct SumKernel { std::vector outputs; std::vector ops; - outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(1024*2))}, kf32)); - for(int i=size,j=0;i>0;i/=num_threads,j++){ - output_size = (output_size + num_threads - 1) / num_threads; - outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(1024*2))}, kf32)); - ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); - // printf("size: %d\n", input_size); - input_size = output_size; - } - toGPU(ctx, data, outputs[0], size * sizeof(float)); - - - { + SumKernel(Context& ctx, size_t size) { + uint32_t num_threads = 1024; + int nSum = round(log2(size) / log2(num_threads)); + int input_size = size; + unsigned long output_size = size; + outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(num_threads*2))}, kf32)); for(int i=size,j=0;i>0;i/=num_threads,j++){ + output_size = (output_size + num_threads - 1) / num_threads; + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2))}, kf32)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + input_size = output_size; + } + } + void dispatchKernel(Context& ctx) { + for(int i=0;i promise; std::future future = promise.get_future(); - dispatchKernel(ctx, ops[j], promise); + gpu::dispatchKernel(ctx, ops[i], promise); wait(ctx, future); - resetCommandBuffer(ctx.device, ops[j]); + resetCommandBuffer(ctx.device, ops[i]); } } + void toGPU(Context& ctx, const float* data, size_t size) { + gpu::toGPU(ctx, data, outputs[0], size); + } + void toCPU(Context& ctx, float* data, size_t size) { + gpu::toCPU(ctx, outputs[outputs.size()-1], data, size); + } +}; + +float sum_gpu(Context& ctx, const float* data, float* buffer, size_t size) { + WGPURequiredLimits requiredLimits = LIMITS; + SumKernel sumKernel(ctx, size); + sumKernel.toGPU(ctx, data, size * sizeof(float)); + sumKernel.dispatchKernel(ctx); { int nIter = 100; DurationTime dt("GPU", true, nIter); for (int t = 0; t < nIter; t++){ - for(int i=size,j=0;i>0;i/=num_threads,j++){ - std::promise promise; - std::future future = promise.get_future(); - dispatchKernel(ctx, ops[j], promise); - wait(ctx, future); - resetCommandBuffer(ctx.device, ops[j]); - } + sumKernel.dispatchKernel(ctx); } } float r = 0; - toCPU(ctx, outputs[outputs.size()-1], (void*)buffer, 4 * sizeof(float)); + sumKernel.toCPU(ctx, buffer, 4 * sizeof(float)); return buffer[0]; } @@ -363,6 +367,7 @@ float sum_gpu(Context& ctx, const float* data, const float* buffer, size_t size) // return r; // } + int main(int argc, char **argv) { static constexpr size_t M = 4096*2; static constexpr size_t N = 4096*2; From c13833fd07bac1659ef530eec4dd3558a56233d9 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Sun, 3 Nov 2024 16:13:33 +0900 Subject: [PATCH 3/3] Add SumKernel2d --- experimental/kernels/reduce.cpp | 280 +++++++++++++++++++++----------- 1 file changed, 187 insertions(+), 93 deletions(-) diff --git a/experimental/kernels/reduce.cpp b/experimental/kernels/reduce.cpp index e1bc387..38cb6a7 100644 --- a/experimental/kernels/reduce.cpp +++ b/experimental/kernels/reduce.cpp @@ -199,68 +199,37 @@ struct Params { var buffer: array<{{precision}}, 1024>; @compute @workgroup_size({{workgroupSize}}) fn main( - @builtin(global_invocation_id) globalID : vec3, @builtin(local_invocation_id) localID : vec3, @builtin(workgroup_id) groupid : vec3, @builtin(num_workgroups) numGroups : vec3) { + let N : u32 = params.N; + let C : u32 = params.C; let blockSize3d: vec3 = vec3({{workgroupSize}}); let blockSize: u32 = blockSize3d.x; let threadId: u32 = localID.x; let blockId: u32 = groupid.x + groupid.y * numGroups.x; - let blockStart = blockId * blockSize * 2 + threadId; - - buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize]; - workgroupBarrier(); - var stride: u32 = blockSize / 2; - - if (blockSize >= 1024 && threadId < 512) { - buffer[threadId] += buffer[threadId + 512]; - } - workgroupBarrier(); - - if (blockSize >= 512 && threadId < 256) { - buffer[threadId] += buffer[threadId + 256]; - } - workgroupBarrier(); - if (blockSize >= 256 && threadId < 128) { - buffer[threadId] += buffer[threadId + 128]; - } - workgroupBarrier(); - - if (threadId < 64) { - buffer[threadId] += buffer[threadId + 64]; - } - workgroupBarrier(); - - if (threadId < 32) { - buffer[threadId] += buffer[threadId + 32]; - } - workgroupBarrier(); - - if (threadId < 16) { - buffer[threadId] += buffer[threadId + 16]; - } - workgroupBarrier(); - - if (threadId < 8) { - buffer[threadId] += buffer[threadId + 8]; - } - workgroupBarrier(); - - if (threadId < 4) { - buffer[threadId] += buffer[threadId + 4]; - } - workgroupBarrier(); - - if (threadId < 2) { - buffer[threadId] += buffer[threadId + 2]; - } - workgroupBarrier(); - - if (threadId == 0) { - buffer[0] += buffer[1]; - out[blockId] = buffer[0]; + for (var i: u32 = 0; i= N) { + } else if(blockStart + blockSize >= N) { + buffer[threadId] = inp[blockStart * C + i]; + } else { + buffer[threadId] = inp[blockStart * C + i] + inp[(blockStart + blockSize) * C + i]; + } + workgroupBarrier(); + + for (var stride: u32 = blockSize / 2; stride > 0; stride /= 2) { + if (threadId < stride) { + buffer[threadId] += buffer[threadId + stride]; + } + workgroupBarrier(); + } + + if (threadId == 0) { + out[blockId * C + i] = buffer[0]; + } + workgroupBarrier(); } } )"; @@ -273,33 +242,100 @@ float sum_cpu(const float* data, size_t size) { return result; } -Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size) { - uint32_t num_threads = 1024; +void sum_cpu_2d(const float* data, float* out, size_t size0, size_t size1) { + float result = 0; + for (size_t j = 0; j < size1; ++j) { + out[j] = 0; + } + for (size_t i = 0; i < size0; ++i) { + for (size_t j = 0; j < size1; ++j) { + out[j] += data[(i * size1) + j]; + } + } +} + +Kernel createSumKernel(Context& ctx, Tensor& input, Tensor& output, size_t size, uint32_t num_threads = 1024) { uint32_t num_blocks = ((size + num_threads -1) / num_threads); uint32_t size_x = 32768u < num_blocks ? 32768u : num_blocks; uint32_t size_y = size_x == 32768u ? num_blocks / 32768u : 1; size_x /= 2; size_x = size_x < 1 ? 1 : size_x; // print size_x, size_y - // printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); return createKernel(ctx, {kSum, num_threads, kf32}, Bindings{input, output}, {size_x, size_y, 1}); } +Kernel createSumKernel2d(Context& ctx, Tensor& input, Tensor& output, size_t size0, size_t size1, uint32_t num_threads = 1024) { + struct Params { + uint32_t N; + uint32_t C; + }; + uint32_t num_blocks = ((size0 + num_threads -1) / num_threads); + uint32_t size_x = num_blocks; + uint32_t size_y = size1; + size_x /= 2; + size_x = size_x < 1 ? 1 : size_x; + printf("size_x: %u, size_y: %u, num_blocks: %u\n", size_x, size_y, num_blocks); + return createKernel(ctx, + {kSum2d, num_threads, kf32}, + Bindings{input, output}, + {size_x, size_y, 1}, + Params{ + static_cast(size0), + static_cast(size1), + }); +} + struct SumKernel { std::vector outputs; std::vector ops; - SumKernel(Context& ctx, size_t size) { - uint32_t num_threads = 1024; - int nSum = round(log2(size) / log2(num_threads)); + SumKernel(Context& ctx, size_t size, uint32_t num_threads = 1024) { int input_size = size; unsigned long output_size = size; outputs.push_back(createTensor(ctx, Shape{std::max(size, static_cast(num_threads*2))}, kf32)); - for(int i=size,j=0;i>0;i/=num_threads,j++){ - output_size = (output_size + num_threads - 1) / num_threads; + for(int j=0;output_size>1;j++){ + output_size = (output_size + (num_threads * 2) - 1) / (num_threads * 2); outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2))}, kf32)); - ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size)); + ops.push_back(createSumKernel(ctx, outputs[j], outputs[j+1], input_size, num_threads)); + input_size = output_size; + } + } + void dispatchKernel(Context& ctx) { + for(int i=0;i promise; + std::future future = promise.get_future(); + gpu::dispatchKernel(ctx, ops[i], promise); + wait(ctx, future); + resetCommandBuffer(ctx.device, ops[i]); + } + } + void toGPU(Context& ctx, const float* data, size_t size) { + gpu::toGPU(ctx, data, outputs[0], size); + } + void toCPU(Context& ctx, float* data, size_t size) { + gpu::toCPU(ctx, outputs[outputs.size()-1], data, size); + } +}; + +struct SumKernel2d { + std::vector outputs; + std::vector ops; + bool debug; + SumKernel2d(Context& ctx, size_t size0, size_t size1, uint32_t num_threads = 1024) { + debug = false; + int input_size = size0; + unsigned long output_size = size0; + outputs.push_back(createTensor(ctx, Shape{std::max(size0, static_cast(num_threads*2)),size1}, kf32)); + for(int j=0;output_size>1;j++){ + output_size = (output_size + (num_threads * 2) - 1) / (num_threads * 2); + if (debug) + printf("size0: %d, num_threads: %d, output_size: %d\n", size0, num_threads, output_size); + outputs.push_back(createTensor(ctx, Shape{std::max(output_size, static_cast(num_threads*2)), size1}, kf32)); + ops.push_back(createSumKernel2d(ctx, outputs[j], outputs[j+1], input_size, size1, num_threads)); input_size = output_size; } + if (debug) + printf("ops.size(): %d\n", ops.size()); } void dispatchKernel(Context& ctx) { for(int i=0;i buffer = std::make_unique(8); + for(int i=0;i promise; -// std::future future = promise.get_future(); -// dispatchKernel(ctx, op, promise); -// wait(ctx, future); -// resetCommandBuffer(ctx.device, op); -// } -// } -// -// float r = 0; -// toCPU(ctx, output, (void*)buffer, num_blocks * sizeof(float)); -// -// for (int i = 0; i < num_blocks; i++){ -// r+=buffer[i]; -// } -// return r; -// } +void sum_gpu_2d(Context& ctx, const float* data, float* out, size_t size0, size_t size1) { + WGPURequiredLimits requiredLimits = LIMITS; + SumKernel2d sumKernel(ctx, size0, size1); + sumKernel.toGPU(ctx, data, size0 * size1 * sizeof(float)); + sumKernel.dispatchKernel(ctx); + + { + int nIter = 3; + DurationTime dt("GPU", true, nIter); + for (int t = 0; t < nIter; t++){ + sumKernel.dispatchKernel(ctx); + } + } + sumKernel.toCPU(ctx, out, size1 * sizeof(float)); +} -int main(int argc, char **argv) { +int main_1d(int argc, char **argv) { static constexpr size_t M = 4096*2; static constexpr size_t N = 4096*2; static constexpr size_t BUF_SIZE = 16; @@ -389,7 +423,6 @@ int main(int argc, char **argv) { gpu::Context ctx = gpu::createContext({}, {}, { .requiredLimits = &requiredLimits }); - Tensor input = createTensor(ctx, Shape{M*N}, kf32, inputArr.get()); printf("Start testing sum(x) on %zu values\n", M*N); cpu_result = sum_cpu(inputArr.get(), M*N); @@ -418,3 +451,64 @@ int main(int argc, char **argv) { printf("Computed %zu values of kSum(x)\n\n", M*N); return 0; } + +int main_2d(int argc, char **argv) { + static constexpr size_t M = 4096; + static constexpr size_t N = 4096; + std::unique_ptr inputArr = std::make_unique(M * N); + std::unique_ptr outputCpuArr = std::make_unique(N); + std::unique_ptr outputGpuArr = std::make_unique(N); + std::mt19937 gen(314159); + printf("Initializing %zu values\n", M*N); + randn(inputArr.get(), M*N, gen); + for(int i=0;i= 1e-0f) { + printf("Error: diff = %.6f\n", diff); + } else { + printf("Success: diff = %.6f\n", diff); + } + + return 0; +} + +int main(int argc, char **argv) { + printf("================================\n"); + printf("Start testing reduce-1d\n"); + main_1d(argc,argv); + printf("================================\n"); + printf("Start testing reduce-2d\n"); + main_2d(argc,argv); + return 0; +}