From c00f22d03da2d5e59a0cd454b842a470d1a3b207 Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Fri, 14 Feb 2025 19:47:56 +0000 Subject: [PATCH] Call zeInitDrivers in L0 provider According to the L0 spec, zeInitDrivers must be called (by every library) before calling any other APIs. Not calling zeInitDrivers causes crash when using statically linked L0 loader in UR. --- src/provider/provider_level_zero.c | 58 +++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/src/provider/provider_level_zero.c b/src/provider/provider_level_zero.c index f89661401..6f001667c 100644 --- a/src/provider/provider_level_zero.c +++ b/src/provider/provider_level_zero.c @@ -14,6 +14,7 @@ #include #include +#include "base_alloc_global.h" #include "provider_level_zero_internal.h" #include "utils_load_library.h" #include "utils_log.h" @@ -111,7 +112,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) { #else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER) -#include "base_alloc_global.h" #include "libumf.h" #include "utils_assert.h" #include "utils_common.h" @@ -211,6 +211,49 @@ static umf_result_t ze2umf_result(ze_result_t result) { } } +static umf_result_t ze_init_drivers(void *lib_handle, const char *lib_name) { + ze_result_t (*zeInitDriversFunc)(uint32_t *, ze_driver_handle_t *, + ze_init_driver_type_desc_t *); + *(void **)&zeInitDriversFunc = + utils_get_symbol_addr(lib_handle, "zeInitDrivers", lib_name); + if (!zeInitDriversFunc) { + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + + ze_init_driver_type_desc_t desc = { + .stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC, + .pNext = NULL, + .flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU}; + uint32_t driverCount = 0; + ze_result_t result = zeInitDriversFunc(&driverCount, NULL, &desc); + if (result != ZE_RESULT_SUCCESS) { + return ze2umf_result(result); + } + + ze_driver_handle_t *zeAllDrivers = + umf_ba_global_alloc(sizeof(ze_driver_handle_t) * driverCount); + result = zeInitDriversFunc(&driverCount, zeAllDrivers, &desc); + umf_ba_global_free(zeAllDrivers); + if (result != ZE_RESULT_SUCCESS) { + return ze2umf_result(result); + } + + return UMF_RESULT_SUCCESS; +} + +static umf_result_t ze_init(void *lib_handle, const char *lib_name) { + ze_result_t (*zeInitFunc)(ze_init_flag_t); + *(void **)&zeInitFunc = + utils_get_symbol_addr(lib_handle, "zeInit", lib_name); + + if (!zeInitFunc) { + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + + ze_result_t result = zeInitFunc(ZE_INIT_FLAG_GPU_ONLY); + return ze2umf_result(result); +} + static void init_ze_global_state(void) { #ifdef _WIN32 const char *lib_name = "ze_loader.dll"; @@ -266,6 +309,19 @@ static void init_ze_global_state(void) { utils_close_library(lib_handle); return; } + + if (ze_init_drivers(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { + LOG_INFO("Initializing Level Zero through zeInitDrivers failed. " + "Falling back to zeInit."); + + if (ze_init(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { + LOG_FATAL("Failed to initialize Level Zero"); + Init_ze_global_state_failed = true; + utils_close_library(lib_handle); + return; + } + } + ze_lib_handle = lib_handle; }