Skip to content

Commit 58c29d3

Browse files
WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP (#23365)
### Description This change implements accuracy level 4 - quantize A to int8 matmul for the WebGPU EP. The matmul kernel here uses DP4A for matrix multiplication, in order to keep the DP4A fed co-operative matrix multiplication is implemented which preloads the row/col into local variables before the multiplication operation. Credits to @qjia7 for help with the quantizer shader. Performance metrics on intel ADL/TGL GPU. ``` PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 2.76762e+06 **avg (tokens/s): 181.022** <<< Prefill speed p50 (us): 2.74843e+06 stddev (us): 41756.4 n: 5 * 501 token(s) Token generation: avg (us): 81500.7 avg (tokens/s): 12.2698 p50 (us): 81104.1 stddev (us): 2961.31 n: 635 * 1 token(s) Token sampling: avg (us): 13.1836 avg (tokens/s): 75851.9 p50 (us): 12 stddev (us): 6.47085 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 13120 p50 (ms): 13081.6 stddev (ms): 114.689 n: 5 Peak working set size (bytes): 5467533312 WebGPU device lost (2): Device was destroyed. ``` This kernel is 2.10x faster than its F16 counterpart for a 500 token prefill. Previous prefill record is 86tks/s. In order to support devices with subgroup size 8/32, a no subgroup version of the same shader is included. Performance is slower than the subgroup version on ADL. ``` PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 4.11989e+06 avg (tokens/s): 121.605 p50 (us): 4.11847e+06 stddev (us): 2147.48 n: 5 * 501 token(s) Token generation: avg (us): 81174.9 avg (tokens/s): 12.3191 p50 (us): 81301.1 stddev (us): 2177.2 n: 635 * 1 token(s) Token sampling: avg (us): 14.7998 avg (tokens/s): 67568.3 p50 (us): 12.3 stddev (us): 11.5481 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 14431.1 p50 (ms): 14433.8 stddev (ms): 5.02473 n: 5 Peak working set size (bytes): 5466480640 WebGPU device lost (2): Device was destroyed. ```
1 parent c7f764c commit 58c29d3

File tree

2 files changed

+280
-1
lines changed

2 files changed

+280
-1
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,222 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
530530
return Status::OK();
531531
}
532532

