@@ -251,6 +251,13 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
251
251
uint32_t NumEntries,
252
252
ur_device_handle_t *phDevices,
253
253
uint32_t *pNumDevices) {
254
+ constexpr std::pair<const ur_platform_backend_t , const char *> adapters[6 ] = {
255
+ {UR_PLATFORM_BACKEND_UNKNOWN, " *" },
256
+ {UR_PLATFORM_BACKEND_LEVEL_ZERO, " level_zero" },
257
+ {UR_PLATFORM_BACKEND_OPENCL, " opencl" },
258
+ {UR_PLATFORM_BACKEND_CUDA, " cuda" },
259
+ {UR_PLATFORM_BACKEND_HIP, " hip" },
260
+ {UR_PLATFORM_BACKEND_NATIVE_CPU, " native_cpu" }};
254
261
255
262
if (!hPlatform) {
256
263
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
@@ -323,7 +330,8 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
323
330
// (If we wished to preserve the ordering of terms, we could replace
324
331
// `std::map` with `std::queue<std::pair<key_type_t, value_type_t>>` or
325
332
// something similar.)
326
- auto maybeEnvVarMap = getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , false );
333
+ auto maybeEnvVarMap =
334
+ getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , false , false , true );
327
335
logger::debug (
328
336
" getenv_to_map parsed env var and {} a map" ,
329
337
(maybeEnvVarMap.has_value () ? " produced" : " failed to produce" ));
@@ -380,35 +388,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
380
388
sizeof (ur_platform_backend_t ), &platformBackend, 0 )) {
381
389
return UR_RESULT_ERROR_INVALID_PLATFORM;
382
390
}
383
- const std::string platformBackendName = // hPlatform->get_backend_name();
384
- [&platformBackend]() constexpr {
385
- switch (platformBackend) {
386
- case UR_PLATFORM_BACKEND_UNKNOWN:
387
- return " *" ; // the only ODS string that matches
388
- break ;
389
- case UR_PLATFORM_BACKEND_LEVEL_ZERO:
390
- return " level_zero" ;
391
- break ;
392
- case UR_PLATFORM_BACKEND_OPENCL:
393
- return " opencl" ;
394
- break ;
395
- case UR_PLATFORM_BACKEND_CUDA:
396
- return " cuda" ;
397
- break ;
398
- case UR_PLATFORM_BACKEND_HIP:
399
- return " hip" ;
400
- break ;
401
- case UR_PLATFORM_BACKEND_NATIVE_CPU:
402
- return " *" ; // the only ODS string that matches
403
- break ;
404
- case UR_PLATFORM_BACKEND_FORCE_UINT32:
405
- return " " ; // no ODS string matches this
406
- break ;
407
- default :
408
- return " " ; // no ODS string matches this
409
- break ;
410
- }
411
- }();
412
391
413
392
using DeviceHardwareType = ur_device_type_t ;
414
393
@@ -482,18 +461,18 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
482
461
// Note the hPlatform -> platformBackend -> platformBackendName conversion
483
462
// above guarantees minimal sanity for the comparison with backend from the
484
463
// ODS string
485
- if (backend.front () != ' *' &&
486
- ! std::equal (platformBackendName. cbegin (), platformBackendName. cend (),
487
- backend. cbegin (), backend. cend () ,
488
- []( const auto &a, const auto &b ) {
489
- // case-insensitive comparison by converting both tolower
490
- return std::tolower ( static_cast < unsigned char >(a)) ==
491
- std::tolower ( static_cast < unsigned char >(b) );
492
- })) {
493
- // irrelevant term for current request: different backend -- silently
494
- // ignore
495
- logger::error ( " unrecognised backend '{}' " , backend) ;
496
- return UR_RESULT_ERROR_INVALID_VALUE;
464
+ if (backend.front () != ' *' ) {
465
+ auto cend = &adapters[ sizeof (adapters) / sizeof (adapters[ 0 ])];
466
+ auto found = std::find_if (adapters, cend,
467
+ [&]( auto &p ) { return p. second == backend; });
468
+ if (found == cend) {
469
+ // It's not a legal backend
470
+ logger::error ( " unrecognised backend '{}' " , backend );
471
+ return UR_RESULT_ERROR_INVALID_VALUE;
472
+ } else if (found-> first != platformBackend) {
473
+ // If it's a rule for a different backend, ignore it
474
+ continue ;
475
+ }
497
476
}
498
477
if (termPair.second .size () == 0 ) {
499
478
// malformed term: missing filterStrings -- output ERROR
0 commit comments