@@ -558,7 +558,147 @@ void puzzle12(Context &ctx) {
558
558
}
559
559
560
560
561
- // Puzzles 13 and 14 Coming Soon!
561
+ // Puzzle 13 : Axis Sum
562
+ // Implement a kernel that computes a sum over each column of a and stores it in out.
563
+
564
+ const char *kPuzzle13 = R"(
565
+ @group(0) @binding(0) var<storage, read_write> a: array<f32>;
566
+ @group(0) @binding(1) var<storage, read_write> output: array<f32>;
567
+ @group(0) @binding(2) var<uniform> params: Params;
568
+
569
+ struct Params {
570
+ TPB: u32,
571
+ size: u32,
572
+ };
573
+
574
+ var<workgroup> cache: array<f32, 256>;
575
+
576
+ @compute @workgroup_size({{workgroupSize}})
577
+ fn main(
578
+ @builtin(local_invocation_id) LocalInvocationID: vec3<u32>,
579
+ @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
580
+ let i = GlobalInvocationID.x;
581
+ let local_i = LocalInvocationID.x;
582
+ let batch = GlobalInvocationID.y;
583
+
584
+ if (i < params.size) {
585
+ // Copy and sync
586
+ cache[local_i] = a[batch * params.size + i];
587
+
588
+ }
589
+
590
+ workgroupBarrier();
591
+
592
+
593
+ // Sum over each col
594
+ if (i < params.size) {
595
+ for (var k: u32 = 0u; k < 3u; k = k + 1u) {
596
+ let p = 1u << k;
597
+ if (local_i % (p * 2u) == 0u && i + p < params.size) {
598
+ cache[local_i] = cache[local_i] + cache[local_i + p];
599
+ }
600
+ }
601
+ }
602
+
603
+ workgroupBarrier();
604
+
605
+ // Each block corresponds to a different output position
606
+ if (local_i == 0u) {
607
+ output[batch] = cache[0];
608
+ }
609
+ }
610
+ )" ;
611
+ void puzzle13 (Context &ctx) {
612
+ printf (" \n\n Puzzle 13\n\n " );
613
+ static constexpr size_t N = 6 ;
614
+ static constexpr size_t TPB = 8 ;
615
+ static constexpr size_t BATCH = 4 ;
616
+ Tensor a = createTensor (ctx, {BATCH, N}, kf32, makeData<N * BATCH>().data ());
617
+ Tensor output = createTensor (ctx, {BATCH}, kf32);
618
+ struct Params {
619
+ uint32_t TPB = TPB;
620
+ uint32_t size = N;
621
+ };
622
+
623
+ Kernel op =
624
+ createKernel (ctx, {kPuzzle13 , {TPB, 1 , 1 }},
625
+ Bindings{a, output}, {1 , BATCH, 1 }, Params{TPB, N});
626
+ showResult<BATCH>(ctx, op, output);
627
+ }
628
+
629
+ // Puzzle 14 : Matrix Multiply!!
630
+ // Implement a kernel that computes the matrix product of a and b and stores it in out.
631
+ // Tip: The most efficient algorithm here will copy a block into shared memory before
632
+ // computing each of the individual row-column dot products. This is easy to do if the
633
+ // matrix fits in shared memory. Do that case first. Then update your code to compute a
634
+ // partial dot-product and iteratively move the part you copied into shared memory. You
635
+ // should be able to do the hard case in 6 global reads.
636
+ const char *kPuzzle14 = R"(
637
+ @group(0) @binding(0) var<storage, read_write> a: array<f32>;
638
+ @group(0) @binding(1) var<storage, read_write> b: array<f32>;
639
+ @group(0) @binding(2) var<storage, read_write> output: array<f32>;
640
+ @group(0) @binding(3) var<uniform> params: Params;
641
+
642
+ struct Params {
643
+ TPB: u32,
644
+ size: u32,
645
+ };
646
+
647
+ var<workgroup> a_shared: array<f32, 256>;
648
+ var<workgroup> b_shared: array<f32, 256>;
649
+
650
+ @compute @workgroup_size({{workgroupSize}})
651
+ fn main(
652
+ @builtin(local_invocation_id) LocalInvocationID: vec3<u32>,
653
+ @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
654
+ let i = GlobalInvocationID.x;
655
+ let j = GlobalInvocationID.y;
656
+ let local_i = LocalInvocationID.x;
657
+ let local_j = LocalInvocationID.y;
658
+
659
+ var acc: f32 = 0.0;
660
+
661
+ for (var k: u32 = 0u; k < params.size; k = k + params.TPB) {
662
+ // Copy in blocks
663
+ if (i < params.size && k + local_j < params.size) {
664
+ a_shared[local_i * params.TPB + local_j] = a[i * params.size + (k + local_j)];
665
+ }
666
+ if (j < params.size && k + local_i < params.size) {
667
+ b_shared[local_j * params.TPB + local_i] = b[(k + local_i) * params.size + j];
668
+ }
669
+ workgroupBarrier();
670
+
671
+ // Matrix Multiply
672
+ let local_k_max = min(params.TPB, params.size - k);
673
+ for (var local_k: u32 = 0u; local_k < local_k_max; local_k = local_k + 1u) {
674
+ acc += a_shared[local_i * params.TPB + local_k] * b_shared[local_k * params.TPB + local_j];
675
+ }
676
+ workgroupBarrier();
677
+ }
678
+
679
+ // Copy to out
680
+ if (i < params.size && j < params.size) {
681
+ output[i * params.size + j] = acc;
682
+ }
683
+ }
684
+ )" ;
685
+ void puzzle14 (Context &ctx) {
686
+ printf (" \n\n Puzzle 14\n\n " );
687
+ static constexpr size_t N = 2 ;
688
+ static constexpr size_t TPB = 3 ;
689
+ Tensor a = createTensor (ctx, {N, N}, kf32, makeData<N * N>().data ());
690
+ Tensor b = createTensor (ctx, {N, N}, kf32, makeData<N * N>().data ());
691
+ Tensor output = createTensor (ctx, {N, N}, kf32);
692
+ struct Params {
693
+ uint32_t TPB = TPB;
694
+ uint32_t size = N;
695
+ };
696
+
697
+ Kernel op =
698
+ createKernel (ctx, {kPuzzle14 , {TPB, TPB, 1 }},
699
+ Bindings{a, b, output}, {1 , 1 , 1 }, Params{TPB, N});
700
+ showResult<N, N, N>(ctx, op, output);
701
+ }
562
702
563
703
564
704
int main (int argc, char **argv) {
@@ -575,7 +715,7 @@ int main(int argc, char **argv) {
575
715
puzzle10 (ctx);
576
716
puzzle11 (ctx);
577
717
puzzle12 (ctx);
578
- // puzzle13(ctx);
579
- // puzzle14(ctx);
718
+ puzzle13 (ctx);
719
+ puzzle14 (ctx);
580
720
return 0 ;
581
721
}
0 commit comments