-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate KleidiAI for MatMulNBits via MlasQNBitGemm #23627
base: main
Are you sure you want to change the base?
Changes from all commits
bf69914
6d527e1
de7f1e5
1a7744d
f3d4b2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <[email protected]> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file was not mostly written by ARM. It has contributions from many contributors outside of Microsoft. If everyone adds such a license header here, it will be soon unmanageable. Please note that when anyone makes a contribution to Microsoft's open source project, they need to agree that Microsoft has the right to re-license the change. Therefore I think it's better to keep the license header unchanged. |
||
# Licensed under the MIT License. | ||
|
||
set(MLAS_ROOT ${ONNXRUNTIME_ROOT}/core/mlas) | ||
|
@@ -99,6 +100,10 @@ function(setup_mlas_source_for_windows) | |
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp | ||
) | ||
|
||
setup_kleidiai() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It means that kleidiai will be a new dependency for all ONNX Runtime build configs. For such changes the onnx runtime team needs to hold an internal discussion with the leadership of this project. |
||
onnxruntime_fetchcontent_makeavailable(kleidiai) | ||
set(KLEIDIAI_SRC ${kleidiai_SOURCE_DIR}) | ||
|
||
set(mlas_platform_preprocess_srcs | ||
${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm | ||
${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm | ||
|
@@ -118,6 +123,10 @@ function(setup_mlas_source_for_windows) | |
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm | ||
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm | ||
${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S | ||
) | ||
else() | ||
target_sources(onnxruntime_mlas PRIVATE | ||
|
@@ -243,6 +252,44 @@ function(setup_mlas_source_for_windows) | |
endif() | ||
endfunction() | ||
|
||
function(setup_kleidiai) | ||
# Disable the KleidiAI tests | ||
set(KLEIDIAI_BUILD_TESTS OFF) | ||
|
||
# Fetch KleidiAI sources: | ||
if (NOT TARGET kleidiai) | ||
if (POLICY CMP0135) | ||
cmake_policy(SET CMP0135 NEW) | ||
endif() | ||
|
||
FetchContent_Declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai}) | ||
endif() | ||
onnxruntime_fetchcontent_makeavailable(kleidiai) | ||
set(KLEIDIAI_SRC ${kleidiai_SOURCE_DIR}) | ||
|
||
# KleidiAI | ||
include_directories(${KLEIDIAI_SRC}/) | ||
|
||
target_sources(onnxruntime_mlas PRIVATE | ||
${MLAS_SRC_DIR}/kai_ukernel_interface.cpp | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.c | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod_asm.S | ||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm_asm.S | ||
) | ||
|
||
if (NOT MSVC) | ||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp | ||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") | ||
endif() | ||
endfunction() | ||
|
||
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") | ||
if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) | ||
file(GLOB_RECURSE mlas_platform_srcs | ||
|
@@ -378,6 +425,7 @@ else() | |
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp | ||
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp | ||
) | ||
setup_kleidiai() | ||
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp | ||
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") | ||
if (NOT APPLE) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <[email protected]> | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" | ||
|
@@ -133,6 +134,7 @@ class MatMulNBits final : public OpKernel { | |
const size_t nbits_; | ||
const bool has_g_idx_; | ||
const bool has_bias_; | ||
bool scales_are_packed_{false}; | ||
const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; | ||
bool has_unquantized_zero_point_{false}; | ||
const bool column_wise_quant_{true}; | ||
|
@@ -181,13 +183,17 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All | |
return Status::OK(); | ||
} | ||
if (input_idx == InputIndex::B) { | ||
packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type_); | ||
const Tensor* scales = nullptr; | ||
OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); | ||
|
||
packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); | ||
if (packed_b_size_ == 0) { | ||
return Status::OK(); | ||
} | ||
auto qptr = tensor.DataRaw(); | ||
auto scale_ptr = scales? scales->DataRaw() : nullptr; | ||
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true); | ||
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); | ||
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, has_zp_input_, nullptr, nullptr); | ||
is_packed = true; | ||
} else if (compute_type_ == SQNBIT_CompInt8) { | ||
#ifdef MLAS_TARGET_AMD64_IX86 | ||
|
@@ -201,7 +207,13 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All | |
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); | ||
is_packed = false; | ||
} | ||
#endif // MLAS_TARGET_AMD64_IX86 | ||
#elif defined(MLAS_TARGET_ARM64) | ||
if (input_idx == InputIndex::scales && packed_b_ != nullptr | ||
&& MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_)) { | ||
scales_are_packed_ = true; | ||
is_packed = true; | ||
} | ||
#endif // MLAS_TARGET_ARM64 | ||
} | ||
|
||
return Status::OK(); | ||
|
@@ -287,7 +299,7 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a, | |
concurrency::ThreadPool* thread_pool, | ||
const MatMulComputeHelper& helper) const { | ||
const auto* a_data = a->Data<T1>(); | ||
const auto* scales_data = scales->Data<T1>(); | ||
const auto* scales_data = scales == nullptr? nullptr : scales->Data<T1>(); | ||
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); | ||
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<T1>(); | ||
auto* y_data = y->MutableData<T1>(); | ||
|
@@ -300,7 +312,7 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a, | |
|
||
IAllocatorUniquePtr<std::byte> workspace{}; | ||
const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( | ||
M, N, K, batch_count, nbits_, block_size_, compute_type_); | ||
M, N, K, batch_count, nbits_, block_size_, zero_points, compute_type_); | ||
if (workspace_size > 0) { | ||
// Use reserve since no caching is needed | ||
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size, true); | ||
|
@@ -310,11 +322,9 @@ Status MatMulNBits<T1>::ComputeBPacked(const Tensor* a, | |
for (size_t i = 0; i < batch_count; ++i) { | ||
data[i].A = a_data + helper.LeftOffsets()[i]; | ||
data[i].lda = lda; | ||
#ifdef MLAS_TARGET_AMD64_IX86 | ||
if (compute_type_ == SQNBIT_CompInt8) { | ||
data[i].QuantBDataWorkspace = packed_b_.get(); | ||
} | ||
#endif | ||
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get()); | ||
data[i].QuantBScale = scales_data; | ||
data[i].QuantBZeroPoint = zero_points_data; | ||
|
@@ -351,7 +361,7 @@ Status MatMulNBits<MLFloat16>::ComputeBPacked(const Tensor* a, | |
|
||
IAllocatorUniquePtr<std::byte> workspace{}; | ||
const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize( | ||
M, N, K, batch_count, nbits_, block_size_, compute_type_); | ||
M, N, K, batch_count, nbits_, block_size_, zero_points, compute_type_); | ||
if (workspace_size > 0) { | ||
// Use reserve since no caching is needed | ||
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size, true); | ||
|
@@ -653,7 +663,7 @@ template <typename T1> | |
Status MatMulNBits<T1>::Compute(OpKernelContext* ctx) const { | ||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); | ||
const Tensor* a = ctx->Input<Tensor>(InputIndex::A); | ||
const Tensor* scales = ctx->Input<Tensor>(InputIndex::scales); | ||
const Tensor* scales = scales_are_packed_? nullptr : ctx->Input<Tensor>(InputIndex::scales); | ||
const Tensor* zero_points = ctx->Input<Tensor>(InputIndex::zero_points); | ||
const Tensor* reorder_idx = ctx->Input<Tensor>(InputIndex::g_idx); | ||
const Tensor* bias = ctx->Input<Tensor>(InputIndex::bias); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
/*++ | ||
|
||
Copyright (c) Microsoft Corporation. All rights reserved. | ||
SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]> | ||
|
||
Licensed under the MIT License. | ||
|
||
|
@@ -123,6 +124,7 @@ MlasIsQNBitGemmAvailable( | |
* @param[in] BatchN number of batches | ||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) | ||
* @param[in] BlkLen number of quantized values per block | ||
* @param[in] has_zp_input whether zero points are provided | ||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) | ||
*/ | ||
size_t MLASCALL | ||
|
@@ -133,6 +135,7 @@ MlasQNBitGemmBatchWorkspaceSize( | |
size_t BatchN, | ||
size_t BlkBitWidth, | ||
size_t BlkLen, | ||
bool has_zp_input, | ||
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType | ||
); | ||
|
||
|
@@ -147,6 +150,7 @@ MlasQNBitGemmBatchWorkspaceSize( | |
* @param[in] K column size of matrix A and row size of matrix B | ||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) | ||
* @param[in] BlkLen number of quantized values per block | ||
* @param[in] has_zp_input whether zero points are provided | ||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) | ||
*/ | ||
size_t MLASCALL | ||
|
@@ -155,6 +159,7 @@ MlasQNBitGemmPackQuantBDataSize( | |
size_t K, | ||
size_t BlkBitWidth, | ||
size_t BlkLen, | ||
bool has_zp_input, | ||
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType | ||
); | ||
|
||
|
@@ -199,3 +204,21 @@ MlasQNBitGemmPackQuantBData( | |
const void* QuantBZeroPoint, | ||
MLAS_THREADPOOL* ThreadPool | ||
); | ||
|
||
/** | ||
* @brief Returns true if scales are packed when calling MlasQNBitGemmPackQuantBData the first time. | ||
* | ||
* @param[in] K column size of matrix A and row size of matrix B | ||
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) | ||
* @param[in] BlkLen number of quantized values per block | ||
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) | ||
* @param[in] has_zp_input whether QuantBZeroPoint is provided | ||
*/ | ||
bool MLASCALL | ||
MlasQNBitGemmScalesPacked( | ||
size_t K, | ||
size_t BlkBitWidth, | ||
size_t BlkLen, | ||
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, | ||
bool has_zp_input | ||
); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// | ||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]> | ||
// | ||
// SPDX-License-Identifier: MIT | ||
// | ||
|
||
#include "kai_ukernel_interface.h" | ||
#include "mlasi.h" | ||
|
||
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" | ||
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" | ||
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" | ||
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = | ||
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||
kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod}; | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod = | ||
{kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||
kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod}; | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod = | ||
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm = | ||
{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, | ||
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { | ||
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { | ||
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; | ||
} else { | ||
return kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod; | ||
} | ||
} | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { | ||
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { | ||
return kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod; | ||
} else { | ||
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// | ||
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <[email protected]> | ||
// | ||
// SPDX-License-Identifier: MIT | ||
// | ||
|
||
#pragma once | ||
|
||
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" | ||
|
||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); | ||
kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update https://github.com/microsoft/onnxruntime/blob/main/ThirdPartyNotices.txt as well.