[MLAS] Fix MatMulNBits CPU LUT GEMM path to apply optional bias#28742
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a correctness issue where the LUT GEMM path (MLAS MlasLutGemm) was silently ignoring the optional bias input, by adding a fused bias-add in the MLAS kernel, wiring the MatMulNBits CPU kernel to pass bias through, and adding targeted regression tests.
Changes:
- Extend
MlasLutGemmto accept an optionalBias[N]pointer and fuse bias broadcast-add into the per-tile compute loop. - Update MatMulNBits LUT path to forward bias to
MlasLutGemm. - Add MLAS-kernel-level and CPU-EP operator-level tests validating correct bias broadcast behavior across multiple shapes/configurations.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp | Adds MLAS unit tests that compare fused-bias results against a reference GEMM + broadcast-add. |
| onnxruntime/test/contrib_ops/matmul_2bits_test.cc | Adds MatMulNBits LUT-path regression tests that include a non-trivial bias vector and validate broadcast-add. |
| onnxruntime/core/mlas/lib/qlutgemm.cpp | Implements fused bias broadcast-add in the LUT GEMM kernel after each output tile is computed. |
| onnxruntime/core/mlas/inc/mlas_qnbit.h | Extends the public MlasLutGemm API signature/docs to include an optional bias pointer. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Passes bias to MlasLutGemm from the MatMulNBits LUT path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
Review
This correctly fixes the bug where the MatMulNBits LUT GEMM path (mlas.use_lut_gemm=1) silently dropped the optional bias input — the common case for the MatMulNBits+Add fusion. The fix threads bias into MlasLutGemm and broadcast-adds it inside the existing per-tile parallel loop, and additionally gates prefer_lut_gemm_ on T1 == float so the fp32-only kernel is never reached for MLFloat16.
Correctness — verified:
- Output
Cis[M, N]row-major (dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0,OutputRows == N). The bias sliceBias + ichunk0 * ChunkSize0of lengthtile_n = ir0_end - ir0_startmaps exactly to the output-feature range just written byComputeGemm, and stays within[0, N). ChunkSize0 = N / n_tiles_numdividesNexactly, so every tile is full andBias[N]is never over-read.- Each chunk writes a disjoint output region, so the fused bias add is race-free.
Performance: bias add is fused into the hot loop (cache-resident) and inherits the existing per-chunk parallelism with no extra sync or allocation.
Tests: strong coverage at both MLAS (SQLutGemmBiasTest) and op level (MatMulNBitsLutGemm.*Bias) across BlkLen 32/64/128, M in {1,32}, symmetric/asymmetric. The n%7/n%11-varying bias vector would catch any stride/transpose bug.
Note: the earlier automated comment about MatMulNBits<MLFloat16> reaching the LUT path is already addressed by the std::is_same_v<T1, float> gating in this PR.
One minor (optional) observation inline.
Description
The MatMulNBits LUT GEMM path (enabled by session config
mlas.use_lut_gemm=1) silently dropped the optional bias input,
causing incorrect outputs whenever the graph optimizer's
MatMulNBits+Add fusion produced MatMulNBits nodes with bias
populated. This is the common case for transformer models with
fused linear+bias layers.
Fuses the bias broadcast-add into the existing per-tile parallel
loop in MlasLutGemm so the addition is multi-threaded and
operates on data already hot in cache.
Also gates prefer_lut_gemm_ to T1==float since MlasLutGemm
interprets A/C as float* internally.
Adds MLAS-level (test_sqlutgemm SQLutGemmBiasTest) and op-level
(matmul_2bits_test MatMulNBitsLutGemm.*Bias) coverage for the
bias-present case across BlkLen 32/64/128, M in {1,32}, both
symmetric and asymmetric quantization.
Motivation and Context
Fix correctness issues with LUT Gemm