diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 90e6516ff45d1..c79efee65e5c5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const tile_size_k = 32; const vec_factor = 4; const u32_factor = 4; - const tile_size_k_vec = 4; + const tile_size_k_vec = 2; const block_size = 32; // Shared memory - var tile_A : array, tile_size_k_vec>, tile_size>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size_k_vec>, tile_size>; // 64 x 32 - var scale_B : array; // 64 x 1 - - // Private memory - var lane_output: array; + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) { @@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { { return; } - tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col]; + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; if (col == 0) { - // kidx_v - covers 8 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16]; + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; } } @@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return; } - let b_value = input_b[b_global*uniforms.K8+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[row][col][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[row][col][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); if (col == 0) { - // kidx_v - each kidx_v covers 8 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4]; + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; } } - fn DP4AI(a:vec4, b:vec4) -> i32 + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t { - var local_sum = dot4I8Packed(a[0], b[0]); - local_sum += dot4I8Packed(a[1], b[1]); - local_sum += dot4I8Packed(a[2], b[2]); - local_sum += dot4I8Packed(a[3], b[3]); - return local_sum; + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; } - )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load 4 vectorized elements, which are 32 elements of K. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. let a_global_base = workgroup_id.x * tile_size; let b_global_base = workgroup_id.y * tile_size; - let load_row = u32(local_idx/4); - let load_col = u32(local_idx%4); + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); // During the compute phase, we have the 64x64 tile split into // subtiles of 16x16. We have a grid of 4x4 subtiles. @@ -689,42 +695,81 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // For each subtile we have 16 threads assigned. let a_idx = u32(local_idx % subtile_size); - // K's vectrorization is 8 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized k units/space (1/8). - for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec) + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) { - // Populate shared memory for the workgroup - loadSHMA(a_global_base, kidx_v, load_row, load_col); - loadSHMB(b_global_base, kidx_v, load_row, load_col); + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } workgroupBarrier(); - var own_a0: vec4 = vec4(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]); - var own_a1: vec4 = vec4(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]); - var own_scale_a = scale_A[base_A + a_idx]; + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; if (sg_size == 16) { - var own_b0: vec4 = vec4(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]); - var own_b1: vec4 = vec4(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]); - var own_scale_b = scale_B[base_B + sg_id]; - for (var col:u32 = 0; col < 16; col++) - { - var local_scale_b = subgroupShuffle(own_scale_b, col); - local_scale_b = local_scale_b * own_scale_a; - var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col)); - local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col)); - lane_output[col] += (output_element_t(local_sum) * local_scale_b); - } + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); } else { - for (var col:u32 = 0; col < 16; col++) - { - var b0: vec4 = vec4(tile_B[base_B + col][0], tile_B[base_B + col][1]); - var b1: vec4 = vec4(tile_B[base_B + col][2], tile_B[base_B + col][3]); - var local_sum = DP4AI(own_a0, b0); - local_sum += DP4AI(own_a1, b1); - lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]); - } + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); } workgroupBarrier(); } @@ -735,11 +780,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // This creates a shader requirement that uniforms.N % 16 == 0 if (a_global < uniforms.M && b_global < uniforms.N) { - for (var i:u32 = 0; i < 4; i++) - { - let lidx = i * 4; - output[output_idx+i] = vec4(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]); - } + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; } )MAIN_FN"; @@ -812,9 +856,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context mul_program.SetDispatchGroupSize( (M + kTileSize - 1) / kTileSize, (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components)}, + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)},