Skip to content

Commit f3e0dbc

Browse files
Add summantion kernels
1 parent 9a42592 commit f3e0dbc

File tree

4 files changed

+499
-0
lines changed

4 files changed

+499
-0
lines changed

experimental/kernels/Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ endif
2929

3030
default: run-native
3131

32+
build/reduce: reduce.cpp kernels.h
33+
$(CC) $(CFLAGS) $(CXXFLAGS) $(LDFLAGS) -o $@ $<
34+
$(LIBSPEC) && build/reduce
35+
3236
run_llm.c: ./build/test_gpt2 dawnlib
3337
$(LIBSPEC) && $<
3438

experimental/kernels/kernels.h

+72
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
683683
}
684684
)";
685685

686+
static const char *kSum = R"(
687+
@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
688+
@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
689+
var<workgroup> buffer: array<{{precision}}, 1024>;
690+
@compute @workgroup_size({{workgroupSize}})
691+
fn main(
692+
@builtin(global_invocation_id) globalID : vec3<u32>,
693+
@builtin(local_invocation_id) localID : vec3<u32>,
694+
@builtin(workgroup_id) groupid : vec3<u32>,
695+
@builtin(num_workgroups) numGroups : vec3<u32>) {
696+
let blockSize3d: vec3<u32> = vec3({{workgroupSize}});
697+
let blockSize: u32 = blockSize3d.x;
698+
let threadId: u32 = localID.x;
699+
let blockId: u32 = groupid.x + groupid.y * numGroups.x;
700+
let blockStart = blockId * blockSize * 2 + threadId;
701+
702+
buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize];
703+
workgroupBarrier();
704+
var stride: u32 = blockSize / 2;
705+
706+
if (blockSize >= 1024 && threadId < 512) {
707+
buffer[threadId] += buffer[threadId + 512];
708+
}
709+
workgroupBarrier();
710+
711+
if (blockSize >= 512 && threadId < 256) {
712+
buffer[threadId] += buffer[threadId + 256];
713+
}
714+
workgroupBarrier();
715+
716+
if (blockSize >= 256 && threadId < 128) {
717+
buffer[threadId] += buffer[threadId + 128];
718+
}
719+
workgroupBarrier();
720+
721+
if (threadId < 64) {
722+
buffer[threadId] += buffer[threadId + 64];
723+
}
724+
workgroupBarrier();
725+
726+
if (threadId < 32) {
727+
buffer[threadId] += buffer[threadId + 32];
728+
}
729+
workgroupBarrier();
730+
731+
if (threadId < 16) {
732+
buffer[threadId] += buffer[threadId + 16];
733+
}
734+
workgroupBarrier();
735+
736+
if (threadId < 8) {
737+
buffer[threadId] += buffer[threadId + 8];
738+
}
739+
workgroupBarrier();
740+
741+
if (threadId < 4) {
742+
buffer[threadId] += buffer[threadId + 4];
743+
}
744+
workgroupBarrier();
745+
746+
if (threadId < 2) {
747+
buffer[threadId] += buffer[threadId + 2];
748+
}
749+
workgroupBarrier();
750+
751+
if (threadId == 0) {
752+
buffer[0] += buffer[1];
753+
out[blockId] = buffer[0];
754+
}
755+
}
756+
)";
757+
686758
} // namespace gpu
687759

688760
#endif // KERNELS_H

0 commit comments

Comments
 (0)