Skip to content

Commit

Permalink
[webgpu] no longer need pass-in gpu adapter for custom context (#23593)
Browse files Browse the repository at this point in the history
### Description

Remove the need to pass in the GPU adapter for the custom context.

With the introduction of the `wgpuDeviceGetAdapterInfo` API, we no
longer need user to specify the GPU adapter when creating a custom
context.
  • Loading branch information
fs-eire authored Feb 10, 2025
1 parent af679a0 commit e666503
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 38 deletions.
41 changes: 19 additions & 22 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace webgpu {

void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
if (device_ == nullptr) {
// Create wgpu::Adapter
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
//
Expand Down Expand Up @@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
req_adapter_options.nextInChain = &adapter_toggles_desc;
#endif

wgpu::Adapter adapter;
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
&req_adapter_options,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message});
*ptr = adapter;
*ptr = std::move(adapter);
},
&adapter_),
&adapter),
UINT64_MAX));
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
}
ORT_ENFORCE(adapter != nullptr, "Failed to get a WebGPU adapter.");

// Create wgpu::Device
if (device_ == nullptr) {
// Create wgpu::Device
wgpu::DeviceDescriptor device_desc = {};

#if !defined(__wasm__)
Expand All @@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
#endif

std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter);
if (required_features.size() > 0) {
device_desc.requiredFeatures = required_features.data();
device_desc.requiredFeatureCount = required_features.size();
}
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_);
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter);
device_desc.requiredLimits = &required_limits;

// TODO: revise temporary error handling
Expand All @@ -123,20 +122,20 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
});

ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
&device_desc,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message});
*ptr = device;
*ptr = std::move(device);
},
&device_),
UINT64_MAX));
ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device.");
}

// cache adapter info
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
// cache device limits
wgpu::SupportedLimits device_supported_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
Expand Down Expand Up @@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
WGPUInstance instance = config.instance;
WGPUAdapter adapter = config.adapter;
WGPUDevice device = config.device;

if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
ORT_ENFORCE(instance == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");

std::call_once(init_default_flag_, [
#if !defined(__wasm__)
Expand Down Expand Up @@ -750,23 +748,22 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
});
instance = default_instance_.Get();
} else {
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device.");
// for context ID > 0, user must provide custom WebGPU instance and device.
ORT_ENFORCE(instance != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
}

std::lock_guard<std::mutex> lock(mutex_);

auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->adapter_.Get() == adapter &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
}
it->second.ref_count++;
return *it->second.context;
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ProgramBase;
struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
Expand Down Expand Up @@ -76,7 +75,6 @@ class WebGpuContext final {

Status Wait(wgpu::Future f);

const wgpu::Adapter& Adapter() const { return adapter_; }
const wgpu::Device& Device() const { return device_; }

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
Expand Down Expand Up @@ -149,8 +147,8 @@ class WebGpuContext final {
AtPasses
};

WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

std::vector<const char*> GetEnabledAdapterToggles() const;
Expand Down Expand Up @@ -198,7 +196,6 @@ class WebGpuContext final {
LibraryHandles modules_;

wgpu::Instance instance_;
wgpu::Adapter adapter_;
wgpu::Device device_;

webgpu::ValidationMode validation_mode_;
Expand Down
11 changes: 1 addition & 10 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec);
}

size_t webgpu_adapter = 0;
std::string webgpu_adapter_str;
if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) {
static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch");
ORT_ENFORCE(std::errc{} ==
std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec);
}

size_t webgpu_device = 0;
std::string webgpu_device_str;
if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) {
Expand Down Expand Up @@ -154,7 +146,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
webgpu::WebGpuContextConfig context_config{
context_id,
reinterpret_cast<WGPUInstance>(webgpu_instance),
reinterpret_cast<WGPUAdapter>(webgpu_adapter),
reinterpret_cast<WGPUDevice>(webgpu_device),
reinterpret_cast<const void*>(dawn_proc_table),
validation_mode,
Expand Down Expand Up @@ -238,7 +229,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
// STEP.4 - start initialization.
//

// Load the Dawn library and create the WebGPU instance and adapter.
// Load the Dawn library and create the WebGPU instance.
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);

// Create WebGPU device and initialize the context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType";

constexpr const char* kDeviceId = "WebGPU:deviceId";
constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance";
constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter";
constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice";

constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode";
Expand Down

0 comments on commit e666503

Please sign in to comment.