Skip to content

Commit 619b10c

Browse files
committed
[UR] Fix backend parsing in ONEAPI_DEVICE_SELECTOR
`native_cpu` is now accepted as a valid backend, and having a different backend to the one matching the platform is now valid. `native_cpu:*;level_zero:*` now works properly.
1 parent 62e74fa commit 619b10c

File tree

3 files changed

+35
-43
lines changed

3 files changed

+35
-43
lines changed

unified-runtime/source/common/ur_util.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ using EnvVarMap = std::map<std::string, std::vector<std::string>>;
209209
/// @param env_var_name name of an environment variable to be parsed
210210
/// @param reject_empy whether to throw an error on discovering an empty value
211211
/// @param allow_duplicate whether to allow multiple pairs with the same key
212+
/// @param lower convert keys to lowercase
212213
/// @return std::optional with a possible map with parsed parameters as keys and
213214
/// vectors of strings containing parsed values as keys.
214215
/// Otherwise, optional is set to std::nullopt when the environment
@@ -217,7 +218,8 @@ using EnvVarMap = std::map<std::string, std::vector<std::string>>;
217218
/// wrong format
218219
inline std::optional<EnvVarMap> getenv_to_map(const char *env_var_name,
219220
bool reject_empty = true,
220-
bool allow_duplicate = false) {
221+
bool allow_duplicate = false,
222+
bool lower = false) {
221223
char main_delim = ';';
222224
char key_value_delim = ':';
223225
char values_delim = ',';
@@ -254,6 +256,10 @@ inline std::optional<EnvVarMap> getenv_to_map(const char *env_var_name,
254256
throw_wrong_format_map(env_var_name, *env_var);
255257
}
256258

259+
if (lower) {
260+
std::transform(key.begin(), key.end(), key.begin(), tolower);
261+
}
262+
257263
std::vector<std::string> values_vec;
258264
std::stringstream values_ss(values);
259265
std::string value;

unified-runtime/source/loader/ur_lib.cpp

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
251251
uint32_t NumEntries,
252252
ur_device_handle_t *phDevices,
253253
uint32_t *pNumDevices) {
254+
constexpr std::pair<const ur_platform_backend_t, const char *> adapters[6] = {
255+
{UR_PLATFORM_BACKEND_UNKNOWN, "*"},
256+
{UR_PLATFORM_BACKEND_LEVEL_ZERO, "level_zero"},
257+
{UR_PLATFORM_BACKEND_OPENCL, "opencl"},
258+
{UR_PLATFORM_BACKEND_CUDA, "cuda"},
259+
{UR_PLATFORM_BACKEND_HIP, "hip"},
260+
{UR_PLATFORM_BACKEND_NATIVE_CPU, "native_cpu"}};
254261

255262
if (!hPlatform) {
256263
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
@@ -323,7 +330,8 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
323330
// (If we wished to preserve the ordering of terms, we could replace
324331
// `std::map` with `std::queue<std::pair<key_type_t, value_type_t>>` or
325332
// something similar.)
326-
auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", false);
333+
auto maybeEnvVarMap =
334+
getenv_to_map("ONEAPI_DEVICE_SELECTOR", false, false, true);
327335
logger::debug(
328336
"getenv_to_map parsed env var and {} a map",
329337
(maybeEnvVarMap.has_value() ? "produced" : "failed to produce"));
@@ -380,35 +388,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
380388
sizeof(ur_platform_backend_t), &platformBackend, 0)) {
381389
return UR_RESULT_ERROR_INVALID_PLATFORM;
382390
}
383-
const std::string platformBackendName = // hPlatform->get_backend_name();
384-
[&platformBackend]() constexpr {
385-
switch (platformBackend) {
386-
case UR_PLATFORM_BACKEND_UNKNOWN:
387-
return "*"; // the only ODS string that matches
388-
break;
389-
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
390-
return "level_zero";
391-
break;
392-
case UR_PLATFORM_BACKEND_OPENCL:
393-
return "opencl";
394-
break;
395-
case UR_PLATFORM_BACKEND_CUDA:
396-
return "cuda";
397-
break;
398-
case UR_PLATFORM_BACKEND_HIP:
399-
return "hip";
400-
break;
401-
case UR_PLATFORM_BACKEND_NATIVE_CPU:
402-
return "*"; // the only ODS string that matches
403-
break;
404-
case UR_PLATFORM_BACKEND_FORCE_UINT32:
405-
return ""; // no ODS string matches this
406-
break;
407-
default:
408-
return ""; // no ODS string matches this
409-
break;
410-
}
411-
}();
412391

413392
using DeviceHardwareType = ur_device_type_t;
414393

@@ -482,18 +461,18 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
482461
// Note the hPlatform -> platformBackend -> platformBackendName conversion
483462
// above guarantees minimal sanity for the comparison with backend from the
484463
// ODS string
485-
if (backend.front() != '*' &&
486-
!std::equal(platformBackendName.cbegin(), platformBackendName.cend(),
487-
backend.cbegin(), backend.cend(),
488-
[](const auto &a, const auto &b) {
489-
// case-insensitive comparison by converting both tolower
490-
return std::tolower(static_cast<unsigned char>(a)) ==
491-
std::tolower(static_cast<unsigned char>(b));
492-
})) {
493-
// irrelevant term for current request: different backend -- silently
494-
// ignore
495-
logger::error("unrecognised backend '{}'", backend);
496-
return UR_RESULT_ERROR_INVALID_VALUE;
464+
if (backend.front() != '*') {
465+
auto cend = &adapters[sizeof(adapters) / sizeof(adapters[0])];
466+
auto found = std::find_if(adapters, cend,
467+
[&](auto &p) { return p.second == backend; });
468+
if (found == cend) {
469+
// It's not a legal backend
470+
logger::error("unrecognised backend '{}'", backend);
471+
return UR_RESULT_ERROR_INVALID_VALUE;
472+
} else if (found->first != platformBackend) {
473+
// If it's a rule for a different backend, ignore it
474+
continue;
475+
}
497476
}
498477
if (termPair.second.size() == 0) {
499478
// malformed term: missing filterStrings -- output ERROR

unified-runtime/test/conformance/device/urDeviceGetSelected.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ TEST_P(urDeviceGetSelectedTest, InvalidGarbageBackendString) {
178178
ASSERT_EQ(count, 0);
179179
}
180180

181+
TEST_P(urDeviceGetSelectedTest, SuccessCaseSensitive) {
182+
setenv("ONEAPI_DEVICE_SELECTOR", "OpEnCl:0", 1);
183+
uint32_t count = 0;
184+
ASSERT_SUCCESS(
185+
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count));
186+
}
187+
181188
TEST_P(urDeviceGetSelectedTest, InvalidMissingFilterStrings) {
182189
setenv("ONEAPI_DEVICE_SELECTOR", "*", 1);
183190
uint32_t count = 0;

0 commit comments

Comments
 (0)