diff --git a/src/provider/provider_level_zero.c b/src/provider/provider_level_zero.c index f89661401..6f001667c 100644 --- a/src/provider/provider_level_zero.c +++ b/src/provider/provider_level_zero.c @@ -14,6 +14,7 @@ #include #include +#include "base_alloc_global.h" #include "provider_level_zero_internal.h" #include "utils_load_library.h" #include "utils_log.h" @@ -111,7 +112,6 @@ umf_memory_provider_ops_t *umfLevelZeroMemoryProviderOps(void) { #else // !defined(UMF_NO_LEVEL_ZERO_PROVIDER) -#include "base_alloc_global.h" #include "libumf.h" #include "utils_assert.h" #include "utils_common.h" @@ -211,6 +211,49 @@ static umf_result_t ze2umf_result(ze_result_t result) { } } +static umf_result_t ze_init_drivers(void *lib_handle, const char *lib_name) { + ze_result_t (*zeInitDriversFunc)(uint32_t *, ze_driver_handle_t *, + ze_init_driver_type_desc_t *); + *(void **)&zeInitDriversFunc = + utils_get_symbol_addr(lib_handle, "zeInitDrivers", lib_name); + if (!zeInitDriversFunc) { + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + + ze_init_driver_type_desc_t desc = { + .stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC, + .pNext = NULL, + .flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU}; + uint32_t driverCount = 0; + ze_result_t result = zeInitDriversFunc(&driverCount, NULL, &desc); + if (result != ZE_RESULT_SUCCESS) { + return ze2umf_result(result); + } + + ze_driver_handle_t *zeAllDrivers = + umf_ba_global_alloc(sizeof(ze_driver_handle_t) * driverCount); + result = zeInitDriversFunc(&driverCount, zeAllDrivers, &desc); + umf_ba_global_free(zeAllDrivers); + if (result != ZE_RESULT_SUCCESS) { + return ze2umf_result(result); + } + + return UMF_RESULT_SUCCESS; +} + +static umf_result_t ze_init(void *lib_handle, const char *lib_name) { + ze_result_t (*zeInitFunc)(ze_init_flag_t); + *(void **)&zeInitFunc = + utils_get_symbol_addr(lib_handle, "zeInit", lib_name); + + if (!zeInitFunc) { + return UMF_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + + ze_result_t result = zeInitFunc(ZE_INIT_FLAG_GPU_ONLY); + return ze2umf_result(result); +} + static void init_ze_global_state(void) { #ifdef _WIN32 const char *lib_name = "ze_loader.dll"; @@ -266,6 +309,19 @@ static void init_ze_global_state(void) { utils_close_library(lib_handle); return; } + + if (ze_init_drivers(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { + LOG_INFO("Initializing Level Zero through zeInitDrivers failed. " + "Falling back to zeInit."); + + if (ze_init(lib_handle, lib_name) != UMF_RESULT_SUCCESS) { + LOG_FATAL("Failed to initialize Level Zero"); + Init_ze_global_state_failed = true; + utils_close_library(lib_handle); + return; + } + } + ze_lib_handle = lib_handle; }