Skip to content

Commit

Permalink
[ARM CPU] Add fp16 mlas kernels for exp, tanh, softmax, logsoftmax, s…
Browse files Browse the repository at this point in the history
…oftcap (#23597)

### Description
Add fp16 mlas kernels for exp, tanh, softmax, logsoftmax, softcap on ARM
CPU



### Motivation and Context
Group query attention supports fast fp16 CPU EP.
  • Loading branch information
fajin-corp authored Feb 10, 2025
1 parent 9ba5619 commit e206950
Show file tree
Hide file tree
Showing 17 changed files with 1,835 additions and 134 deletions.
8 changes: 8 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
${MLAS_SRC_DIR}/softmax.h
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -97,6 +98,9 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -377,6 +381,8 @@ else()
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon.h
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
Expand All @@ -398,6 +404,7 @@ else()
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -411,6 +418,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -990,11 +990,12 @@ MlasComputeErf(
size_t N
);

template <typename T>
void
MLASCALL
MlasComputeExp(
const float* Input,
float* Output,
const T* Input,
T* Output,
size_t N
);

Expand All @@ -1006,23 +1007,35 @@ MlasComputeLogistic(
size_t N
);

template <typename T>
void
MLASCALL
MlasComputeSoftmax(
const float* Input,
float* Output,
const T* Input,
T* Output,
size_t N,
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
MLAS_THREADPOOL* ThreadPool
);

template <typename T>
void
MLASCALL
MlasComputeSoftcap(
const T* Input,
T* Output,
size_t N,
T cap
);

template<typename T>
void
MLASCALL
MlasComputeTanh(
const float* Input,
float* Output,
const T* Input,
T* Output,
size_t N
);

Expand Down
34 changes: 17 additions & 17 deletions onnxruntime/core/mlas/lib/activate_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasReluActivation>

MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
return MlasMaximumFloat16x8(ZeroVec, Value);
return MlasMaximumFloat16(ZeroVec, Value);
}

MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value);
return MlasMaximumFloat16(MlasToLowHalfFloat16x4(ZeroVec), Value);
}
};

Expand All @@ -75,15 +75,15 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasLeakyReluActivation>

MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast);
MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16(Value, AlphaBroadcast);
return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec),
ValueTimesAlpha, Value);
}

MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
MLAS_FLOAT16X4 ValueTimesAlpha =
MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast));
MlasMultiplyFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast));
return MlasBitwiseSelectFloat16x4(
MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha,
Value);
Expand Down Expand Up @@ -539,16 +539,16 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasClipActivation> {

MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMaximumFloat16x8(MinimumBroadcast, Value);
Value = MlasMinimumFloat16x8(MaximumBroadcast, Value);
Value = MlasMaximumFloat16(MinimumBroadcast, Value);
Value = MlasMinimumFloat16(MaximumBroadcast, Value);

return Value;
}

MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
return Value;
}
};
Expand All @@ -573,19 +573,19 @@ struct MLAS_HALF_ACTIVATION_FUNCTION<MlasHardSigmoidActivation>

MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value)
{
Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast);
Value = MlasMinimumFloat16x8(MaximumBroadcast, Value);
Value = MlasMaximumFloat16x8(MinimumBroadcast, Value);
Value = MlasMultiplyAddFloat16(Value, AlphaBroadcast, BetaBroadcast);
Value = MlasMinimumFloat16(MaximumBroadcast, Value);
Value = MlasMaximumFloat16(MinimumBroadcast, Value);

return Value;
}

MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value)
{
Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast),
Value = MlasMultiplyAddFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast),
MlasToLowHalfFloat16x4(BetaBroadcast));
Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);
Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value);
Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value);

return Value;
}
Expand Down Expand Up @@ -692,7 +692,7 @@ MlasActivationKernel(
MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
addsrc += 8;
Vector = MlasAddFloat16x8(Vector, AVec);
Vector = MlasAddFloat16(Vector, AVec);
Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x8(buffer, Vector);
buffer += 8;
Expand All @@ -703,7 +703,7 @@ MlasActivationKernel(
MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
addsrc += 4;
Vector = MlasAddFloat16x4(Vector, AVec);
Vector = MlasAddFloat16(Vector, AVec);
Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x4(buffer, Vector);
buffer += 4;
Expand All @@ -715,7 +715,7 @@ MlasActivationKernel(
MLAS_FLOAT16X4 buf;
std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
buf = MlasAddFloat16x4(buf, addbuf);
buf = MlasAddFloat16(buf, addbuf);
buf = ActivationFunction.Activate(buf);
MlasStorePartialFloat16x4(buffer, buf, n);
}
Expand Down
Loading

0 comments on commit e206950

Please sign in to comment.