533+
Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
534+
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
535+
shader.AddOutput("output", ShaderUsage::UseUniform);
536+
shader.AddOutput("scales", ShaderUsage::UseUniform);
537+
538+
shader.AdditionalImplementation() << R"ADDNL_FN(
539+
var<workgroup> max_values : array<input_a_element_t, 4>;
540+
)ADDNL_FN";
541+
542+
shader.MainFunctionBody() << R"MAIN_FN(
543+
var local_a = input_a[global_idx];
544+
var max_val = subgroupMax(abs(local_a));
545+
var max_temp = max(max_val.xy, max_val.zw);
546+
var scale = max(max_temp[0], max_temp[1]);
547+
if (local_idx % sg_size == 0) {
548+
max_values[local_idx / sg_size] = scale;
549+
}
550+
workgroupBarrier();
551+
552+
if (sg_size == 8)
553+
{
554+
scale = max(max_values[0], max_values[1]);
555+
scale = max(scale, max_values[2]);
556+
scale = max(scale, max_values[3]);
557+
}
558+
else if (sg_size == 16)
559+
{
560+
scale = max(max_values[0], max_values[1]);
561+
}
562+
else
563+
{
564+
scale = max_values[0];
565+
}
566+
567+
var norm_a = local_a/scale;
568+
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
569+
if (local_idx == 0)
570+
{
571+
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
572+
scales[workgroup_idx] = scale/127;
573+
}
574+
)MAIN_FN";
575+
return Status::OK();
576+
}
577+
578+
Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
579+
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
580+
shader.AddInput("scales_a", ShaderUsage::UseUniform);
581+
shader.AddInput("input_b", ShaderUsage::UseUniform);
582+
shader.AddInput("scales_b", ShaderUsage::UseUniform);
583+
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
584+
585+
// This shader implements co-operative matrix multiply. The key idea here is to
586+
// assume there is a primitive for medium size matrix multiply a subgroup can perform,
587+
// using all its lanes and pooling all its registers to keep the values in registry.
588+
//
589+
// The entire workgroup which has N subgroups first loads a tile into shared memory,
590+
// Then each subgroup loads a subtile from shared memory into registers and uses
591+
// the medium size matrix multiply primitive to perform the math.
592+
// The values for tile/subtile size are chosen to conform to the resource limits
593+
// of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads -
594+
// therefore there are 16 subgroups and 16 lanes in each subgroup.
595+
// K the hidden dimension is paged in from RAM at k tile size which is 64.
596+
// All this puts the shared memory requirement slightly above 16KB.
597+
// WebGPU limit is 16KB, output is moved to registers instead of SHM to make
598+
// everything fit in shared memory.
599+
//
600+
// Each subgroup performs a 16 x 64 x 16 multiply which is implemented with
601+
// subgroup shuffle as a placeholder for the day the medium matrix mul primitive
602+
// becomes available in WGSL. The registry requirements is ~2KB per subgroup, on
603+
// Alderlake/Tigerlake subgroup has 8KB of registry space pooling the
604+
// 512B of registry from each lane.
605+
//
606+
// The medium size matmul is implemented using dot4I8Packed, so the inputs for
607+
// this shader require A to be int8 quantized with block size 64. B is regular
608+
// matmulnbits input with block size 32.
609+
610+
shader.AdditionalImplementation() << R"ADDNL_FN(
611+
const tile_size = 64;
612+
const subtile_size = 16;
613+
const tile_size_k = 32;
614+
const vec_factor = 4;
615+
const u32_factor = 4;
616+
const tile_size_k_vec = 4;
617+
const block_size = 32;
618+
619+
// Shared memory
620+
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
621+
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
622+
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
623+
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
624+
625+
// Private memory
626+
var<private> lane_output: array<output_element_t, 16>;
627+
628+
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
629+
{
630+
let a_global = a_global_base + row;
631+
if (a_global >= uniforms.M)
632+
{
633+
return;
634+
}
635+
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
636+
if (col == 0)
637+
{
638+
// kidx_v - covers 8 values of k
639+
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
640+
}
641+
}
642+
643+
fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32)
644+
{
645+
let b_global = b_global_base + row;
646+
if (b_global >= uniforms.N)
647+
{
648+
return;
649+
}
650+
651+
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
652+
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
653+
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
654+
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
655+
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
656+
if (col == 0)
657+
{
658+
// kidx_v - each kidx_v covers 8 values of k
659+
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
660+
}
661+
}
662+
663+
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
664+
{
665+
var local_sum = dot4I8Packed(a[0], b[0]);
666+
local_sum += dot4I8Packed(a[1], b[1]);
667+
local_sum += dot4I8Packed(a[2], b[2]);
668+
local_sum += dot4I8Packed(a[3], b[3]);
669+
return local_sum;
670+
}
671+
672+
)ADDNL_FN";
673+
674+
shader.MainFunctionBody() << R"MAIN_FN(
675+
// During the load phase we use all 256 threads to load 64 rows of A/B.
676+
// For each row we load 4 vectorized elements, which are 32 elements of K.
677+
let a_global_base = workgroup_id.x * tile_size;
678+
let b_global_base = workgroup_id.y * tile_size;
679+
let load_row = u32(local_idx/4);
680+
let load_col = u32(local_idx%4);
681+
682+
// During the compute phase, we have the 64x64 tile split into
683+
// subtiles of 16x16. We have a grid of 4x4 subtiles.
684+
let subtile_id = u32(local_idx / subtile_size);
685+
let subtile_idx = u32(subtile_id / 4);
686+
let subtile_idy = u32(subtile_id % 4);
687+
let base_A = subtile_idx * 16;
688+
let base_B = subtile_idy * 16;
689+
// For each subtile we have 16 threads assigned.
690+
let a_idx = u32(local_idx % subtile_size);
691+
692+
// K's vectrorization is 8 items per index. See input_a/input_b.
693+
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
694+
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
695+
{
696+
// Populate shared memory for the workgroup
697+
loadSHMA(a_global_base, kidx_v, load_row, load_col);
698+
loadSHMB(b_global_base, kidx_v, load_row, load_col);
699+
workgroupBarrier();
700+
701+
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
702+
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
703+
var own_scale_a = scale_A[base_A + a_idx];
704+
if (sg_size == 16)
705+
{
706+
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
707+
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
708+
var own_scale_b = scale_B[base_B + sg_id];
709+
for (var col:u32 = 0; col < 16; col++)
710+
{
711+
var local_scale_b = subgroupShuffle(own_scale_b, col);
712+
local_scale_b = local_scale_b * own_scale_a;
713+
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
714+
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
715+
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
716+
}
717+
}
718+
else
719+
{
720+
for (var col:u32 = 0; col < 16; col++)
721+
{
722+
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
723+
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
724+
var local_sum = DP4AI(own_a0, b0);
725+
local_sum += DP4AI(own_a1, b1);
726+
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
727+
}
728+
}
729+
workgroupBarrier();
730+
}
731+
732+
let a_global = a_global_base + base_A + a_idx;
733+
let b_global = b_global_base + base_B;
734+
let output_idx = ((a_global) * uniforms.N + b_global)/4;
735+
// This creates a shader requirement that uniforms.N % 16 == 0
736+
if (a_global < uniforms.M && b_global < uniforms.N)
737+
{
738+
for (var i:u32 = 0; i < 4; i++)
739+
{
740+
let lidx = i * 4;
741+
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
742+
}
743+
}
744+
)MAIN_FN";
745+
746+
return Status::OK();
747+
}
748+
533749
Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
534750
const Tensor* a = context.Input(0);
535751
const Tensor* b = context.Input(1);
@@ -565,11 +781,54 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
565781
uint32_t components = GetMaxComponents(N);
566782

