Skip to content

Commit 16dc03b

Browse files
committed
MLAS commits
1 parent a9cfc47 commit 16dc03b

5 files changed

Lines changed: 245 additions & 6 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class MatMulNBits final : public OpKernel {
192192
const MatMulComputeHelper& helper) const;
193193

194194
Status ComputeBPackedLUT(const Tensor* a,
195+
const Tensor* bias,
195196
Tensor* y,
196197
concurrency::ThreadPool* thread_pool,
197198
const MatMulComputeHelper& helper) const;
@@ -641,6 +642,7 @@ Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>&
641642

642643
template <typename T1>
643644
Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
645+
const Tensor* bias,
644646
Tensor* y,
645647
concurrency::ThreadPool* thread_pool,
646648
const MatMulComputeHelper& helper) const {
@@ -650,7 +652,21 @@ Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
650652
const int N = static_cast<int>(helper.N());
651653
const int K = static_cast<int>(helper.K());
652654

653-
MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool);
655+
// Bias is fused into MlasLutGemm: it is broadcast-added per output-feature tile inside
656+
// the same parallel loop that runs the GEMM, so the bias addition is multi-threaded and
657+
// operates on data that is still hot in cache. MlasLutGemm currently only supports fp32
658+
// activations/outputs; reject any other type when a bias is present.
659+
const float* bias_data = nullptr;
660+
if (bias != nullptr) {
661+
if constexpr (std::is_same_v<T1, float>) {
662+
bias_data = bias->Data<float>();
663+
} else {
664+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
665+
"MatMulNBits LUT GEMM path does not support non-fp32 bias.");
666+
}
667+
}
668+
669+
MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool, bias_data);
654670
return Status::OK();
655671
}
656672

@@ -1227,7 +1243,7 @@ Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
12271243
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
12281244
// with B directly too.
12291245
if (prefer_lut_gemm_) {
1230-
return ComputeBPackedLUT(a, y, thread_pool, helper);
1246+
return ComputeBPackedLUT(a, bias, y, thread_pool, helper);
12311247
}
12321248

12331249
if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {

onnxruntime/core/mlas/inc/mlas_qnbit.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ MlasLutGemmPack(
358358
* @param[in] N column size of matrix B
359359
* @param[in] HasZeroPoint whether zero points are provided
360360
* @param[in] threadpool thread pool for parallel computation
361+
* @param[in] Bias optional bias vector of length N (one value per output feature).
362+
* When non-null, it is broadcast-added to every row of the [M, N]
363+
* output. The addition is fused into the per-tile compute loop so
364+
* it inherits the same multi-threading as the GEMM itself.
365+
* Pass nullptr if no bias is to be applied.
361366
*/
362367
void MLASCALL
363368
MlasLutGemm(
@@ -369,5 +374,6 @@ MlasLutGemm(
369374
size_t M,
370375
size_t N,
371376
bool HasZeroPoint,
372-
MLAS_THREADPOOL* threadpool
377+
MLAS_THREADPOOL* threadpool,
378+
const float* Bias = nullptr
373379
);

onnxruntime/core/mlas/lib/qlutgemm.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ MlasLutGemm(
420420
size_t M, // batch size (number of rows in activation)
421421
size_t N,
422422
bool HasZeroPoint,
423-
MLAS_THREADPOOL* threadpool
423+
MLAS_THREADPOOL* threadpool,
424+
const float* Bias
424425
)
425426
{
426427
// adapted from ggml_backend_tmac_mul_mat
@@ -616,6 +617,22 @@ MlasLutGemm(
616617
BlkLen, // Weight quantization group size
617618
HasZeroPoint // Whether zero points are used
618619
);
620+
621+
// Fused bias add: broadcast the per-output-feature Bias[N] slice into the
622+
// just-written tile. The output tile we just wrote is `ChunkSize0` contiguous
623+
// floats at `act_output + dst_offset`, corresponding to output feature indices
624+
// [ichunk0 * ChunkSize0, ichunk0 * ChunkSize0 + ChunkSize0). The bias slice
625+
// therefore aligns at `Bias + ichunk0 * ChunkSize0`. Doing this here (rather
626+
// than as a separate post-pass) keeps the data hot in cache and inherits the
627+
// existing per-chunk parallelism for free.
628+
if (Bias != nullptr) {
629+
const size_t tile_n = ir0_end - ir0_start;
630+
float* y_tile = act_output + dst_offset;
631+
const float* bias_tile = Bias + ichunk0 * ChunkSize0;
632+
for (size_t i = 0; i < tile_n; ++i) {
633+
y_tile[i] += bias_tile[i];
634+
}
635+
}
619636
}
620637
}
621638
}

onnxruntime/test/contrib_ops/matmul_2bits_test.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) {
254254

255255
template <typename AType>
256256
void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
257-
bool has_zero_point, float abs_error = 0.15f, float rel_error = 0.05f) {
257+
bool has_zero_point, bool has_bias = false,
258+
float abs_error = 0.15f, float rel_error = 0.05f) {
258259
if (K % 32 != 0 || N % 128 != 0 || block_size % 32 != 0) {
259260
GTEST_SKIP() << "LUT GEMM requires K multiple of 32, N multiple of 128, block_size multiple of 32";
260261
}
@@ -308,13 +309,27 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
308309
static_cast<int32_t>(N),
309310
tp);
310311

