Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DP4AMatMul perf refinements #23539

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 107 additions & 63 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
const tile_size_k = 32;
const vec_factor = 4;
const u32_factor = 4;
const tile_size_k_vec = 4;
const tile_size_k_vec = 2;
const block_size = 32;
// Shared memory
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
// Private memory
var<private> lane_output: array<output_element_t, 16>;
var<workgroup> tile_A : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
var<workgroup> tile_B : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
Expand All @@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
{
return;
}
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col];
if (col == 0)
{
// kidx_v - covers 8 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
// kidx_v - covers 16 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
}
}
Expand All @@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return;
}
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
if (col == 0)
{
// kidx_v - each kidx_v covers 8 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
// kidx_v - each kidx_v covers 16 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2];
}
}
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a[0], b[0]);
local_sum += dot4I8Packed(a[1], b[1]);
local_sum += dot4I8Packed(a[2], b[2]);
local_sum += dot4I8Packed(a[3], b[3]);
return local_sum;
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
// During the load phase we use all 256 threads to load 64 rows of A/B.
// For each row we load 4 vectorized elements, which are 32 elements of K.
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
let a_global_base = workgroup_id.x * tile_size;
let b_global_base = workgroup_id.y * tile_size;
let load_row = u32(local_idx/4);
let load_col = u32(local_idx%4);
let load_AorB = u32(local_idx/128);
let load_row = u32((local_idx%128)/2);
let load_col = u32(local_idx%2);
// During the compute phase, we have the 64x64 tile split into
// subtiles of 16x16. We have a grid of 4x4 subtiles.
Expand All @@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// For each subtile we have 16 threads assigned.
let a_idx = u32(local_idx % subtile_size);
// K's vectrorization is 8 items per index. See input_a/input_b.
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
var lane_output1: vec4<output_element_t>;
var lane_output2: vec4<output_element_t>;
var lane_output3: vec4<output_element_t>;
var lane_output4: vec4<output_element_t>;
// K's vectrorization is 16 items per index. See input_a/input_b.
// tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
// k tile size is 32. In vectorized space that is 32/16 = 2.
for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec)
{
// Populate shared memory for the workgroup
loadSHMA(a_global_base, kidx_v, load_row, load_col);
loadSHMB(b_global_base, kidx_v, load_row, load_col);
// Load Phase: Populate shared memory for the workgroup.
if (load_AorB == 0)
{
loadSHMA(a_global_base, kidx_v, load_row, load_col);
}
else
{
loadSHMB(b_global_base, kidx_v, load_row, load_col);
}
workgroupBarrier();
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
var own_scale_a = scale_A[base_A + a_idx];
// Compute phase: Perform matmul for this subtile 16 x 32 x 16.
// Step 1: Load from shared memory into registers across entire subgroup.
var own_a0: vec4<u32> = tile_A[0][base_A + a_idx];
var own_a1: vec4<u32> = tile_A[1][base_A + a_idx];
var own_scale_a: output_element_t = scale_A[base_A + a_idx];
if (sg_size == 16)
{
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
var own_scale_b = scale_B[base_B + sg_id];
for (var col:u32 = 0; col < 16; col++)
{
var local_scale_b = subgroupShuffle(own_scale_b, col);
local_scale_b = local_scale_b * own_scale_a;
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
}
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
var own_b1: vec4<u32> = tile_B[1][base_B + sg_id];
var own_scale_b: output_element_t = scale_B[base_B + sg_id];
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a);
lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a);
lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a);
lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a);
lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a);
lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a);
lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a);
lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a);
lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a);
lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a);
lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a);
lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a);
lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a);
lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a);
lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a);
lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a);
}
else
{
for (var col:u32 = 0; col < 16; col++)
{
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
var local_sum = DP4AI(own_a0, b0);
local_sum += DP4AI(own_a1, b1);
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
}
// Code for other subgroup sizes, simply doesnt use subgroups at all.
// Relies on reads from single location tile_B[][base_B + col] by all
// being optimized by the hardware.
lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]);
lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]);
lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]);
lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]);
lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]);
lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]);
lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]);
lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]);
lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]);
lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]);
lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]);
lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]);
lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]);
lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]);
lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
}
workgroupBarrier();
}
Expand All @@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// This creates a shader requirement that uniforms.N % 16 == 0
if (a_global < uniforms.M && b_global < uniforms.N)
{
for (var i:u32 = 0; i < 4; i++)
{
let lidx = i * 4;
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
}
output[output_idx] = lane_output1;
output[output_idx+1] = lane_output2;
output[output_idx+2] = lane_output3;
output[output_idx+3] = lane_output4;
}
)MAIN_FN";

Expand Down Expand Up @@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
mul_program.SetDispatchGroupSize(
(M + kTileSize - 1) / kTileSize,
(N + kTileSize - 1) / kTileSize, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
Expand Down
Loading