Skip to content

Commit

Permalink
switching a to 32 block size works
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Feb 11, 2025
1 parent 1decc48 commit 654ce0c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,22 +536,22 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
shader.AddOutput("scales", ShaderUsage::UseUniform);

shader.MainFunctionBody() << R"MAIN_FN(
var local_a : array<vec4<input_a_element_t>, 32>;
var local_a : array<vec4<input_a_element_t>, 8>;
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
for (var idx:u32=0;idx<32;idx+=1)
for (var idx:u32=0;idx<8;idx+=1)
{
local_a[idx] = input_a[workgroup_id.x*32 + idx];
local_a[idx] = input_a[global_id.x*8+idx];
max_value = max(max_value, abs(local_a[idx]));
}
var scale = max(max_value.x, max_value.y);
scale = max(scale, max_value.z);
scale = max(scale, max_value.w);
for (var idx:u32=0;idx<32;idx+=1)
for (var idx:u32=0;idx<8;idx+=1)
{
output[workgroup_id.x*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
output[global_id.x*8+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
}
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_id.x] = scale/127;
scales[global_id.x] = scale/127;
)MAIN_FN";
return Status::OK();
}
Expand Down Expand Up @@ -614,7 +614,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (col == 0)
{
// kidx_v - covers 16 values of k
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
scale_A[row] = scales_a[a_global*(uniforms.K/block_size) + kidx_v/2];
}
}
Expand Down Expand Up @@ -817,7 +817,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
constexpr uint32_t kVec2Components = 2;
constexpr uint32_t kU32Components = 4;

constexpr uint32_t kBlockSizeA = 128;
constexpr uint32_t kBlockSizeA = 32;
DP4AMatMulQuantizeProgram quantize_program;
quantize_program.SetWorkgroupSize(1);
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
Expand Down

0 comments on commit 654ce0c

Please sign in to comment.