@@ -530,6 +530,222 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
530530 return Status::OK ();
531531}
532532
533+ Status DP4AMatMulQuantizeProgram::GenerateShaderCode (ShaderHelper& shader) const {
534+ shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
535+ shader.AddOutput (" output" , ShaderUsage::UseUniform);
536+ shader.AddOutput (" scales" , ShaderUsage::UseUniform);
537+
538+ shader.AdditionalImplementation () << R"ADDNL_FN(
539+ var<workgroup> max_values : array<input_a_element_t, 4>;
540+ )ADDNL_FN" ;
541+
542+ shader.MainFunctionBody () << R"MAIN_FN(
543+ var local_a = input_a[global_idx];
544+ var max_val = subgroupMax(abs(local_a));
545+ var max_temp = max(max_val.xy, max_val.zw);
546+ var scale = max(max_temp[0], max_temp[1]);
547+ if (local_idx % sg_size == 0) {
548+ max_values[local_idx / sg_size] = scale;
549+ }
550+ workgroupBarrier();
551+
552+ if (sg_size == 8)
553+ {
554+ scale = max(max_values[0], max_values[1]);
555+ scale = max(scale, max_values[2]);
556+ scale = max(scale, max_values[3]);
557+ }
558+ else if (sg_size == 16)
559+ {
560+ scale = max(max_values[0], max_values[1]);
561+ }
562+ else
563+ {
564+ scale = max_values[0];
565+ }
566+
567+ var norm_a = local_a/scale;
568+ output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
569+ if (local_idx == 0)
570+ {
571+ // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
572+ scales[workgroup_idx] = scale/127;
573+ }
574+ )MAIN_FN" ;
575+ return Status::OK ();
576+ }
577+
578+ Status DP4AMatMulNBitsProgram::GenerateShaderCode (ShaderHelper& shader) const {
579+ shader.AddInput (" input_a" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
580+ shader.AddInput (" scales_a" , ShaderUsage::UseUniform);
581+ shader.AddInput (" input_b" , ShaderUsage::UseUniform);
582+ shader.AddInput (" scales_b" , ShaderUsage::UseUniform);
583+ shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
584+
585+ // This shader implements co-operative matrix multiply. The key idea here is to
586+ // assume there is a primitive for medium size matrix multiply a subgroup can perform,
587+ // using all its lanes and pooling all its registers to keep the values in registry.
588+ //
589+ // The entire workgroup which has N subgroups first loads a tile into shared memory,
590+ // Then each subgroup loads a subtile from shared memory into registers and uses
591+ // the medium size matrix multiply primitive to perform the math.
592+ // The values for tile/subtile size are chosen to conform to the resource limits
593+ // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads -
594+ // therefore there are 16 subgroups and 16 lanes in each subgroup.
595+ // K the hidden dimension is paged in from RAM at k tile size which is 64.
596+ // All this puts the shared memory requirement slightly above 16KB.
597+ // WebGPU limit is 16KB, output is moved to registers instead of SHM to make
598+ // everything fit in shared memory.
599+ //
600+ // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with
601+ // subgroup shuffle as a placeholder for the day the medium matrix mul primitive
602+ // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on
603+ // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the
604+ // 512B of registry from each lane.
605+ //
606+ // The medium size matmul is implemented using dot4I8Packed, so the inputs for
607+ // this shader require A to be int8 quantized with block size 64. B is regular
608+ // matmulnbits input with block size 32.
609+
610+ shader.AdditionalImplementation () << R"ADDNL_FN(
611+ const tile_size = 64;
612+ const subtile_size = 16;
613+ const tile_size_k = 32;
614+ const vec_factor = 4;
615+ const u32_factor = 4;
616+ const tile_size_k_vec = 4;
617+ const block_size = 32;
618+
619+ // Shared memory
620+ var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
621+ var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622+ var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
623+ var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
624+
625+ // Private memory
626+ var<private> lane_output: array<output_element_t, 16>;
627+
628+ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
629+ {
630+ let a_global = a_global_base + row;
631+ if (a_global >= uniforms.M)
632+ {
633+ return;
634+ }
635+ tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
636+ if (col == 0)
637+ {
638+ // kidx_v - covers 8 values of k
639+ scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
640+ }
641+ }
642+
643+ fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32)
644+ {
645+ let b_global = b_global_base + row;
646+ if (b_global >= uniforms.N)
647+ {
648+ return;
649+ }
650+
651+ let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
652+ var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
653+ var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
654+ tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
655+ tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
656+ if (col == 0)
657+ {
658+ // kidx_v - each kidx_v covers 8 values of k
659+ scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
660+ }
661+ }
662+
663+ fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
664+ {
665+ var local_sum = dot4I8Packed(a[0], b[0]);
666+ local_sum += dot4I8Packed(a[1], b[1]);
667+ local_sum += dot4I8Packed(a[2], b[2]);
668+ local_sum += dot4I8Packed(a[3], b[3]);
669+ return local_sum;
670+ }
671+
672+ )ADDNL_FN" ;
673+
674+ shader.MainFunctionBody () << R"MAIN_FN(
675+ // During the load phase we use all 256 threads to load 64 rows of A/B.
676+ // For each row we load 4 vectorized elements, which are 32 elements of K.
677+ let a_global_base = workgroup_id.x * tile_size;
678+ let b_global_base = workgroup_id.y * tile_size;
679+ let load_row = u32(local_idx/4);
680+ let load_col = u32(local_idx%4);
681+
682+ // During the compute phase, we have the 64x64 tile split into
683+ // subtiles of 16x16. We have a grid of 4x4 subtiles.
684+ let subtile_id = u32(local_idx / subtile_size);
685+ let subtile_idx = u32(subtile_id / 4);
686+ let subtile_idy = u32(subtile_id % 4);
687+ let base_A = subtile_idx * 16;
688+ let base_B = subtile_idy * 16;
689+ // For each subtile we have 16 threads assigned.
690+ let a_idx = u32(local_idx % subtile_size);
691+
692+ // K's vectrorization is 8 items per index. See input_a/input_b.
693+ // tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
694+ for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
695+ {
696+ // Populate shared memory for the workgroup
697+ loadSHMA(a_global_base, kidx_v, load_row, load_col);
698+ loadSHMB(b_global_base, kidx_v, load_row, load_col);
699+ workgroupBarrier();
700+
701+ var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
702+ var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
703+ var own_scale_a = scale_A[base_A + a_idx];
704+ if (sg_size == 16)
705+ {
706+ var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
707+ var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
708+ var own_scale_b = scale_B[base_B + sg_id];
709+ for (var col:u32 = 0; col < 16; col++)
710+ {
711+ var local_scale_b = subgroupShuffle(own_scale_b, col);
712+ local_scale_b = local_scale_b * own_scale_a;
713+ var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
714+ local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
715+ lane_output[col] += (output_element_t(local_sum) * local_scale_b);
716+ }
717+ }
718+ else
719+ {
720+ for (var col:u32 = 0; col < 16; col++)
721+ {
722+ var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
723+ var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
724+ var local_sum = DP4AI(own_a0, b0);
725+ local_sum += DP4AI(own_a1, b1);
726+ lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
727+ }
728+ }
729+ workgroupBarrier();
730+ }
731+
732+ let a_global = a_global_base + base_A + a_idx;
733+ let b_global = b_global_base + base_B;
734+ let output_idx = ((a_global) * uniforms.N + b_global)/4;
735+ // This creates a shader requirement that uniforms.N % 16 == 0
736+ if (a_global < uniforms.M && b_global < uniforms.N)
737+ {
738+ for (var i:u32 = 0; i < 4; i++)
739+ {
740+ let lidx = i * 4;
741+ output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
742+ }
743+ }
744+ )MAIN_FN" ;
745+
746+ return Status::OK ();
747+ }
748+
533749Status MatMulNBits::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
534750 const Tensor* a = context.Input (0 );
535751 const Tensor* b = context.Input (1 );
@@ -565,11 +781,54 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
565781 uint32_t components = GetMaxComponents (N);
566782
567783 const bool has_zero_points = zero_points != nullptr ;
784+ const bool has_subgroup = context.Device ().HasFeature (wgpu::FeatureName::Subgroups);
785+ // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
786+ // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
787+ const bool use_dp4a = has_subgroup && context.AdapterInfo ().backendType != wgpu::BackendType::Metal;
788+ if (accuracy_level_ == 4 && block_size == 32 &&
789+ batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 &&
790+ !has_zero_points && use_dp4a && M >= kMinMForTileOptimization ) {
791+ constexpr uint32_t kVec4Components = 4 ;
792+ constexpr uint32_t kVec2Components = 2 ;
793+ constexpr uint32_t kU32Components = 4 ;
794+
795+ constexpr uint32_t kBlockSizeA = 128 ;
796+ DP4AMatMulQuantizeProgram quantize_program;
797+ quantize_program.SetWorkgroupSize (32 );
798+ quantize_program.SetDispatchGroupSize (M * K / kBlockSizeA , 1 , 1 );
799+ TensorShape a_quant_shape{1 , M, K / kU32Components };
800+ Tensor a_quant = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), a_quant_shape);
801+ TensorShapeVector a_scales_dims ({1 , 1 , M, K / kBlockSizeA });
802+ Tensor a_scale = context.CreateGPUTensor (a->DataType (), a_scales_dims);
803+ quantize_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec4Components )}})
804+ .AddOutputs ({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape (), gsl::narrow<int >(1 )},
805+ {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape (), gsl::narrow<int >(1 )}});
806+ ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
807+
808+ constexpr uint32_t kTileSize = 64 ;
809+ TensorShape reshaped_y_shape{1 , M, N / kVec4Components };
810+ DP4AMatMulNBitsProgram mul_program;
811+ mul_program.SetWorkgroupSize (256 );
812+ mul_program.SetDispatchGroupSize (
813+ (M + kTileSize - 1 ) / kTileSize ,
814+ (N + kTileSize - 1 ) / kTileSize , 1 );
815+ mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kVec2Components )},
816+ {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )},
817+ {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(kU32Components )},
818+ {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int >(1 )}})
819+ .AddUniformVariables ({{static_cast <uint32_t >(M)},
820+ {static_cast <uint32_t >(N)},
821+ {static_cast <uint32_t >(K)},
822+ {static_cast <uint32_t >(K / 8 )},
823+ {static_cast <uint32_t >(K / 16 )}})
824+ .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int >(kVec4Components )});
825+ return context.RunProgram (mul_program);
826+ }
568827
569828 // TODO: Support output_number > 1. Some cases are failed when output_number > 1.
570829 constexpr uint32_t output_number = 1 ;
571830 const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1 ;
572- const bool use_subgroup = context. Device (). HasFeature (wgpu::FeatureName::Subgroups) && context.AdapterInfo ().vendor == std::string_view{" intel" } && components_a == 4 && block_size == 32 ;
831+ const bool use_subgroup = has_subgroup && context.AdapterInfo ().vendor == std::string_view{" intel" } && components_a == 4 && block_size == 32 ;
573832 MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int >(components_b), has_zero_points, use_subgroup};
574833 if (M > kMinMForTileOptimization && block_size == 32 ) {
575834 components = 1 ;
0 commit comments