Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class MatMulNBits final : public OpKernel {
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()},
has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()},
prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" &&
prefer_lut_gemm_{std::is_same_v<T1, float> &&
info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" &&
MlasIsLutGemmAvailable(narrow<size_t>(info.GetAttr<int64_t>("N")),
narrow<size_t>(info.GetAttr<int64_t>("K")),
narrow<size_t>(info.GetAttr<int64_t>("bits")),
Expand Down Expand Up @@ -192,6 +193,7 @@ class MatMulNBits final : public OpKernel {
const MatMulComputeHelper& helper) const;

Status ComputeBPackedLUT(const Tensor* a,
const Tensor* bias,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const;
Expand Down Expand Up @@ -641,6 +643,7 @@ Status MatMulNBits<T1>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>&

template <typename T1>
Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
const Tensor* bias,
Tensor* y,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
Expand All @@ -650,7 +653,21 @@ Status MatMulNBits<T1>::ComputeBPackedLUT(const Tensor* a,
const int N = static_cast<int>(helper.N());
const int K = static_cast<int>(helper.K());

MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool);
// Bias is fused into MlasLutGemm: it is broadcast-added per output-feature tile inside
// the same parallel loop that runs the GEMM, so the bias addition is multi-threaded and
// operates on data that is still hot in cache. MlasLutGemm currently only supports fp32
// activations/outputs; reject any other type when a bias is present.
const float* bias_data = nullptr;
if (bias != nullptr) {
if constexpr (std::is_same_v<T1, float>) {
bias_data = bias->Data<float>();
} else {
Comment thread
hariharans29 marked this conversation as resolved.
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"MatMulNBits LUT GEMM path does not support non-fp32 bias.");
}
}

MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool, bias_data);
return Status::OK();
}

Expand Down Expand Up @@ -1227,7 +1244,7 @@ Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
// MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch()
// with B directly too.
if (prefer_lut_gemm_) {
return ComputeBPackedLUT(a, y, thread_pool, helper);
return ComputeBPackedLUT(a, bias, y, thread_pool, helper);
}
Comment thread
hariharans29 marked this conversation as resolved.

if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) {
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ MlasLutGemmPack(
* @param[in] N column size of matrix B
* @param[in] HasZeroPoint whether zero points are provided
* @param[in] threadpool thread pool for parallel computation
* @param[in] Bias optional bias vector of length N (one value per output feature).
* When non-null, it is broadcast-added to every row of the [M, N]
* output. The addition is fused into the per-tile compute loop so
* it inherits the same multi-threading as the GEMM itself.
* Pass nullptr if no bias is to be applied.
*/
void MLASCALL
MlasLutGemm(
Expand All @@ -369,5 +374,6 @@ MlasLutGemm(
size_t M,
size_t N,
bool HasZeroPoint,
MLAS_THREADPOOL* threadpool
MLAS_THREADPOOL* threadpool,
const float* Bias = nullptr
);
19 changes: 18 additions & 1 deletion onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ MlasLutGemm(
size_t M, // batch size (number of rows in activation)
size_t N,
bool HasZeroPoint,
MLAS_THREADPOOL* threadpool
MLAS_THREADPOOL* threadpool,
const float* Bias
)
{
// adapted from ggml_backend_tmac_mul_mat
Expand Down Expand Up @@ -616,6 +617,22 @@ MlasLutGemm(
BlkLen, // Weight quantization group size
HasZeroPoint // Whether zero points are used
);

// Fused bias add: broadcast the per-output-feature Bias[N] slice into the
// just-written tile. The output tile we just wrote is `ChunkSize0` contiguous
// floats at `act_output + dst_offset`, corresponding to output feature indices
// [ichunk0 * ChunkSize0, ichunk0 * ChunkSize0 + ChunkSize0). The bias slice
// therefore aligns at `Bias + ichunk0 * ChunkSize0`. Doing this here (rather
// than as a separate post-pass) keeps the data hot in cache and inherits the
// existing per-chunk parallelism for free.
if (Bias != nullptr) {
const size_t tile_n = ir0_end - ir0_start;
float* y_tile = act_output + dst_offset;
const float* bias_tile = Bias + ichunk0 * ChunkSize0;
for (size_t i = 0; i < tile_n; ++i) {
y_tile[i] += bias_tile[i];
}
}
}
}
}
Expand Down
60 changes: 58 additions & 2 deletions onnxruntime/test/contrib_ops/matmul_2bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) {

template <typename AType>
void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
bool has_zero_point, float abs_error = 0.15f, float rel_error = 0.05f) {
bool has_zero_point, bool has_bias = false,
float abs_error = 0.15f, float rel_error = 0.05f) {
if (K % 32 != 0 || N % 128 != 0 || block_size % 32 != 0) {
GTEST_SKIP() << "LUT GEMM requires K multiple of 32, N multiple of 128, block_size multiple of 32";
}
Expand Down Expand Up @@ -308,13 +309,27 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
static_cast<int32_t>(N),
tp);

