@@ -683,6 +683,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
683
683
}
684
684
)" ;
685
685
686
+ static const char *kSum = R"(
687
+ @group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
688
+ @group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
689
+ var<workgroup> buffer: array<{{precision}}, 1024>;
690
+ @compute @workgroup_size({{workgroupSize}})
691
+ fn main(
692
+ @builtin(global_invocation_id) globalID : vec3<u32>,
693
+ @builtin(local_invocation_id) localID : vec3<u32>,
694
+ @builtin(workgroup_id) groupid : vec3<u32>,
695
+ @builtin(num_workgroups) numGroups : vec3<u32>) {
696
+ let blockSize3d: vec3<u32> = vec3({{workgroupSize}});
697
+ let blockSize: u32 = blockSize3d.x;
698
+ let threadId: u32 = localID.x;
699
+ let blockId: u32 = groupid.x + groupid.y * numGroups.x;
700
+ let blockStart = blockId * blockSize * 2 + threadId;
701
+
702
+ buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize];
703
+ workgroupBarrier();
704
+ var stride: u32 = blockSize / 2;
705
+
706
+ if (blockSize >= 1024 && threadId < 512) {
707
+ buffer[threadId] += buffer[threadId + 512];
708
+ }
709
+ workgroupBarrier();
710
+
711
+ if (blockSize >= 512 && threadId < 256) {
712
+ buffer[threadId] += buffer[threadId + 256];
713
+ }
714
+ workgroupBarrier();
715
+
716
+ if (blockSize >= 256 && threadId < 128) {
717
+ buffer[threadId] += buffer[threadId + 128];
718
+ }
719
+ workgroupBarrier();
720
+
721
+ if (threadId < 64) {
722
+ buffer[threadId] += buffer[threadId + 64];
723
+ }
724
+ workgroupBarrier();
725
+
726
+ if (threadId < 32) {
727
+ buffer[threadId] += buffer[threadId + 32];
728
+ }
729
+ workgroupBarrier();
730
+
731
+ if (threadId < 16) {
732
+ buffer[threadId] += buffer[threadId + 16];
733
+ }
734
+ workgroupBarrier();
735
+
736
+ if (threadId < 8) {
737
+ buffer[threadId] += buffer[threadId + 8];
738
+ }
739
+ workgroupBarrier();
740
+
741
+ if (threadId < 4) {
742
+ buffer[threadId] += buffer[threadId + 4];
743
+ }
744
+ workgroupBarrier();
745
+
746
+ if (threadId < 2) {
747
+ buffer[threadId] += buffer[threadId + 2];
748
+ }
749
+ workgroupBarrier();
750
+
751
+ if (threadId == 0) {
752
+ buffer[0] += buffer[1];
753
+ out[blockId] = buffer[0];
754
+ }
755
+ }
756
+ )" ;
757
+
686
758
} // namespace gpu
687
759
688
760
#endif // KERNELS_H
0 commit comments