567783
const bool has_zero_points = zero_points != nullptr;
784+
const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups);
785+
// macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
786+
// https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
787+
const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal;
788+
if (accuracy_level_ == 4 && block_size == 32 &&
789+
batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 &&
790+
!has_zero_points && use_dp4a && M >= kMinMForTileOptimization) {
791+
constexpr uint32_t kVec4Components = 4;
792+
constexpr uint32_t kVec2Components = 2;
793+
constexpr uint32_t kU32Components = 4;
794+
795+
constexpr uint32_t kBlockSizeA = 128;
796+
DP4AMatMulQuantizeProgram quantize_program;
797+
quantize_program.SetWorkgroupSize(32);
798+
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
799+
TensorShape a_quant_shape{1, M, K / kU32Components};
800+
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);
801+
TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA});
802+
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
803+
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)}})
804+
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow<int>(1)},
805+
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}});
806+
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
807+
808+
constexpr uint32_t kTileSize = 64;
809+
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
810+
DP4AMatMulNBitsProgram mul_program;
811+
mul_program.SetWorkgroupSize(256);
812+
mul_program.SetDispatchGroupSize(
813+
(M + kTileSize - 1) / kTileSize,
814+
(N + kTileSize - 1) / kTileSize, 1);
815+
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
816+
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
817+
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
818+
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
819+
.AddUniformVariables({{static_cast<uint32_t>(M)},
820+
{static_cast<uint32_t>(N)},
821+
{static_cast<uint32_t>(K)},
822+
{static_cast<uint32_t>(K / 8)},
823+
{static_cast<uint32_t>(K / 16)}})
824+
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(kVec4Components)});
825+
return context.RunProgram(mul_program);
826+
}
568827

569828
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
570829
constexpr uint32_t output_number = 1;
571830
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
572-
const bool use_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32;
831+
const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32;
573832
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points, use_subgroup};
574833
if (M > kMinMForTileOptimization && block_size == 32) {
575834
components = 1;

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,32 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
3535
bool use_subgroup_;
3636
};
3737

38+
class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram> {
39+
public:
40+
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
41+
Status GenerateShaderCode(ShaderHelper& sh) const override;
42+
};
43+
44+
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
45+
public:
46+
DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {}
47+
Status GenerateShaderCode(ShaderHelper& sh) const override;
48+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
49+
{"M", ProgramUniformVariableDataType::Uint32},
50+
{"N", ProgramUniformVariableDataType::Uint32},
51+
{"K", ProgramUniformVariableDataType::Uint32},
52+
{"K8", ProgramUniformVariableDataType::Uint32},
53+
{"K16", ProgramUniformVariableDataType::Uint32});
54+
};
55+
3856
class MatMulNBits final : public WebGpuKernel {
3957
public:
4058
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
4159
K_ = info.GetAttr<int64_t>("K");
4260
N_ = info.GetAttr<int64_t>("N");
4361
block_size_ = info.GetAttr<int64_t>("block_size");
4462
int64_t bits = info.GetAttr<int64_t>("bits");
63+
accuracy_level_ = info.GetAttrOrDefault<int64_t>("accuracy_level", 4);
4564
ORT_ENFORCE(bits == 4,
4665
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
4766
}
@@ -52,6 +71,7 @@ class MatMulNBits final : public WebGpuKernel {
5271
int64_t K_;
5372
int64_t N_;
5473
int64_t block_size_;
74+
int64_t accuracy_level_;
5575
};
5676

5777
} // namespace webgpu

0 commit comments

Comments
 (0)