@@ -229,6 +229,83 @@ class MlasSQLutGemmTest : public MlasTestBase {
229229 }
230230 }
231231
232+ // Verifies that the bias argument to MlasLutGemm is correctly broadcast-added per row.
233+ // Bias has shape [N] and must be added to every row of the [M, N] output.
234+ void TestWithBias (size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
235+ MLAS_THREADPOOL* tp = WithThreadpool ? GetMlasThreadPool () : nullptr ;
236+
237+ MlasClearLutGemmKernelConfig ();
238+
239+ const float * A = BufferA.GetBuffer (K * M);
240+ const float * B = BufferB.GetBuffer (N * K);
241+ float * C = BufferC.GetBuffer (N * M, true );
242+ float * CReference = BufferCReference.GetBuffer (N * M, true );
243+
244+ uint8_t * QuantBData = nullptr ;
245+ float * QuantBScale = nullptr ;
246+ uint8_t * QuantBZeroPoint = nullptr ;
247+
248+ {
249+ size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
250+ MlasBlockwiseQuantizedBufferSizes<BlkBitWidth>(BlkLen, /* columnwise */ true ,
251+ static_cast <int >(K), static_cast <int >(N),
252+ QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes);
253+
254+ QuantBData = BufferQuantBData.GetBuffer (QuantBDataSizeInBytes);
255+ QuantBScale = BufferQuantBScale.GetBuffer (QuantBScaleSize);
256+ if (!Symmetric) {
257+ QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer (QuantBZeroPointSizeInBytes);
258+ }
259+
260+ MlasQuantizeBlockwise<float , BlkBitWidth>(QuantBData, QuantBScale, QuantBZeroPoint,
261+ B, BlkLen, /* columnwise */ true ,
262+ static_cast <int >(K), static_cast <int >(N),
263+ static_cast <int >(N), GetMlasThreadPool ());
264+ }
265+
266+ MlasInitLutGemmKernelConfig (N, K, BlkBitWidth, BlkLen, !Symmetric);
267+
268+ size_t PackedBufSize = MlasLutGemmPackedSize (N, K, BlkBitWidth, BlkLen, !Symmetric);
269+ std::byte* PackedBuf = BufferPackedB.GetBuffer (PackedBufSize);
270+
271+ MlasLutGemmPack (
272+ N, K, BlkBitWidth, BlkLen, !Symmetric,
273+ reinterpret_cast <std::byte*>(QuantBData),
274+ QuantBScale,
275+ QuantBZeroPoint,
276+ false , // IsFloatZeroPoint
277+ PackedBuf, tp);
278+
279+ // Build a deterministic per-output-feature bias vector with non-trivial variation across N
280+ // so that any stride/transpose bug in the fused bias add will be caught.
281+ std::vector<float > Bias (N);
282+ for (size_t n = 0 ; n < N; ++n) {
283+ Bias[n] = 0 .125f + 0 .5f * static_cast <float >(n % 7 ) - 0 .25f * static_cast <float >(n % 11 );
284+ }
285+
286+ MlasLutGemm (
287+ A, BlkLen, PackedBuf, C,
288+ static_cast <int >(K), static_cast <int >(M), static_cast <int >(N),
289+ !Symmetric, tp, Bias.data ());
290+
291+ // Reference: same GEMM as the no-bias path, then broadcast-add Bias.
292+ CallReferenceGemm (M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, CReference);
293+ for (size_t m = 0 ; m < M; m++) {
294+ for (size_t n = 0 ; n < N; n++) {
295+ CReference[m * N + n] += Bias[n];
296+ }
297+ }
298+
299+ size_t f = 0 ;
300+ for (size_t m = 0 ; m < M; m++) {
301+ for (size_t n = 0 ; n < N; n++, f++) {
302+ ASSERT_TRUE (CloseEnough (C[f], CReference[f]))
303+ << " Expected: " << CReference[f] << " Actual: " << C[f] << " @[" << m << " x" << n << " ], "
304+ << " M=" << M << " , N=" << N << " , K=" << K << " , Symmetric=" << Symmetric << " , WithBias=true" ;
305+ }
306+ }
307+ }
308+
232309 public:
233310 static const char * GetTestSuiteName () {
234311 static std::string suite_name = std::string (" SQLutGemm" ) +
@@ -369,6 +446,69 @@ class SQLutGemmFloatZPTest : public MlasTestFixture<MlasSQLutGemmTest<BlkBitWidt
369446 float ZPValue_;
370447};
371448
449+ // Fixture for the fused bias path (MlasLutGemm Bias parameter).
450+ template <size_t BlkBitWidth, size_t BlkLen>
451+ class SQLutGemmBiasTest : public MlasTestFixture <MlasSQLutGemmTest<BlkBitWidth, BlkLen>> {
452+ public:
453+ explicit SQLutGemmBiasTest (size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric)
454+ : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric) {}
455+
456+ void TestBody () override {
457+ MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>::mlas_tester->TestWithBias (
458+ M_, N_, K_, WithThreadpool_, Symmetric_);
459+ }
460+
461+ static size_t RegisterSingleTest (size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
462+ if (!MlasIsLutGemmAvailable (N, K, BlkBitWidth, BlkLen)) {
463+ return 0 ;
464+ }
465+ if (N < BlkLen) {
466+ return 0 ;
467+ }
468+
469+ std::stringstream ss;
470+ ss << (WithThreadpool ? " Threaded" : " SingleThread" )
471+ << " /Bias/isSymmetric" << Symmetric
472+ << " /M" << M << " xN" << N << " xK" << K;
473+
474+ auto test_name = ss.str ();
475+
476+ testing::RegisterTest (
477+ MlasSQLutGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName (),
478+ test_name.c_str (),
479+ nullptr ,
480+ test_name.c_str (),
481+ __FILE__,
482+ __LINE__,
483+ [=]() -> MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>* {
484+ return new SQLutGemmBiasTest<BlkBitWidth, BlkLen>(M, N, K, WithThreadpool, Symmetric);
485+ });
486+
487+ return 1 ;
488+ }
489+
490+ static size_t RegisterShortExecuteTests () {
491+ size_t count = 0 ;
492+ for (bool with_threadpool : {true }) {
493+ for (bool symmetric : {true , false }) {
494+ // Cover M=1 (decode-like) and M>1 (prefill/batched) paths, and multiple chunked-N sizes
495+ // so we exercise the per-tile bias slice arithmetic across more than one tile.
496+ count += RegisterSingleTest (1 , 128 , 128 , with_threadpool, symmetric);
497+ count += RegisterSingleTest (1 , 256 , 256 , with_threadpool, symmetric);
498+ count += RegisterSingleTest (1 , 1024 , 1024 , with_threadpool, symmetric);
499+ count += RegisterSingleTest (32 , 128 , 128 , with_threadpool, symmetric);
500+ count += RegisterSingleTest (32 , 256 , 256 , with_threadpool, symmetric);
501+ }
502+ }
503+ return count;
504+ }
505+
506+ private:
507+ size_t M_, N_, K_;
508+ bool WithThreadpool_;
509+ bool Symmetric_;
510+ };
511+
372512static size_t SQLutGemmRegisterAllShortExecuteTests () {
373513 size_t count = 0 ;
374514 count += SQLutGemmShortExecuteTest<2 , 32 >::RegisterShortExecuteTests ();
@@ -379,6 +519,10 @@ static size_t SQLutGemmRegisterAllShortExecuteTests() {
379519 count += SQLutGemmFloatZPTest<2 , 32 >::RegisterShortExecuteTests ();
380520 count += SQLutGemmFloatZPTest<2 , 64 >::RegisterShortExecuteTests ();
381521 count += SQLutGemmFloatZPTest<2 , 128 >::RegisterShortExecuteTests ();
522+ // Fused bias tests
523+ count += SQLutGemmBiasTest<2 , 32 >::RegisterShortExecuteTests ();
524+ count += SQLutGemmBiasTest<2 , 64 >::RegisterShortExecuteTests ();
525+ count += SQLutGemmBiasTest<2 , 128 >::RegisterShortExecuteTests ();
382526 return count;
383527}
384528
0 commit comments