From a40c063551cda95e2a5db097ad3e4890b28ae8a5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 29 Aug 2025 17:33:12 -0700 Subject: [PATCH] [ET-VK][ez] Add ability to check for dot product extension support + upgrade glslc Pull Request resolved: https://github.com/pytorch/executorch/pull/13814 ## Motivation Prepare for shaders that will use accelerated int8 dot product GLSL extensions, i.e. `dotPacked4x8AccSatEXT` ## Changes * Query for support for the shader integer dot product extension when creating the VkPhysicalDevice * Request the shader integer dot product extension when creating VkDevice * Provide APIs to check if the extension is available in the current runtime. ghstack-source-id: 306632732 @exported-using-ghexport Differential Revision: [D81323427](https://our.internmc.facebook.com/intern/diff/D81323427/) --- .ci/scripts/setup-vulkan-linux-deps.sh | 2 +- backends/vulkan/runtime/api/Context.cpp | 6 + backends/vulkan/runtime/gen_vulkan_spv.py | 4 + .../vulkan/runtime/graph/ComputeGraph.cpp | 5 + backends/vulkan/runtime/graph/ComputeGraph.h | 8 ++ backends/vulkan/runtime/vk_api/Adapter.cpp | 112 ++++++++++++++++++ backends/vulkan/runtime/vk_api/Adapter.h | 9 ++ backends/vulkan/runtime/vk_api/Device.cpp | 13 ++ backends/vulkan/runtime/vk_api/Device.h | 6 + backends/vulkan/runtime/vk_api/Exception.cpp | 3 + backends/vulkan/runtime/vk_api/Exception.h | 1 + backends/vulkan/runtime/vk_api/QueryPool.cpp | 2 +- backends/vulkan/runtime/vk_api/Shader.cpp | 6 +- backends/vulkan/runtime/vk_api/Shader.h | 4 +- 14 files changed, 176 insertions(+), 5 deletions(-) diff --git a/.ci/scripts/setup-vulkan-linux-deps.sh b/.ci/scripts/setup-vulkan-linux-deps.sh index 1266bce38a6..cd99ff0d6ff 100755 --- a/.ci/scripts/setup-vulkan-linux-deps.sh +++ b/.ci/scripts/setup-vulkan-linux-deps.sh @@ -43,7 +43,7 @@ install_vulkan_sdk() { export PATH="${PATH}:${_vulkan_sdk_dir}/${VULKAN_SDK_VERSION}/x86_64/bin/" } -VULKAN_SDK_VERSION="1.3.296.0" +VULKAN_SDK_VERSION="1.4.321.1" install_swiftshader install_vulkan_sdk "${VULKAN_SDK_VERSION}" diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 68db37b866e..8599cbfffb6 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -111,6 +111,12 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE); } } + if (shader.requires_integer_dot_product) { + if (!adapter_p_->supports_int8_dot_product()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT); + } + } } vkapi::DescriptorSet Context::get_descriptor_set( diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 9b6d53c5d05..3f2d616b428 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -1103,6 +1103,7 @@ class ShaderInfo: requires_shader_int16_ext: bool = False requires_16bit_storage_ext: bool = False requires_8bit_storage_ext: bool = False + requires_integer_dot_product_ext: bool = False def getName(filePath: str) -> str: @@ -1213,6 +1214,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.requires_16bit_storage_ext = True if "GL_EXT_shader_8bit_storage" in line: shader_info.requires_8bit_storage_ext = True + if "GL_EXT_integer_dot_product" in line: + shader_info.requires_integer_dot_product_ext = True return shader_info @@ -1288,6 +1291,7 @@ def to_cpp_str(val: bool): to_cpp_str(shader_info.requires_shader_int16_ext), to_cpp_str(shader_info.requires_16bit_storage_ext), to_cpp_str(shader_info.requires_8bit_storage_ext), + to_cpp_str(shader_info.requires_integer_dot_product_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index fff530d57cb..f40a6b0f286 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -155,6 +155,11 @@ ComputeGraph::ComputeGraph(GraphConfig config) config_.execute_threshold_node_count = 128; config_.execute_initial_threshold_node_count = 64; } + + // Check if the underlying GPU can access accelerated integer dot product + // instructions + can_use_int8_dot_product_ = + context_->adapter_ptr()->supports_int8_dot_product(); } ComputeGraph::~ComputeGraph() { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 4257f63fab6..78fb79e65e8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -221,6 +221,10 @@ class ComputeGraph final { // config.execute_threshold_node_count. size_t execute_threshold_node_count_ = 0; + // Whether the underlying GPU support accelerated integer dot product + // extensions + bool can_use_int8_dot_product_ = false; + public: // // Accessors @@ -1013,6 +1017,10 @@ class ComputeGraph final { return execute_count_; } + inline bool can_use_int8_dot_product() const { + return can_use_int8_dot_product_; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index e08491c656b..0e87dde1922 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -109,6 +109,9 @@ VkDevice create_logical_device( #ifdef VK_KHR_shader_float16_int8 VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME, #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + VK_KHR_SHADER_INTEGER_DOT_PRODUCT_EXTENSION_NAME, +#endif /* VK_KHR_shader_integer_dot_product */ #if defined(VK_KHR_pipeline_executable_properties) && defined(VULKAN_DEBUG) VK_KHR_PIPELINE_EXECUTABLE_PROPERTIES_EXTENSION_NAME, #endif /* VK_KHR_pipeline_executable_properties */ @@ -160,6 +163,14 @@ VkDevice create_logical_device( extension_list_top = &shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR + shader_int_dot_product_features{ + physical_device.shader_int_dot_product_features}; + shader_int_dot_product_features.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_features; +#endif /* VK_KHR_shader_integer_dot_product */ + device_create_info.pNext = extension_list_top; VkDevice handle = nullptr; @@ -401,6 +412,107 @@ std::string Adapter::stringize() const { #endif /* VK_KHR_shader_float16_int8 */ ss << " }" << std::endl; +#ifdef VK_KHR_shader_integer_dot_product + ss << " Shader Integer Dot Product Features {" << std::endl; + PRINT_PROP( + physical_device_.shader_int_dot_product_features, + shaderIntegerDotProduct); + ss << " }" << std::endl; + + ss << " Shader Integer Dot Product Properties {" << std::endl; + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitMixedSignednessAccelerated); + ss << " }" << std::endl; +#endif /* VK_KHR_shader_integer_dot_product */ + const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index aa4c659c6d8..6a68b487348 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -212,6 +212,15 @@ class Adapter final { #endif /* VK_KHR_shader_float16_int8 */ } + inline bool supports_int8_dot_product() { +#ifdef VK_KHR_shader_integer_dot_product + return physical_device_.shader_int_dot_product_features + .shaderIntegerDotProduct == VK_TRUE; +#else + return false; +#endif /* VK_KHR_shader_integer_dot_product */ + } + inline bool supports_int16_shader_types() { return physical_device_.supports_int16_shader_types; } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index b9e3b444db2..a21130f1231 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -36,6 +36,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) shader_float16_int8_types{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + shader_int_dot_product_features{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR}, + shader_int_dot_product_properties{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES_KHR}, +#endif queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), @@ -77,6 +83,13 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) extension_list_top = &shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + shader_int_dot_product_features.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_features; + shader_int_dot_product_properties.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_properties; +#endif /* VK_KHR_shader_integer_dot_product */ + features2.pNext = extension_list_top; vkGetPhysicalDeviceFeatures2(handle, &features2); diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 3fdfcc04a49..f5b7154d260 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -44,6 +44,12 @@ struct PhysicalDevice final { #ifdef VK_KHR_shader_float16_int8 VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + VkPhysicalDeviceShaderIntegerDotProductFeatures + shader_int_dot_product_features; + VkPhysicalDeviceShaderIntegerDotProductProperties + shader_int_dot_product_properties; +#endif /* VK_KHR_shader_integer_dot_product */ // Available GPU queues std::vector queue_families; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index d26fbd8cb22..c07349fa7ca 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -92,6 +92,9 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { case VulkanExtension::INT8_STORAGE: out << "VK_KHR_8bit_storage"; break; + case VulkanExtension::INTEGER_DOT_PRODUCT: + out << "VK_KHR_shader_integer_dot_product"; + break; } return out; } diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index a65afb1bcc5..a883a68fefc 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -82,6 +82,7 @@ enum class VulkanExtension : uint8_t { SHADER_INT16, INT16_STORAGE, INT8_STORAGE, + INTEGER_DOT_PRODUCT, }; class ShaderNotSupportedError : public std::exception { diff --git a/backends/vulkan/runtime/vk_api/QueryPool.cpp b/backends/vulkan/runtime/vk_api/QueryPool.cpp index 2f6d433b887..e8b3ca55206 100644 --- a/backends/vulkan/runtime/vk_api/QueryPool.cpp +++ b/backends/vulkan/runtime/vk_api/QueryPool.cpp @@ -209,7 +209,7 @@ std::string QueryPool::generate_string_report() { std::stringstream ss; - int kernel_name_w = 40; + int kernel_name_w = 120; int global_size_w = 25; int local_size_w = 25; int duration_w = 25; diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 458b1f83956..4356f92efe7 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -31,7 +31,8 @@ ShaderInfo::ShaderInfo( const utils::uvec3 tile_size, const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, - const bool requires_8bit_storage_ext) + const bool requires_8bit_storage_ext, + const bool requires_integer_dot_product_ext) : src_code{ spirv_bin, size, @@ -41,7 +42,8 @@ ShaderInfo::ShaderInfo( out_tile_size(tile_size), requires_shader_int16(requires_shader_int16_ext), requires_16bit_storage(requires_16bit_storage_ext), - requires_8bit_storage(requires_8bit_storage_ext) { + requires_8bit_storage(requires_8bit_storage_ext), + requires_integer_dot_product(requires_integer_dot_product_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 7d0fa7b7476..21332381406 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -65,6 +65,7 @@ struct ShaderInfo final { bool requires_shader_int16 = false; bool requires_16bit_storage = false; bool requires_8bit_storage = false; + bool requires_integer_dot_product = false; explicit ShaderInfo(); @@ -76,7 +77,8 @@ struct ShaderInfo final { const utils::uvec3 tile_size, const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, - const bool requires_8bit_storage_ext); + const bool requires_8bit_storage_ext, + const bool requires_integer_dot_product_ext); operator bool() const { return src_code.bin != nullptr;