@@ -781,6 +781,78 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
781
781
}
782
782
)" ;
783
783
784
+ static const char *kSum = R"(
785
+ @group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>;
786
+ @group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>;
787
+ var<workgroup> buffer: array<{{precision}}, 1024>;
788
+ @compute @workgroup_size({{workgroupSize}})
789
+ fn main(
790
+ @builtin(global_invocation_id) globalID : vec3<u32>,
791
+ @builtin(local_invocation_id) localID : vec3<u32>,
792
+ @builtin(workgroup_id) groupid : vec3<u32>,
793
+ @builtin(num_workgroups) numGroups : vec3<u32>) {
794
+ let blockSize3d: vec3<u32> = vec3({{workgroupSize}});
795
+ let blockSize: u32 = blockSize3d.x;
796
+ let threadId: u32 = localID.x;
797
+ let blockId: u32 = groupid.x + groupid.y * numGroups.x;
798
+ let blockStart = blockId * blockSize * 2 + threadId;
799
+
800
+ buffer[threadId] = inp[blockStart] + inp[blockStart + blockSize];
801
+ workgroupBarrier();
802
+ var stride: u32 = blockSize / 2;
803
+
804
+ if (blockSize >= 1024 && threadId < 512) {
805
+ buffer[threadId] += buffer[threadId + 512];
806
+ }
807
+ workgroupBarrier();
808
+
809
+ if (blockSize >= 512 && threadId < 256) {
810
+ buffer[threadId] += buffer[threadId + 256];
811
+ }
812
+ workgroupBarrier();
813
+
814
+ if (blockSize >= 256 && threadId < 128) {
815
+ buffer[threadId] += buffer[threadId + 128];
816
+ }
817
+ workgroupBarrier();
818
+
819
+ if (threadId < 64) {
820
+ buffer[threadId] += buffer[threadId + 64];
821
+ }
822
+ workgroupBarrier();
823
+
824
+ if (threadId < 32) {
825
+ buffer[threadId] += buffer[threadId + 32];
826
+ }
827
+ workgroupBarrier();
828
+
829
+ if (threadId < 16) {
830
+ buffer[threadId] += buffer[threadId + 16];
831
+ }
832
+ workgroupBarrier();
833
+
834
+ if (threadId < 8) {
835
+ buffer[threadId] += buffer[threadId + 8];
836
+ }
837
+ workgroupBarrier();
838
+
839
+ if (threadId < 4) {
840
+ buffer[threadId] += buffer[threadId + 4];
841
+ }
842
+ workgroupBarrier();
843
+
844
+ if (threadId < 2) {
845
+ buffer[threadId] += buffer[threadId + 2];
846
+ }
847
+ workgroupBarrier();
848
+
849
+ if (threadId == 0) {
850
+ buffer[0] += buffer[1];
851
+ out[blockId] = buffer[0];
852
+ }
853
+ }
854
+ )" ;
855
+
784
856
} // namespace gpu
785
857
786
858
#endif // KERNELS_H
0 commit comments