Skip to content

Commit

Permalink
WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP (#23365)
Browse files Browse the repository at this point in the history
### Description

This change implements accuracy level 4 - quantize A to int8 matmul for
the WebGPU EP. The matmul kernel here uses DP4A for matrix
multiplication, in order to keep the DP4A fed co-operative matrix
multiplication is implemented which preloads the row/col into local
variables before the multiplication operation.

Credits to @qjia7 for help with the quantizer shader.

Performance metrics on intel ADL/TGL GPU.

```
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       2.76762e+06
        **avg (tokens/s): 181.022**   <<< Prefill speed
        p50 (us):       2.74843e+06
        stddev (us):    41756.4
        n:              5 * 501 token(s)
Token generation:
        avg (us):       81500.7
        avg (tokens/s): 12.2698
        p50 (us):       81104.1
        stddev (us):    2961.31
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       13.1836
        avg (tokens/s): 75851.9
        p50 (us):       12
        stddev (us):    6.47085
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       13120
        p50 (ms):       13081.6
        stddev (ms):    114.689
        n:              5
Peak working set size (bytes): 5467533312
WebGPU device lost (2): Device was destroyed.

```
This kernel is 2.10x faster than its F16 counterpart for a 500 token
prefill. Previous prefill record is 86tks/s.

In order to support devices with subgroup size 8/32, a no subgroup
version of the same shader is included. Performance is slower than the
subgroup version on ADL.

```
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.11989e+06
        avg (tokens/s): 121.605
        p50 (us):       4.11847e+06
        stddev (us):    2147.48
        n:              5 * 501 token(s)
Token generation:
        avg (us):       81174.9
        avg (tokens/s): 12.3191
        p50 (us):       81301.1
        stddev (us):    2177.2
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       14.7998
        avg (tokens/s): 67568.3
        p50 (us):       12.3
        stddev (us):    11.5481
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       14431.1
        p50 (ms):       14433.8
        stddev (ms):    5.02473
        n:              5
Peak working set size (bytes): 5466480640
WebGPU device lost (2): Device was destroyed.
```
  • Loading branch information
sushraja-msft authored Jan 21, 2025
1 parent c7f764c commit 58c29d3
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 1 deletion.
261 changes: 260 additions & 1 deletion onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,222 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AddOutput("scales", ShaderUsage::UseUniform);

shader.AdditionalImplementation() << R"ADDNL_FN(
var<workgroup> max_values : array<input_a_element_t, 4>;
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
var local_a = input_a[global_idx];
var max_val = subgroupMax(abs(local_a));
var max_temp = max(max_val.xy, max_val.zw);
var scale = max(max_temp[0], max_temp[1]);
if (local_idx % sg_size == 0) {
max_values[local_idx / sg_size] = scale;
}
workgroupBarrier();
if (sg_size == 8)
{
scale = max(max_values[0], max_values[1]);
scale = max(scale, max_values[2]);
scale = max(scale, max_values[3]);
}
else if (sg_size == 16)
{
scale = max(max_values[0], max_values[1]);
}
else
{
scale = max_values[0];
}
var norm_a = local_a/scale;
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
if (local_idx == 0)
{
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_idx] = scale/127;
}
)MAIN_FN";
return Status::OK();
}

Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales_a", ShaderUsage::UseUniform);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

// This shader implements co-operative matrix multiply. The key idea here is to
// assume there is a primitive for medium size matrix multiply a subgroup can perform,
// using all its lanes and pooling all its registers to keep the values in registry.
//
// The entire workgroup which has N subgroups first loads a tile into shared memory,
// Then each subgroup loads a subtile from shared memory into registers and uses
// the medium size matrix multiply primitive to perform the math.
// The values for tile/subtile size are chosen to conform to the resource limits
// of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads -
// therefore there are 16 subgroups and 16 lanes in each subgroup.
// K the hidden dimension is paged in from RAM at k tile size which is 64.
// All this puts the shared memory requirement slightly above 16KB.
// WebGPU limit is 16KB, output is moved to registers instead of SHM to make
// everything fit in shared memory.
//
// Each subgroup performs a 16 x 64 x 16 multiply which is implemented with
// subgroup shuffle as a placeholder for the day the medium matrix mul primitive
// becomes available in WGSL. The registry requirements is ~2KB per subgroup, on
// Alderlake/Tigerlake subgroup has 8KB of registry space pooling the
// 512B of registry from each lane.
//
// The medium size matmul is implemented using dot4I8Packed, so the inputs for
// this shader require A to be int8 quantized with block size 64. B is regular
// matmulnbits input with block size 32.

shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_size = 64;
const subtile_size = 16;
const tile_size_k = 32;
const vec_factor = 4;
const u32_factor = 4;
const tile_size_k_vec = 4;
const block_size = 32;
// Shared memory
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
// Private memory
var<private> lane_output: array<output_element_t, 16>;
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
let a_global = a_global_base + row;
if (a_global >= uniforms.M)
{
return;
}
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
if (col == 0)
{
// kidx_v - covers 8 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
}
}
fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
let b_global = b_global_base + row;
if (b_global >= uniforms.N)
{
return;
}
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
if (col == 0)
{
// kidx_v - each kidx_v covers 8 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
}
}
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
{
var local_sum = dot4I8Packed(a[0], b[0]);
local_sum += dot4I8Packed(a[1], b[1]);
local_sum += dot4I8Packed(a[2], b[2]);
local_sum += dot4I8Packed(a[3], b[3]);
return local_sum;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
// During the load phase we use all 256 threads to load 64 rows of A/B.
// For each row we load 4 vectorized elements, which are 32 elements of K.
let a_global_base = workgroup_id.x * tile_size;
let b_global_base = workgroup_id.y * tile_size;
let load_row = u32(local_idx/4);
let load_col = u32(local_idx%4);
// During the compute phase, we have the 64x64 tile split into
// subtiles of 16x16. We have a grid of 4x4 subtiles.
let subtile_id = u32(local_idx / subtile_size);
let subtile_idx = u32(subtile_id / 4);
let subtile_idy = u32(subtile_id % 4);
let base_A = subtile_idx * 16;
let base_B = subtile_idy * 16;
// For each subtile we have 16 threads assigned.
let a_idx = u32(local_idx % subtile_size);
// K's vectrorization is 8 items per index. See input_a/input_b.
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
{
// Populate shared memory for the workgroup
loadSHMA(a_global_base, kidx_v, load_row, load_col);
loadSHMB(b_global_base, kidx_v, load_row, load_col);
workgroupBarrier();
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
var own_scale_a = scale_A[base_A + a_idx];
if (sg_size == 16)
{
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
var own_scale_b = scale_B[base_B + sg_id];
for (var col:u32 = 0; col < 16; col++)
{
var local_scale_b = subgroupShuffle(own_scale_b, col);
local_scale_b = local_scale_b * own_scale_a;
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
}
}
else
{
for (var col:u32 = 0; col < 16; col++)
{
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
var local_sum = DP4AI(own_a0, b0);
local_sum += DP4AI(own_a1, b1);
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
}
}
workgroupBarrier();
}
let a_global = a_global_base + base_A + a_idx;
let b_global = b_global_base + base_B;
let output_idx = ((a_global) * uniforms.N + b_global)/4;
// This creates a shader requirement that uniforms.N % 16 == 0
if (a_global < uniforms.M && b_global < uniforms.N)
{
for (var i:u32 = 0; i < 4; i++)
{
let lidx = i * 4;
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
}
}
)MAIN_FN";

return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -565,11 +781,54 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
uint32_t components = GetMaxComponents(N);

const bool has_zero_points = zero_points != nullptr;
const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups);
// macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
// https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal;
if (accuracy_level_ == 4 && block_size == 32 &&
batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 &&
!has_zero_points && use_dp4a && M >= kMinMForTileOptimization) {
constexpr uint32_t kVec4Components = 4;
constexpr uint32_t kVec2Components = 2;
constexpr uint32_t kU32Components = 4;

constexpr uint32_t kBlockSizeA = 128;
DP4AMatMulQuantizeProgram quantize_program;
quantize_program.SetWorkgroupSize(32);
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
TensorShape a_quant_shape{1, M, K / kU32Components};
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);
TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA});
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)}})
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow<int>(1)},
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
DP4AMatMulNBitsProgram mul_program;
mul_program.SetWorkgroupSize(256);
mul_program.SetDispatchGroupSize(
(M + kTileSize - 1) / kTileSize,
(N + kTileSize - 1) / kTileSize, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 8)},
{static_cast<uint32_t>(K / 16)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(kVec4Components)});
return context.RunProgram(mul_program);
}

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
constexpr uint32_t output_number = 1;
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
const bool use_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32;
const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32;
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points, use_subgroup};
if (M > kMinMForTileOptimization && block_size == 32) {
components = 1;
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,32 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_subgroup_;
};

class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram> {
public:
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
};

class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
public:
DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32},
{"K16", ProgramUniformVariableDataType::Uint32});
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
K_ = info.GetAttr<int64_t>("K");
N_ = info.GetAttr<int64_t>("N");
block_size_ = info.GetAttr<int64_t>("block_size");
int64_t bits = info.GetAttr<int64_t>("bits");
accuracy_level_ = info.GetAttrOrDefault<int64_t>("accuracy_level", 4);
ORT_ENFORCE(bits == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
}
Expand All @@ -52,6 +71,7 @@ class MatMulNBits final : public WebGpuKernel {
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t accuracy_level_;
};

} // namespace webgpu
Expand Down

0 comments on commit 58c29d3

Please sign in to comment.