diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index c79efee65e5c5..e6cf911e954d7 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -535,42 +535,23 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const shader.AddOutput("output", ShaderUsage::UseUniform); shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - var max_values : array; - )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - var local_a = input_a[global_idx]; - var max_val = subgroupMax(abs(local_a)); - var max_temp = max(max_val.xy, max_val.zw); - var scale = max(max_temp[0], max_temp[1]); - if (local_idx % sg_size == 0) { - max_values[local_idx / sg_size] = scale; - } - workgroupBarrier(); - - if (sg_size == 8) + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) { - scale = max(max_values[0], max_values[1]); - scale = max(scale, max_values[2]); - scale = max(scale, max_values[3]); + local_a[idx] = input_a[workgroup_id.x*32 + idx]; + max_value = max(max_value, abs(local_a[idx])); } - else if (sg_size == 16) - { - scale = max(max_values[0], max_values[1]); - } - else - { - scale = max_values[0]; - } - - var norm_a = local_a/scale; - output[global_idx] = pack4x8snorm(vec4(norm_a)); - if (local_idx == 0) + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) { - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; + output[workgroup_id.x*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_id.x] = scale/127; )MAIN_FN"; return Status::OK(); } @@ -838,7 +819,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t kBlockSizeA = 128; DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(32); + quantize_program.SetWorkgroupSize(1); quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); TensorShape a_quant_shape{1, M, K / kU32Components}; Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape);