Skip to content

Commit e94aa02

Browse files
authored
Merge pull request #71 from junjihashimoto/feature/reduce
Add summation kernels
2 parents 0e89e65 + 189375f commit e94aa02

File tree

4 files changed

+602
-0
lines changed

4 files changed

+602
-0
lines changed

experimental/kernels/Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ endif
2929

3030
default: run-native
3131

32+
build/reduce: reduce.cpp kernels.h
33+
$(CC) $(CFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $<
34+
$(LIBSPEC) && build/reduce
35+
3236
run_llm.c: ./build/test_gpt2 dawnlib
3337
$(LIBSPEC) && $<
3438

experimental/kernels/kernels.h

+72
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
781781
}
782782
)";
783783

784+
static const char *kSum = R"(
785+
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
786+
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
787+
var<workgroup> buffer: array<{{precision}}, 1024>;
788+
@compute @workgroup_size({{workgroupSize}})
789+
fn main(
790+
@builtin(global_invocation_id) globalID : vec3<u32>,
791+
@builtin(local_invocation_id) localID : vec3<u32>,
792+
@builtin(workgroup_id) groupid : vec3<u32>,
793+
@builtin(num_workgroups) numGroups : vec3<u32>) {
794+
let blockSize3d: vec3<u32> = vec3({{workgroupSize}});
795+
let blockSize: u32 = blockSize3d.x;
796+
let threadId: u32 = localID.x;
797+
let blockId: u32 = groupid.x + groupid.y * numGroups.x;
798+
let blockStart = blockId * blockSize * 2 + threadId;
799+
800+
buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize];
801+
workgroupBarrier();
802+
var stride: u32 = blockSize / 2;
803+
804+
if (blockSize >= 1024 && threadId < 512) {
805+
buffer[threadId] += buffer[threadId + 512];
806+
}
807+
workgroupBarrier();
808+
809+
if (blockSize >= 512 && threadId < 256) {
810+
buffer[threadId] += buffer[threadId + 256];
811+
}
812+
workgroupBarrier();
813+
814+
if (blockSize >= 256 && threadId < 128) {
815+
buffer[threadId] += buffer[threadId + 128];
816+
}
817+
workgroupBarrier();
818+
819+
if (threadId < 64) {
820+
buffer[threadId] += buffer[threadId + 64];
821+
}
822+
workgroupBarrier();
823+
824+
if (threadId < 32) {
825+
buffer[threadId] += buffer[threadId + 32];
826+
}
827+
workgroupBarrier();
828+
829+
if (threadId < 16) {
830+
buffer[threadId] += buffer[threadId + 16];
831+
}
832+
workgroupBarrier();
833+
834+
if (threadId < 8) {
835+
buffer[threadId] += buffer[threadId + 8];
836+
}
837+
workgroupBarrier();
838+
839+
if (threadId < 4) {
840+
buffer[threadId] += buffer[threadId + 4];
841+
}
842+
workgroupBarrier();
843+
844+
if (threadId < 2) {
845+
buffer[threadId] += buffer[threadId + 2];
846+
}
847+
workgroupBarrier();
848+
849+
if (threadId == 0) {
850+
buffer[0] += buffer[1];
851+
out[blockId] = buffer[0];
852+
}
853+
}
854+
)";
855+
784856
} // namespace gpu
785857

786858
#endif // KERNELS_H

0 commit comments

Comments
 (0)