Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate KleidiAI for MatMulNBits via MlasQNBitGemm #23627

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
dawn;https://github.com/google/dawn/archive/b9b4a37041dec3dd62ac92014a6cc1aece48d9f3.zip;e8b8c2ebabdedb7c57d931fc4a19ae22146d31e1
kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/d15722976120710080ca098fe8ddabf4556cb40f/kleidiai-d15722976120710080ca098fe8ddabf4556cb40f.zip;d6c840d00c3b05aedf06e957ddaece1013d1f40b
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.3.0.tar.gz;58777d6907bdedb165fbca2e467a26b1363dc924
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

48 changes: 48 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <[email protected]>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was not mostly written by ARM. It has contributions from many contributors outside of Microsoft. If everyone adds such a license header here, it will be soon unmanageable. Please note that when anyone makes a contribution to Microsoft's open source project, they need to agree that Microsoft has the right to re-license the change. Therefore I think it's better to keep the license header unchanged.

# Licensed under the MIT License.

set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas)
Expand Down Expand Up @@ -99,6 +100,10 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)

setup_kleidiai()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means that kleidiai will be a new dependency for all ONNX Runtime build configs. For such changes the onnx runtime team needs to hold an internal discussion with the leadership of this project.

onnxruntime_fetchcontent_makeavailable(kleidiai)
set(KLEIDIAI_SRC ${kleidiai_SOURCE_DIR})

set(mlas_platform_preprocess_srcs
${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm
${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm
Expand All @@ -118,6 +123,10 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S
)
else()
target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -243,6 +252,44 @@ function(setup_mlas_source_for_windows)
endif()
endfunction()

function(setup_kleidiai)
# Disable the KleidiAI tests
set(KLEIDIAI_BUILD_TESTS OFF)

# Fetch KleidiAI sources:
if (NOT TARGET kleidiai)
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
endif()

FetchContent_Declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai})
endif()
onnxruntime_fetchcontent_makeavailable(kleidiai)
set(KLEIDIAI_SRC ${kleidiai_SOURCE_DIR})

# KleidiAI
include_directories(${KLEIDIAI_SRC}/)

target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S
)

if (NOT MSVC)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
endif()
endfunction()

if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
file(GLOB_RECURSE mlas_platform_srcs
Expand Down Expand Up @@ -378,6 +425,7 @@ else()
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
)
setup_kleidiai()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
if (NOT APPLE)
Expand Down
28 changes: 19 additions & 9 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <[email protected]>
// Licensed under the MIT License.

#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
Expand Down Expand Up @@ -133,6 +134,7 @@ class MatMulNBits final : public OpKernel {
const size_t nbits_;
const bool has_g_idx_;
const bool has_bias_;
bool scales_are_packed_{false};
const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_;
bool has_unquantized_zero_point_{false};
const bool column_wise_quant_{true};
Expand Down Expand Up @@ -181,13 +183,17 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
return Status::OK();
}
if (input_idx == InputIndex::B) {
packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_);
const Tensor* scales = nullptr;
OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales);

packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_);
if (packed_b_size_ == 0) {
return Status::OK();
}
auto qptr = tensor.DataRaw();
auto scale_ptr = scales? scales->DataRaw() : nullptr;
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr);
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, has_zp_input_, nullptr, nullptr);
is_packed = true;
} else if (compute_type_ == SQNBIT_CompInt8) {
#ifdef MLAS_TARGET_AMD64_IX86
Expand All @@ -201,7 +207,13 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);
is_packed = false;
}
#endif // MLAS_TARGET_AMD64_IX86
#elif defined(MLAS_TARGET_ARM64)
if (input_idx == InputIndex::scales && packed_b_ != nullptr
&& MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_)) {
scales_are_packed_ = true;
is_packed = true;
}
#endif // MLAS_TARGET_ARM64
}

return Status::OK();
Expand Down Expand Up @@ -287,7 +299,7 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
concurrency::ThreadPool* thread_pool,
const MatMulComputeHelper& helper) const {
const auto* a_data = a->Data<T1>();
const auto* scales_data = scales->Data<T1>();
const auto* scales_data = scales == nullptr? nullptr : scales->Data<T1>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<T1>();
auto* y_data = y->MutableData<T1>();
Expand All @@ -300,7 +312,7 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,

IAllocatorUniquePtr<std::byte> workspace{};
const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type_);
M, N, K, batch_count, nbits_, block_size_, zero_points, compute_type_);
if (workspace_size > 0) {
// Use reserve since no caching is needed
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size, true);
Expand All @@ -310,11 +322,9 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a,
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type_ == SQNBIT_CompInt8) {
data[i].QuantBDataWorkspace = packed_b_.get();
}
#endif
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
Expand Down Expand Up @@ -351,7 +361,7 @@ Status MatMulNBits<MLFloat16>::ComputeBPacked(const Tensor* a,

IAllocatorUniquePtr<std::byte> workspace{};
const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type_);
M, N, K, batch_count, nbits_, block_size_, zero_points, compute_type_);
if (workspace_size > 0) {
// Use reserve since no caching is needed
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size, true);
Expand Down Expand Up @@ -653,7 +663,7 @@ template <typename T1>
Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input<Tensor>(InputIndex::A);
const Tensor* scales = ctx->Input<Tensor>(InputIndex::scales);
const Tensor* scales = scales_are_packed_? nullptr : ctx->Input<Tensor>(InputIndex::scales);
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points);
const Tensor* reorder_idx = ctx->Input<Tensor>(InputIndex::g_idx);
const Tensor* bias = ctx->Input<Tensor>(InputIndex::bias);
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*++

Copyright (c) Microsoft Corporation. All rights reserved.
SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>

Licensed under the MIT License.

Expand Down Expand Up @@ -123,6 +124,7 @@ MlasIsQNBitGemmAvailable(
* @param[in] BatchN number of batches
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] has_zp_input whether zero points are provided
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
Expand All @@ -133,6 +135,7 @@ MlasQNBitGemmBatchWorkspaceSize(
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
bool has_zp_input,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

Expand All @@ -147,6 +150,7 @@ MlasQNBitGemmBatchWorkspaceSize(
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] has_zp_input whether zero points are provided
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
Expand All @@ -155,6 +159,7 @@ MlasQNBitGemmPackQuantBDataSize(
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
bool has_zp_input,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

Expand Down Expand Up @@ -199,3 +204,21 @@ MlasQNBitGemmPackQuantBData(
const void* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief Returns true if scales are packed when calling MlasQNBitGemmPackQuantBData the first time.
*
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] has_zp_input whether QuantBZeroPoint is provided
*/
bool MLASCALL
MlasQNBitGemmScalesPacked(
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
bool has_zp_input
);
81 changes: 81 additions & 0 deletions onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
//
// SPDX-License-Identifier: MIT
//

#include "kai_ukernel_interface.h"
#include "mlasi.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod};

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod =
{kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod};

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod =
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod};

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm =
{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm};

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm;
} else {
return kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod;
}
}

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
return kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod;
} else {
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod;
}
}
12 changes: 12 additions & 0 deletions onnxruntime/core/mlas/lib/kai_ukernel_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
//
// SPDX-License-Identifier: MIT
//

#pragma once

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"

kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();
Loading
Loading