Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class MatMulNBits final : public OpKernel {
const MatMulComputeHelper& helper) const;

Status ComputeBPackedLUT(const Tensor* a,
const Tensor* bias,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;
Expand Down Expand Up @@ -641,6 +642,7 @@ Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>&

template <typename T1>
Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
const Tensor* bias,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
Expand All @@ -650,7 +652,21 @@ Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
const int N = static_cast<int>(helper.N());
const int K = static_cast<int>(helper.K());

MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool);
// Bias is fused into MlasLutGemm: it is broadcast-added per output-feature tile inside
// the same parallel loop that runs the GEMM, so the bias addition is multi-threaded and
// operates on data that is still hot in cache. MlasLutGemm currently only supports fp32
// activations/outputs; reject any other type when a bias is present.
const float* bias_data = nullptr;
if (bias != nullptr) {
if constexpr (std::is_same_v<T1, float>) {
bias_data = bias->Data<float>();
} else {
Comment thread
hariharans29 marked this conversation as resolved.
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"MatMulNBits LUT GEMM path does not support non-fp32 bias.");
}
}

MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool, bias_data);
return Status::OK();
}

Expand Down Expand Up @@ -1227,7 +1243,7 @@ Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
// with B directly too.
if (prefer_lut_gemm_) {
return ComputeBPackedLUT(a, y, thread_pool, helper);
return ComputeBPackedLUT(a, bias, y, thread_pool, helper);
}
Comment thread
hariharans29 marked this conversation as resolved.

if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ MlasLutGemmPack(
* @param[in] N column size of matrix B
* @param[in] HasZeroPoint whether zero points are provided
* @param[in] threadpool thread pool for parallel computation
* @param[in] Bias optional bias vector of length N (one value per output feature).
* When non-null, it is broadcast-added to every row of the [M, N]
* output. The addition is fused into the per-tile compute loop so
* it inherits the same multi-threading as the GEMM itself.
* Pass nullptr if no bias is to be applied.
*/
void MLASCALL
MlasLutGemm(
Expand All @@ -369,5 +374,6 @@ MlasLutGemm(
size_t M,
size_t N,
bool HasZeroPoint,
MLAS_THREADPOOL* threadpool
MLAS_THREADPOOL* threadpool,
const float* Bias = nullptr
);
19 changes: 18 additions & 1 deletion onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ MlasLutGemm(
size_t M, // batch size (number of rows in activation)
size_t N,
bool HasZeroPoint,
MLAS_THREADPOOL* threadpool
MLAS_THREADPOOL* threadpool,
const float* Bias
)
{
// adapted from ggml_backend_tmac_mul_mat
Expand Down Expand Up @@ -616,6 +617,22 @@ MlasLutGemm(
BlkLen, // Weight quantization group size
HasZeroPoint // Whether zero points are used
);

// Fused bias add: broadcast the per-output-feature Bias[N] slice into the
// just-written tile. The output tile we just wrote is `ChunkSize0` contiguous
// floats at `act_output + dst_offset`, corresponding to output feature indices
// [ichunk0 * ChunkSize0, ichunk0 * ChunkSize0 + ChunkSize0). The bias slice
// therefore aligns at `Bias + ichunk0 * ChunkSize0`. Doing this here (rather
// than as a separate post-pass) keeps the data hot in cache and inherits the
// existing per-chunk parallelism for free.
if (Bias != nullptr) {
const size_t tile_n = ir0_end - ir0_start;
float* y_tile = act_output + dst_offset;
const float* bias_tile = Bias + ichunk0 * ChunkSize0;
for (size_t i = 0; i < tile_n; ++i) {
y_tile[i] += bias_tile[i];
}
}
}
}
}
Expand Down
60 changes: 58 additions & 2 deletions onnxruntime/test/contrib_ops/matmul_2bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) {

template <typename AType>
void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
bool has_zero_point, float abs_error = 0.15f, float rel_error = 0.05f) {
bool has_zero_point, bool has_bias = false,
float abs_error = 0.15f, float rel_error = 0.05f) {
if (K % 32 != 0 || N % 128 != 0 || block_size % 32 != 0) {
GTEST_SKIP() << "LUT GEMM requires K multiple of 32, N multiple of 128, block_size multiple of 32";
}
Expand Down Expand Up @@ -308,13 +309,27 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
static_cast<int32_t>(N),
tp);