312+
// Optional per-output-feature bias with non-trivial variation across N so any stride/transpose
313+
// bug in the fused bias add (inside MlasLutGemm) is observable.
314+
std::vector<float> bias;
315+
if (has_bias) {
316+
bias.resize(static_cast<size_t>(N));
317+
for (int64_t n = 0; n < N; ++n) {
318+
bias[static_cast<size_t>(n)] =
319+
0.125f + 0.5f * static_cast<float>(n % 7) - 0.25f * static_cast<float>(n % 11);
320+
}
321+
}
322+
311323
std::vector<float> expected_vals(M * N);
312324
for (int64_t m = 0; m < M; m++) {
313325
for (int64_t n = 0; n < N; n++) {
314326
float sum = 0.0f;
315327
for (int64_t k = 0; k < K; k++) {
316328
sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k];
317329
}
330+
if (has_bias) {
331+
sum += bias[static_cast<size_t>(n)];
332+
}
318333
expected_vals[m * N + n] = sum;
319334
}
320335
}
@@ -344,7 +359,16 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
344359
}
345360

346361
test.AddOptionalInputEdge<int32_t>();
347-
test.AddOptionalInputEdge<AType>();
362+
363+
if (has_bias) {
364+
if constexpr (std::is_same<AType, float>::value) {
365+
test.AddInput<AType>("bias", {N}, bias, true);
366+
} else {
367+
test.AddOptionalInputEdge<AType>();
368+
}
369+
} else {
370+
test.AddOptionalInputEdge<AType>();
371+
}
348372

349373
if constexpr (std::is_same<AType, float>::value) {
350374
test.AddOutput<AType>("Y", {M, N}, expected_vals);
@@ -405,6 +429,38 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256) {
405429
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, true);
406430
}
407431

