@@ -530,6 +530,222 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
530
530
return Status::OK ();
531
531
}
532
532
533
+ Status DP4AMatMulQuantizeProgram::GenerateShaderCode (ShaderHelper& shader) const {
534
+ shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
535
+ shader.AddOutput (" output" , ShaderUsage::UseUniform);
536
+ shader.AddOutput (" scales" , ShaderUsage::UseUniform);
537
+
538
+ shader.AdditionalImplementation () << R"ADDNL_FN(
539
+ var<workgroup> max_values : array<input_a_element_t, 4>;
540
+ )ADDNL_FN" ;
541
+
542
+ shader.MainFunctionBody () << R"MAIN_FN(
543
+ var local_a = input_a[global_idx];
544
+ var max_val = subgroupMax(abs(local_a));
545
+ var max_temp = max(max_val.xy, max_val.zw);
546
+ var scale = max(max_temp[0], max_temp[1]);
547
+ if (local_idx % sg_size == 0) {
548
+ max_values[local_idx / sg_size] = scale;
549
+ }
550
+ workgroupBarrier();
551
+
552
+ if (sg_size == 8)
553
+ {
554
+ scale = max(max_values[0], max_values[1]);
555
+ scale = max(scale, max_values[2]);
556
+ scale = max(scale, max_values[3]);
557
+ }
558
+ else if (sg_size == 16)
559
+ {
560
+ scale = max(max_values[0], max_values[1]);
561
+ }
562
+ else
563
+ {
564
+ scale = max_values[0];
565
+ }
566
+
567
+ var norm_a = local_a/scale;
568
+ output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
569
+ if (local_idx == 0)
570
+ {
571
+ // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
572
+ scales[workgroup_idx] = scale/127;
573
+ }
574
+ )MAIN_FN" ;
575
+ return Status::OK ();
576
+ }
577
+
578
+ Status DP4AMatMulNBitsProgram::GenerateShaderCode (ShaderHelper& shader) const {
579
+ shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
580
+ shader.AddInput (" scales_a" , ShaderUsage::UseUniform);
581
+ shader.AddInput (" input_b" , ShaderUsage::UseUniform);
582
+ shader.AddInput (" scales_b" , ShaderUsage::UseUniform);
583
+ shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
584
+
585
+ // This shader implements co-operative matrix multiply. The key idea here is to
586
+ // assume there is a primitive for medium size matrix multiply a subgroup can perform,
587
+ // using all its lanes and pooling all its registers to keep the values in registry.
588
+ //
589
+ // The entire workgroup which has N subgroups first loads a tile into shared memory,
590
+ // Then each subgroup loads a subtile from shared memory into registers and uses
591
+ // the medium size matrix multiply primitive to perform the math.
592
+ // The values for tile/subtile size are chosen to conform to the resource limits
593
+ // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads -
594
+ // therefore there are 16 subgroups and 16 lanes in each subgroup.
595
+ // K the hidden dimension is paged in from RAM at k tile size which is 64.
596
+ // All this puts the shared memory requirement slightly above 16KB.
597
+ // WebGPU limit is 16KB, output is moved to registers instead of SHM to make
598
+ // everything fit in shared memory.
599
+ //
600
+ // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with
601
+ // subgroup shuffle as a placeholder for the day the medium matrix mul primitive
602
+ // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on
603
+ // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the
604
+ // 512B of registry from each lane.
605
+ //
606
+ // The medium size matmul is implemented using dot4I8Packed, so the inputs for
607
+ // this shader require A to be int8 quantized with block size 64. B is regular
608
+ // matmulnbits input with block size 32.
609
+
610
+ shader.AdditionalImplementation () << R"ADDNL_FN(
611
+ const tile_size = 64;
612
+ const subtile_size = 16;
613
+ const tile_size_k = 32;
614
+ const vec_factor = 4;
615
+ const u32_factor = 4;
616
+ const tile_size_k_vec = 4;
617
+ const block_size = 32;
618
+
619
+ // Shared memory
620
+ var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
621
+ var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622
+ var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
623
+ var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
624
+
625
+ // Private memory
626
+ var<private> lane_output: array<output_element_t, 16>;
627
+
628
+ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
629
+ {
630
+ let a_global = a_global_base + row;
631
+ if (a_global >= uniforms.M)
632
+ {
633
+ return;
634
+ }
635
+ tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
636
+ if (col == 0)
637
+ {
638
+ // kidx_v - covers 8 values of k
639
+ scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
640
+ }
641
+ }
642
+
643
+ fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32)
644
+ {
645
+ let b_global = b_global_base + row;
646
+ if (b_global >= uniforms.N)
647
+ {
648
+ return;
649
+ }
650
+
651
+ let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
652
+ var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
653
+ var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
654
+ tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
655
+ tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
656
+ if (col == 0)
657
+ {
658
+ // kidx_v - each kidx_v covers 8 values of k
659
+ scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
660
+ }
661
+ }
662
+
663
+ fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
664
+ {
665
+ var local_sum = dot4I8Packed(a[0], b[0]);
666
+ local_sum += dot4I8Packed(a[1], b[1]);
667
+ local_sum += dot4I8Packed(a[2], b[2]);
668
+ local_sum += dot4I8Packed(a[3], b[3]);
669
+ return local_sum;
670
+ }
671
+
672
+ )ADDNL_FN" ;
673
+
674
+ shader.MainFunctionBody () << R"MAIN_FN(
675
+ // During the load phase we use all 256 threads to load 64 rows of A/B.
676
+ // For each row we load 4 vectorized elements, which are 32 elements of K.
677
+ let a_global_base = workgroup_id.x * tile_size;
678
+ let b_global_base = workgroup_id.y * tile_size;
679
+ let load_row = u32(local_idx/4);
680
+ let load_col = u32(local_idx%4);
681
+
682
+ // During the compute phase, we have the 64x64 tile split into
683
+ // subtiles of 16x16. We have a grid of 4x4 subtiles.
684
+ let subtile_id = u32(local_idx / subtile_size);
685
+ let subtile_idx = u32(subtile_id / 4);
686
+ let subtile_idy = u32(subtile_id % 4);
687
+ let base_A = subtile_idx * 16;
688
+ let base_B = subtile_idy * 16;
689
+ // For each subtile we have 16 threads assigned.
690
+ let a_idx = u32(local_idx % subtile_size);
691
+
692
+ // K's vectrorization is 8 items per index. See input_a/input_b.
693
+ // tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
694
+ for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
695
+ {
696
+ // Populate shared memory for the workgroup
697
+ loadSHMA(a_global_base, kidx_v, load_row, load_col);
698
+ loadSHMB(b_global_base, kidx_v, load_row, load_col);
699
+ workgroupBarrier();
700
+
701
+ var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
702
+ var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
703
+ var own_scale_a = scale_A[base_A + a_idx];
704
+ if (sg_size == 16)
705
+ {
706
+ var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
707
+ var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
708
+ var own_scale_b = scale_B[base_B + sg_id];
709
+ for (var col:u32 = 0; col < 16; col++)
710
+ {
711
+ var local_scale_b = subgroupShuffle(own_scale_b, col);
712
+ local_scale_b = local_scale_b * own_scale_a;
713
+ var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
714
+ local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
715
+ lane_output[col] += (output_element_t(local_sum) * local_scale_b);
716
+ }
717
+ }
718
+ else
719
+ {
720
+ for (var col:u32 = 0; col < 16; col++)
721
+ {
722
+ var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
723
+ var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
724
+ var local_sum = DP4AI(own_a0, b0);
725
+ local_sum += DP4AI(own_a1, b1);
726
+ lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
727
+ }
728
+ }
729
+ workgroupBarrier();
730
+ }
731
+
732
+ let a_global = a_global_base + base_A + a_idx;
733
+ let b_global = b_global_base + base_B;
734
+ let output_idx = ((a_global) * uniforms.N + b_global)/4;
735
+ // This creates a shader requirement that uniforms.N % 16 == 0
736
+ if (a_global < uniforms.M && b_global < uniforms.N)
737
+ {
738
+ for (var i:u32 = 0; i < 4; i++)
739
+ {
740
+ let lidx = i * 4;
741
+ output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
742
+ }
743
+ }
744
+ )MAIN_FN" ;
745
+
746
+ return Status::OK ();
747
+ }
748
+
533
749
Status MatMulNBits::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
534
750
const Tensor* a = context.Input (0 );
535
751
const Tensor* b = context.Input (1 );
@@ -565,11 +781,54 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
565
781
uint32_t components = GetMaxComponents (N);
566
782
567
783
const bool has_zero_points = zero_points != nullptr ;
784
+ const bool has_subgroup = context.Device ().HasFeature (wgpu::FeatureName::Subgroups);
785
+ // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
786
+ // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
787
+ const bool use_dp4a = has_subgroup && context.AdapterInfo ().backendType != wgpu::BackendType::Metal;
788
+ if (accuracy_level_ == 4 && block_size == 32 &&
789
+ batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 &&
790
+ !has_zero_points && use_dp4a && M >= kMinMForTileOptimization ) {
791
+ constexpr uint32_t kVec4Components = 4 ;
792
+ constexpr uint32_t kVec2Components = 2 ;
793
+ constexpr uint32_t kU32Components = 4 ;
794
+
795
+ constexpr uint32_t kBlockSizeA = 128 ;
796
+ DP4AMatMulQuantizeProgram quantize_program;
797
+ quantize_program.SetWorkgroupSize (32 );
798
+ quantize_program.SetDispatchGroupSize (M * K / kBlockSizeA , 1 , 1 );
799
+ TensorShape a_quant_shape{1 , M, K / kU32Components };
800
+ Tensor a_quant = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), a_quant_shape);
801
+ TensorShapeVector a_scales_dims ({1 , 1 , M, K / kBlockSizeA });
802
+ Tensor a_scale = context.CreateGPUTensor (a->DataType (), a_scales_dims);
803
+ quantize_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec4Components )}})
804
+ .AddOutputs ({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape (), gsl::narrow<int >(1 )},
805
+ {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}});
806
+ ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
807
+
808
+ constexpr uint32_t kTileSize = 64 ;
809
+ TensorShape reshaped_y_shape{1 , M, N / kVec4Components };
810
+ DP4AMatMulNBitsProgram mul_program;
811
+ mul_program.SetWorkgroupSize (256 );
812
+ mul_program.SetDispatchGroupSize (
813
+ (M + kTileSize - 1 ) / kTileSize ,
814
+ (N + kTileSize - 1 ) / kTileSize , 1 );
815
+ mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec2Components )},
816
+ {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )},
817
+ {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kU32Components )},
818
+ {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )}})
819
+ .AddUniformVariables ({{static_cast <uint32_t >(M)},
820
+ {static_cast <uint32_t >(N)},
821
+ {static_cast <uint32_t >(K)},
822
+ {static_cast <uint32_t >(K / 8 )},
823
+ {static_cast <uint32_t >(K / 16 )}})
824
+ .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int >(kVec4Components )});
825
+ return context.RunProgram (mul_program);
826
+ }
568
827
569
828
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
570
829
constexpr uint32_t output_number = 1 ;
571
830
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1 ;
572
- const bool use_subgroup = context. Device (). HasFeature (wgpu::FeatureName::Subgroups) && context.AdapterInfo ().vendor == std::string_view{" intel" } && components_a == 4 && block_size == 32 ;
831
+ const bool use_subgroup = has_subgroup && context.AdapterInfo ().vendor == std::string_view{" intel" } && components_a == 4 && block_size == 32 ;
573
832
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int >(components_b), has_zero_points, use_subgroup};
574
833
if (M > kMinMForTileOptimization && block_size == 32 ) {
575
834
components = 1 ;
0 commit comments