// Optional per-output-feature bias with non-trivial variation across N so any stride/transpose
// bug in the fused bias add (inside MlasLutGemm) is observable.
std::vector<float> bias;
if (has_bias) {
bias.resize(static_cast<size_t>(N));
for (int64_t n = 0; n < N; ++n) {
bias[static_cast<size_t>(n)] =
0.125f + 0.5f * static_cast<float>(n % 7) - 0.25f * static_cast<float>(n % 11);
}
}

std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k];
}
if (has_bias) {
sum += bias[static_cast<size_t>(n)];
}
expected_vals[m * N + n] = sum;
}
}
Expand Down Expand Up @@ -344,7 +359,16 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
}

test.AddOptionalInputEdge<int32_t>();
test.AddOptionalInputEdge<AType>();

if (has_bias) {
if constexpr (std::is_same<AType, float>::value) {
test.AddInput<AType>("bias", {N}, bias, true);
} else {
test.AddOptionalInputEdge<AType>();
}
} else {
test.AddOptionalInputEdge<AType>();
}

if constexpr (std::is_same<AType, float>::value) {
test.AddOutput<AType>("Y", {M, N}, expected_vals);
Expand Down Expand Up @@ -405,6 +429,38 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256) {
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, true);
}

// Fused-bias tests — verify the Bias parameter to MlasLutGemm is broadcast-added correctly.
// These are the regression tests for the bug where the LUT path silently dropped the optional
// `bias` input of MatMulNBits.
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/true, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256_BlkLen64_Bias) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256_BlkLen64_Bias) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/true, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x256_BlkLen128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 256, 128, /*has_zero_point=*/true, /*has_bias=*/true);
}

// Batched (M>1) bias tests — exercise the per-row bias broadcast across many activation rows.
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_Batch32_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(32, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) {
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true);
}

// Float zero point tests — directed QAD scenario (zp=1.5)
void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) {
RandomValueGenerator random{1234};
Expand Down
144 changes: 144 additions & 0 deletions onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,83 @@ class MlasSQLutGemmTest : public MlasTestBase {
}
}

// Verifies that the bias argument to MlasLutGemm is correctly broadcast-added per row.
// Bias has shape [N] and must be added to every row of the [M, N] output.
void TestWithBias(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
MLAS_THREADPOOL* tp = WithThreadpool ? GetMlasThreadPool() : nullptr;

MlasClearLutGemmKernelConfig();

const float* A = BufferA.GetBuffer(K * M);
const float* B = BufferB.GetBuffer(N * K);
float* C = BufferC.GetBuffer(N * M, true);
float* CReference = BufferCReference.GetBuffer(N * M, true);

uint8_t* QuantBData = nullptr;
float* QuantBScale = nullptr;
uint8_t* QuantBZeroPoint = nullptr;

{
size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes;
MlasBlockwiseQuantizedBufferSizes<BlkBitWidth>(BlkLen, /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes);

QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes);
QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize);
if (!Symmetric) {
QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes);
}

MlasQuantizeBlockwise<float, BlkBitWidth>(QuantBData, QuantBScale, QuantBZeroPoint,
B, BlkLen, /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
static_cast<int>(N), GetMlasThreadPool());
}

MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, !Symmetric);

size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, !Symmetric);
std::byte* PackedBuf = BufferPackedB.GetBuffer(PackedBufSize);

MlasLutGemmPack(
N, K, BlkBitWidth, BlkLen, !Symmetric,
reinterpret_cast<std::byte*>(QuantBData),
QuantBScale,
QuantBZeroPoint,
false, // IsFloatZeroPoint
PackedBuf, tp);

// Build a deterministic per-output-feature bias vector with non-trivial variation across N
// so that any stride/transpose bug in the fused bias add will be caught.
std::vector<float> Bias(N);
for (size_t n = 0; n < N; ++n) {
Bias[n] = 0.125f + 0.5f * static_cast<float>(n % 7) - 0.25f * static_cast<float>(n % 11);
}