// Optional per-output-feature bias with non-trivial variation across N so any stride/transpose
// bug in the fused bias add (inside MlasLutGemm) is observable.
std::vector<float> bias;
if (has_bias) {
bias.resize(static_cast<size_t>(N));
for (int64_t n = 0; n < N; ++n) {
bias[static_cast<size_t>(n)] =
0.125f + 0.5f * static_cast<float>(n % 7) - 0.25f * static_cast<float>(n % 11);
}
}

std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k];
}
if (has_bias) {
sum += bias[static_cast<size_t>(n)];
}
expected_vals[m * N + n] = sum;
}
}
Expand Down Expand Up @@ -344,7 +359,16 @@ void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size,
}

test.AddOptionalInputEdge<int32_t>();
test.AddOptionalInputEdge<AType>();

if (has_bias) {
if constexpr (std::is_same<AType, float>::value) {
test.AddInput<AType>("bias", {N}, bias, true);
} else {
test.AddOptionalInputEdge<AType>();
}
} else {
test.AddOptionalInputEdge<AType>();
}

if constexpr (std::is_same<AType, float>::value) {
test.AddOutput<AType>("Y", {M, N}, expected_vals);
Expand Down Expand Up @@ -405,6 +429,38 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256) {
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, true);
}

// Fused-bias tests — verify the Bias parameter to MlasLutGemm is broadcast-added correctly.
// These are the regression tests for the bug where the LUT path silently dropped the optional
// `bias` input of MatMulNBits.
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 128, 32, /*has_zero_point=*/true, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256_BlkLen64_Bias) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256_BlkLen64_Bias) {
TestMatMul2BitsLutGemm<float>(1, 256, 256, 64, /*has_zero_point=*/true, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x256_BlkLen128_Bias) {
TestMatMul2BitsLutGemm<float>(1, 128, 256, 128, /*has_zero_point=*/true, /*has_bias=*/true);
}

// Batched (M>1) bias tests — exercise the per-row bias broadcast across many activation rows.
TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_Batch32_128x128_Bias) {
TestMatMul2BitsLutGemm<float>(32, 128, 128, 32, /*has_zero_point=*/false, /*has_bias=*/true);
}

TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) {
TestMatMul2BitsLutGemm<float>(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true);
}

// Float zero point tests — directed QAD scenario (zp=1.5)
void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) {
RandomValueGenerator random{1234};
Expand Down
27 changes: 18 additions & 9 deletions onnxruntime/test/mlas/bench/bench_lutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <stdexcept>

static const std::vector<std::string> lutgemm_bench_arg_names = {"BlkLen", "N", "K", "Threads", "HasZP"};
static const std::vector<std::string> lutgemm_compute_arg_names = {"BlkLen", "M", "N", "K", "Threads", "HasZP"};
static const std::vector<std::string> lutgemm_compute_arg_names = {"BlkLen", "M", "N", "K", "Threads", "HasZP", "HasBias"};

template <size_t BlkBitWidth>
void LUTGEMM_PACK(benchmark::State& state) {
Expand Down Expand Up @@ -95,6 +95,7 @@ void LUTGEMM_COMPUTE(benchmark::State& state) {
const size_t K = static_cast<size_t>(state.range(3));
const size_t Threads = static_cast<size_t>(state.range(4));
const bool HasZeroPoint = static_cast<bool>(state.range(5));
const bool HasBias = static_cast<bool>(state.range(6));

if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) {
state.SkipWithMessage("LUT GEMM is not available with the given configuration.");
Expand Down Expand Up @@ -144,14 +145,21 @@ void LUTGEMM_COMPUTE(benchmark::State& state) {
PackedBuf.data(),
tp.get());

std::vector<float> Bias;
const float* BiasPtr = nullptr;
if (HasBias) {
Bias = RandomVectorUniform(N, -1.0f, 1.0f);
BiasPtr = Bias.data();
}

MlasLutGemm(A.data(), BlkLen, PackedBuf.data(), C.data(),
static_cast<int>(K), static_cast<int>(M), static_cast<int>(N),
HasZeroPoint, tp.get());
HasZeroPoint, tp.get(), BiasPtr);

for (auto _ : state) {
MlasLutGemm(A.data(), BlkLen, PackedBuf.data(), C.data(),
static_cast<int>(K), static_cast<int>(M), static_cast<int>(N),
HasZeroPoint, tp.get());
HasZeroPoint, tp.get(), BiasPtr);
}
}

Expand All @@ -169,12 +177,13 @@ static void LutGemmPackArgs(benchmark::internal::Benchmark* b) {
static void LutGemmComputeArgs(benchmark::internal::Benchmark* b) {
b->ArgNames(lutgemm_compute_arg_names);
b->ArgsProduct({
{128}, // BlkLen
{1, 32}, // M
{4096}, // N
{4096}, // K
{8}, // Threads
{int64_t{false}}, // HasZeroPoint
{128}, // BlkLen
{1, 32}, // M
{4096}, // N
{4096}, // K
{8}, // Threads
{int64_t{false}}, // HasZeroPoint
Comment thread
hariharans29 marked this conversation as resolved.
Outdated
{int64_t{false}, int64_t{true}}, // HasBias
});
}

Expand Down
Loading
Loading