diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index bde73252449dc..77b18390b6afd 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -774,11 +774,29 @@ else() ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + + include(CheckCSourceCompiles) + + set(MLAS_OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}") + if(CMAKE_REQUIRED_FLAGS) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -mavx512fp16") + else() + set(CMAKE_REQUIRED_FLAGS "-mavx512fp16") + endif() + check_c_source_compiles(" + int main() { + __asm__ volatile(\"vcvtneeph2ps %ymm0, %ymm1\"); + return 0; + } + " COMPILER_SUPPORTS_AVX512FP16) + set(CMAKE_REQUIRED_FLAGS "${MLAS_OLD_CMAKE_REQUIRED_FLAGS}") + + if(COMPILER_SUPPORTS_AVX512FP16 AND NOT APPLE) set(mlas_platform_srcs_avx2 ${mlas_platform_srcs_avx2} ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S ) + list(APPEND mlas_private_compile_definitions MLAS_SUPPORTS_AVX512FP16) endif() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") @@ -997,4 +1015,4 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) endif() endif() -endif() \ No newline at end of file +endif() diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e9f140a2ee0f7..1c295799541b2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -527,14 +527,14 @@ Return Value: } #ifndef __APPLE__ -#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) +#if defined(MLAS_SUPPORTS_AVX512FP16) // // Check if the processor supports AVX NE CONVERT. // if ((Cpuid7_1[3] & (0b1 << 5)) != 0) { this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx; } -#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) +#endif // MLAS_SUPPORTS_AVX512FP16 // @@ -671,7 +671,7 @@ Return Value: } else{ this->ErfFP16KernelRoutine = MlasNeonErfFP16Kernel; - this->GeluFP16KernelRoutine = MlasNeonGeluFP16Kernel; + this->GeluFP16KernelRoutine = MlasNeonGeluFP16Kernel; } #else this->ErfFP16KernelRoutine = MlasNeonErfFP16Kernel;