MlasLutGemm(
A, BlkLen, PackedBuf, C,
static_cast<int>(K), static_cast<int>(M), static_cast<int>(N),
!Symmetric, tp, Bias.data());

// Reference: same GEMM as the no-bias path, then broadcast-add Bias.
CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, CReference);
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
CReference[m * N + n] += Bias[n];
}
}

size_t f = 0;
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++, f++) {
ASSERT_TRUE(CloseEnough(C[f], CReference[f]))
<< "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], "
<< "M=" << M << ", N=" << N << ", K=" << K << ", Symmetric=" << Symmetric << ", WithBias=true";
}
}
}

public:
static const char* GetTestSuiteName() {
static std::string suite_name = std::string("SQLutGemm") +
Expand Down Expand Up @@ -369,6 +446,69 @@ class SQLutGemmFloatZPTest : public MlasTestFixture<MlasSQLutGemmTest<BlkBitWidt
float ZPValue_;
};

// Fixture for the fused bias path (MlasLutGemm Bias parameter).
template <size_t BlkBitWidth, size_t BlkLen>
class SQLutGemmBiasTest : public MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>> {
public:
explicit SQLutGemmBiasTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric)
: M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric) {}

void TestBody() override {
MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>::mlas_tester->TestWithBias(
M_, N_, K_, WithThreadpool_, Symmetric_);
}

static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) {
if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) {
return 0;
}
if (N < BlkLen) {
return 0;
}

std::stringstream ss;
ss << (WithThreadpool ? "Threaded" : "SingleThread")
<< "/Bias/isSymmetric" << Symmetric
<< "/M" << M << "xN" << N << "xK" << K;

auto test_name = ss.str();

testing::RegisterTest(
MlasSQLutGemmTest<BlkBitWidth, BlkLen>::GetTestSuiteName(),
test_name.c_str(),
nullptr,
test_name.c_str(),
__FILE__,
__LINE__,
[=]() -> MlasTestFixture<MlasSQLutGemmTest<BlkBitWidth, BlkLen>>* {
return new SQLutGemmBiasTest<BlkBitWidth, BlkLen>(M, N, K, WithThreadpool, Symmetric);
});

return 1;
}

static size_t RegisterShortExecuteTests() {
size_t count = 0;
for (bool with_threadpool : {true}) {
for (bool symmetric : {true, false}) {
// Cover M=1 (decode-like) and M>1 (prefill/batched) paths, and multiple chunked-N sizes
// so we exercise the per-tile bias slice arithmetic across more than one tile.
count += RegisterSingleTest(1, 128, 128, with_threadpool, symmetric);
count += RegisterSingleTest(1, 256, 256, with_threadpool, symmetric);
count += RegisterSingleTest(1, 1024, 1024, with_threadpool, symmetric);
count += RegisterSingleTest(32, 128, 128, with_threadpool, symmetric);
count += RegisterSingleTest(32, 256, 256, with_threadpool, symmetric);
}
}
return count;
}

private:
size_t M_, N_, K_;
bool WithThreadpool_;
bool Symmetric_;
};

static size_t SQLutGemmRegisterAllShortExecuteTests() {
size_t count = 0;
count += SQLutGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests();
Expand All @@ -379,6 +519,10 @@ static size_t SQLutGemmRegisterAllShortExecuteTests() {
count += SQLutGemmFloatZPTest<2, 32>::RegisterShortExecuteTests();
count += SQLutGemmFloatZPTest<2, 64>::RegisterShortExecuteTests();
count += SQLutGemmFloatZPTest<2, 128>::RegisterShortExecuteTests();
// Fused bias tests
count += SQLutGemmBiasTest<2, 32>::RegisterShortExecuteTests();
count += SQLutGemmBiasTest<2, 64>::RegisterShortExecuteTests();
count += SQLutGemmBiasTest<2, 128>::RegisterShortExecuteTests();
return count;
}

Expand Down
Loading