Skip to content

Commit c39c9e3

Browse files
sushraja-msftashrit-ms
authored andcommitted
WIP: DP4AMatMul fix matmul for subgoup size 64 GPUs (#23637)
### Description This change moves away from using subgroup ops for quantization. This is because on AMD GPUs subgroup size is 64 and that is not handled in our quantization function, resulting in garbage output. Implementing subgroup size 64 quantization requires changing the workgroup size and then implementing support for subgroup size 128 becomes a challenge. With the new implementation perf on intel ALD remains about the same 4.36s for 1000K prefill. Tests for this change are present here https://github.com/microsoft/onnxruntime/blob/e66650350b85cb5e3a408f6576fe6a7f4f4ddebc/onnxruntime/test/contrib_ops/matmul_4bits_test.cc However, to trigger the current issue they must be run on a GPU with subgroup size 64.
1 parent 755f28a commit c39c9e3

File tree

1 file changed

+13
-32
lines changed

1 file changed

+13
-32
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -535,42 +535,23 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
535535
shader.AddOutput("output", ShaderUsage::UseUniform);
536536
shader.AddOutput("scales", ShaderUsage::UseUniform);
537537

538-
shader.AdditionalImplementation() << R"ADDNL_FN(
539-
var<workgroup> max_values : array<input_a_element_t, 4>;
540-
)ADDNL_FN";
541-
542538
shader.MainFunctionBody() << R"MAIN_FN(
543-
var local_a = input_a[global_idx];
544-
var max_val = subgroupMax(abs(local_a));
545-
var max_temp = max(max_val.xy, max_val.zw);
546-
var scale = max(max_temp[0], max_temp[1]);
547-
if (local_idx % sg_size == 0) {
548-
max_values[local_idx / sg_size] = scale;
549-
}
550-
workgroupBarrier();
551-
552-
if (sg_size == 8)
539+
var local_a : array<vec4<input_a_element_t>, 32>;
540+
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
541+
for (var idx:u32=0;idx<32;idx+=1)
553542
{
554-
scale = max(max_values[0], max_values[1]);
555-
scale = max(scale, max_values[2]);
556-
scale = max(scale, max_values[3]);
543+
local_a[idx] = input_a[workgroup_id.x*32 + idx];
544+
max_value = max(max_value, abs(local_a[idx]));
557545
}
558-
else if (sg_size == 16)
559-
{
560-
scale = max(max_values[0], max_values[1]);
561-
}
562-
else
563-
{
564-
scale = max_values[0];
565-
}
566-
567-
var norm_a = local_a/scale;
568-
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
569-
if (local_idx == 0)
546+
var scale = max(max_value.x, max_value.y);
547+
scale = max(scale, max_value.z);
548+
scale = max(scale, max_value.w);
549+
for (var idx:u32=0;idx<32;idx+=1)
570550
{
571-
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
572-
scales[workgroup_idx] = scale/127;
551+
output[workgroup_id.x*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
573552
}
553+
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
554+
scales[workgroup_id.x] = scale/127;
574555
)MAIN_FN";
575556
return Status::OK();
576557
}
@@ -838,7 +819,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
838819

839820
constexpr uint32_t kBlockSizeA = 128;
840821
DP4AMatMulQuantizeProgram quantize_program;
841-
quantize_program.SetWorkgroupSize(32);
822+
quantize_program.SetWorkgroupSize(1);
842823
quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1);
843824
TensorShape a_quant_shape{1, M, K / kU32Components};
844825
Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), a_quant_shape);

0 commit comments

Comments
 (0)