432+
// Fused-bias tests — verify the Bias parameter to MlasLutGemm is broadcast-added correctly.
433+
// These are the regression tests for the bug where the LUT path silently dropped the optional
434+
// `bias` input of MatMulNBits.
435+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x128_Bias) {
436+
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
437+
}
438+
439+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x128_Bias) {
440+
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/true, /*has_bias=*/true);
441+
}
442+
443+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256_BlkLen64_Bias) {
444+
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/false, /*has_bias=*/true);
445+
}
446+
447+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256_BlkLen64_Bias) {
448+
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/true, /*has_bias=*/true);
449+
}
450+
451+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x256_BlkLen128_Bias) {
452+
TestMatMul2BitsLutGemm<float>(1, 128, 256, 128, /*has_zero_point=*/true, /*has_bias=*/true);
453+
}
454+
455+
// Batched (M>1) bias tests — exercise the per-row bias broadcast across many activation rows.
456+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_Batch32_128x128_Bias) {
457+
TestMatMul2BitsLutGemm<float>(32, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
458+
}
459+
460+
TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) {
461+
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true);
462+
}
463+
408464
// Float zero point tests — directed QAD scenario (zp=1.5)
409465
void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) {
410466
RandomValueGenerator random{1234};

onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,83 @@ class MlasSQLutGemmTest : public MlasTestBase {
229229
}
230230
}
231231

232+
// Verifies that the bias argument to MlasLutGemm is correctly broadcast-added per row.
233+
// Bias has shape [N] and must be added to every row of the [M, N] output.
234+
void TestWithBias(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
235+
MLAS_THREADPOOL* tp = WithThreadpool ? GetMlasThreadPool() : nullptr;
236+
237+
MlasClearLutGemmKernelConfig();
238+
239+
const float* A = BufferA.GetBuffer(K * M);
240+
const float* B = BufferB.GetBuffer(N * K);
241+
float* C = BufferC.GetBuffer(N * M, true);
242+
float* CReference = BufferCReference.GetBuffer(N * M, true);
243+
244+
uint8_t* QuantBData = nullptr;
245+
float* QuantBScale = nullptr;
246+
uint8_t* QuantBZeroPoint = nullptr;
247+
248+
{
249+
size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
250+
MlasBlockwiseQuantizedBufferSizes<BlkBitWidth>(BlkLen, /* columnwise */ true,
251+
static_cast<int>(K), static_cast<int>(N),
252+
QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes);
253+
254+
QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes);
255+
QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize);
256+
if (!Symmetric) {
257+
QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes);
258+
}
259+
260+
MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData, QuantBScale, QuantBZeroPoint,
261+
B, BlkLen, /* columnwise */ true,
262+
static_cast<int>(K), static_cast<int>(N),
263+
static_cast<int>(N), GetMlasThreadPool());
264+
}
265+
266+
MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, !Symmetric);
267+
268+
size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, !Symmetric);
269+
std::byte* PackedBuf = BufferPackedB.GetBuffer(PackedBufSize);
270+
271+
MlasLutGemmPack(
272+
N, K, BlkBitWidth, BlkLen, !Symmetric,
273+
reinterpret_cast<std::byte*>(QuantBData),
274+
QuantBScale,
275+
QuantBZeroPoint,
276+
false, // IsFloatZeroPoint
277+
PackedBuf, tp);
278+
279+
// Build a deterministic per-output-feature bias vector with non-trivial variation across N
280+
// so that any stride/transpose bug in the fused bias add will be caught.
281+
std::vector<float> Bias(N);
282+
for (size_t n = 0; n < N; ++n) {
283+
Bias[n] = 0.125f + 0.5f * static_cast<float>(n % 7) - 0.25f * static_cast<float>(n % 11);
284+
}
285+
286+
MlasLutGemm(
287+
A, BlkLen, PackedBuf, C,
288+
static_cast<int>(K), static_cast<int>(M), static_cast<int>(N),
289+
!Symmetric, tp, Bias.data());
290+
291+
// Reference: same GEMM as the no-bias path, then broadcast-add Bias.
292+
CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, CReference);
293+
for (size_t m = 0; m < M; m++) {
294+
for (size_t n = 0; n < N; n++) {
295+
CReference[m * N + n] += Bias[n];
296+
}
297+
}
298+
299+
size_t f = 0;
300+
for (size_t m = 0; m < M; m++) {
301+
for (size_t n = 0; n < N; n++, f++) {
302+
ASSERT_TRUE(CloseEnough(C[f], CReference[f]))
303+
<< "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], "
304+
<< "M=" << M << ", N=" << N << ", K=" << K << ", Symmetric=" << Symmetric << ", WithBias=true";
305+
}
306+
}
307+
}
308+
232309
public:
233310
static const char* GetTestSuiteName() {
234311
static std::string suite_name = std::string("SQLutGemm") +
@@ -369,6 +446,69 @@ class SQLutGemmFloatZPTest : public MlasTestFixture<MlasSQLutGemmTest<BlkBitWidt
369446
float ZPValue_;
370447
};
371448

449+
// Fixture for the fused bias path (MlasLutGemm Bias parameter).
450+
template <size_t BlkBitWidth, size_t BlkLen>
451+
class SQLutGemmBiasTest : public MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>> {
452+
public:
453+
explicit SQLutGemmBiasTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric)
454+
: M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric) {}
455+
456+
void TestBody() override {
457+
MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>::mlas_tester->TestWithBias(
458+
M_, N_, K_, WithThreadpool_, Symmetric_);
459+
}
460+
461+
static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
462+
if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) {
463+
return 0;
464+
}
465+
if (N < BlkLen) {
466+
return 0;
467+
}
468+
469+
std::stringstream ss;
470+
ss << (WithThreadpool ? "Threaded" : "SingleThread")
471+
<< "/Bias/isSymmetric" << Symmetric
472+
<< "/M" << M << "xN" << N << "xK" << K;
473+
474+
auto test_name = ss.str();
475+
476+
testing::RegisterTest(
477+
MlasSQLutGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
478+
test_name.c_str(),
479+
nullptr,
480+
test_name.c_str(),
481+
__FILE__,
482+
__LINE__,
483+
[=]() -> MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>* {
484+
return new SQLutGemmBiasTest<BlkBitWidth, BlkLen>(M, N, K, WithThreadpool, Symmetric);
485+
});
486+
487+
return 1;
488+
}
489+
490+
static size_t RegisterShortExecuteTests() {
491+
size_t count = 0;
492+
for (bool with_threadpool : {true}) {
493+
for (bool symmetric : {true, false}) {
494+
// Cover M=1 (decode-like) and M>1 (prefill/batched) paths, and multiple chunked-N sizes
495+
// so we exercise the per-tile bias slice arithmetic across more than one tile.
496+
count += RegisterSingleTest(1, 128, 128, with_threadpool, symmetric);
497+
count += RegisterSingleTest(1, 256, 256, with_threadpool, symmetric);
498+
count += RegisterSingleTest(1, 1024, 1024, with_threadpool, symmetric);
499+
count += RegisterSingleTest(32, 128, 128, with_threadpool, symmetric);
500+
count += RegisterSingleTest(32, 256, 256, with_threadpool, symmetric);
501+
}
502+
}
503+
return count;
504+
}
505+
506+
private:
507+
size_t M_, N_, K_;
508+
bool WithThreadpool_;
509+
bool Symmetric_;
510+
};
511+
372512
static size_t SQLutGemmRegisterAllShortExecuteTests() {
373513
size_t count = 0;
374514
count += SQLutGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests();
@@ -379,6 +519,10 @@ static size_t SQLutGemmRegisterAllShortExecuteTests() {
379519
count += SQLutGemmFloatZPTest<2, 32>::RegisterShortExecuteTests();
380520
count += SQLutGemmFloatZPTest<2, 64>::RegisterShortExecuteTests();
381521
count += SQLutGemmFloatZPTest<2, 128>::RegisterShortExecuteTests();
522+
// Fused bias tests
523+
count += SQLutGemmBiasTest<2, 32>::RegisterShortExecuteTests();
524+
count += SQLutGemmBiasTest<2, 64>::RegisterShortExecuteTests();
525+
count += SQLutGemmBiasTest<2, 128>::RegisterShortExecuteTests();
382526
return count;
383527
}
384528

0 commit comments

Comments
 (0)