Skip to content

Commit

Permalink
DP4AMatMul perf refinements (microsoft#23539)
Browse files Browse the repository at this point in the history
In this change

1. Vectorization of k is updated to 4.
2. Tile_A, Tile_B are stored transposed in shared memory. This makes it
so that memory locality is improved for our access pattern.
3. Lane output is switched to being individual vectors and its loop
unrolled, this solves the problem where laneoutput was not on registers
before.

Perf improvements are not very consistent with this change. On Tigerlake
GPU with 32.0.101.6460 (latest intel drivers)
```
Baseline

model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       7.36557e+06                         <<<<
        avg (tokens/s): 135.903
        p50 (us):       7.35498e+06
        stddev (us):    27599
        n:              5 * 1001 token(s)

With Change

model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       6.52302e+06                           <<<<
        avg (tokens/s): 153.457
        p50 (us):       6.52224e+06
        stddev (us):    10407.3
        n:              5 * 1001 token(s)
```

However, using the Intel GPA comparing before and after profile, one can
clearly see straight runs of ALU work without being interspersed by
writebacks to local memory that contained lane_output before.


![image](https://github.com/user-attachments/assets/e01d3474-8406-4a61-b352-2ecbf0855a7f)
  • Loading branch information
sushraja-msft authored and jatinwadhwa921 committed Feb 5, 2025
1 parent c65f8ca commit 82cc2c7
Showing 1 changed file with 107 additions and 63 deletions.
170 changes: 107 additions & 63 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
const tile_size_k = 32;
const vec_factor = 4;
const u32_factor = 4;
const tile_size_k_vec = 4;
const tile_size_k_vec = 2;
const block_size = 32;
// Shared memory
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
// Private memory
var<private> lane_output: array<output_element_t, 16>;
var<workgroup> tile_A : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
var<workgroup> tile_B : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
Expand All @@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
{
return;
}
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col];
if (col == 0)
{
// kidx_v - covers 8 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
// kidx_v - covers 16 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
}
}
Expand All @@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return;
}
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
if (col == 0)
{
// kidx_v - each kidx_v covers 8 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
// kidx_v - each kidx_v covers 16 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2];
}
}
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a[0], b[0]);
local_sum += dot4I8Packed(a[1], b[1]);
local_sum += dot4I8Packed(a[2], b[2]);
local_sum += dot4I8Packed(a[3], b[3]);
return local_sum;
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
// During the load phase we use all 256 threads to load 64 rows of A/B.
// For each row we load 4 vectorized elements, which are 32 elements of K.
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
let a_global_base = workgroup_id.x * tile_size;
let b_global_base = workgroup_id.y * tile_size;
let load_row = u32(local_idx/4);
let load_col = u32(local_idx%4);
let load_AorB = u32(local_idx/128);
let load_row = u32((local_idx%128)/2);
let load_col = u32(local_idx%2);
// During the compute phase, we have the 64x64 tile split into
// subtiles of 16x16. We have a grid of 4x4 subtiles.
Expand All @@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// For each subtile we have 16 threads assigned.
let a_idx = u32(local_idx % subtile_size);
// K's vectrorization is 8 items per index. See input_a/input_b.
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
var lane_output1: vec4<output_element_t>;
var lane_output2: vec4<output_element_t>;
var lane_output3: vec4<output_element_t>;
var lane_output4: vec4<output_element_t>;
// K's vectrorization is 16 items per index. See input_a/input_b.
// tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
// k tile size is 32. In vectorized space that is 32/16 = 2.
for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec)
{
// Populate shared memory for the workgroup
loadSHMA(a_global_base, kidx_v, load_row, load_col);
loadSHMB(b_global_base, kidx_v, load_row, load_col);
// Load Phase: Populate shared memory for the workgroup.
if (load_AorB == 0)
{
loadSHMA(a_global_base, kidx_v, load_row, load_col);
}
else
{
loadSHMB(b_global_base, kidx_v, load_row, load_col);
}
workgroupBarrier();
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
var own_scale_a = scale_A[base_A + a_idx];
// Compute phase: Perform matmul for this subtile 16 x 32 x 16.
// Step 1: Load from shared memory into registers across entire subgroup.
var own_a0: vec4<u32> = tile_A[0][base_A + a_idx];
var own_a1: vec4<u32> = tile_A[1][base_A + a_idx];
var own_scale_a: output_element_t = scale_A[base_A + a_idx];
if (sg_size == 16)
{
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
var own_scale_b = scale_B[base_B + sg_id];
for (var col:u32 = 0; col < 16; col++)
{
var local_scale_b = subgroupShuffle(own_scale_b, col);
local_scale_b = local_scale_b * own_scale_a;
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
}
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
var own_b1: vec4<u32> = tile_B[1][base_B + sg_id];
var own_scale_b: output_element_t = scale_B[base_B + sg_id];
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a);
lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a);
lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a);
lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a);
lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a);
lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a);
lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a);
lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a);
lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a);
lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a);
lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a);
lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a);
lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a);
lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a);
lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a);
lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a);
}
else
{
for (var col:u32 = 0; col < 16; col++)
{
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
var local_sum = DP4AI(own_a0, b0);
local_sum += DP4AI(own_a1, b1);
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
}
// Code for other subgroup sizes, simply doesnt use subgroups at all.
// Relies on reads from single location tile_B[][base_B + col] by all
// being optimized by the hardware.
lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]);
lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]);
lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]);
lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]);
lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]);
lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]);
lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]);
lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]);
lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]);
lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]);
lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]);
lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]);
lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]);
lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]);
lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]);
lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]);
}
workgroupBarrier();
}
Expand All @@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// This creates a shader requirement that uniforms.N % 16 == 0
if (a_global < uniforms.M && b_global < uniforms.N)
{
for (var i:u32 = 0; i < 4; i++)
{
let lidx = i * 4;
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
}
output[output_idx] = lane_output1;
output[output_idx+1] = lane_output2;
output[output_idx+2] = lane_output3;
output[output_idx+3] = lane_output4;
}
)MAIN_FN";

Expand Down Expand Up @@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
mul_program.SetDispatchGroupSize(
(M + kTileSize - 1) / kTileSize,
(N + kTileSize - 1) / kTileSize, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
Expand Down

0 comments on commit 82cc2c7

Please sign in to comment.