Skip to content

Commit

Permalink
Keep names
Browse files Browse the repository at this point in the history
  • Loading branch information
jchen10 committed Feb 7, 2025
1 parent 8b0bccb commit 9501966
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context)
}

void ComputeContext::PushErrorScope() {
webgpu_context_.PushErrorScopeIfNoLowerThan(ValidationMode::Full);
if (webgpu_context_.ValidationMode() >= ValidationMode::Full) {
webgpu_context_.PushErrorScope();
}
}

Status ComputeContext::PopErrorScope() {
return webgpu_context_.PopErrorScopeIfNoLowerThan(ValidationMode::Full);
if (webgpu_context_.ValidationMode() >= ValidationMode::Full) {
return webgpu_context_.PopErrorScope();
}
return Status::OK();
}

} // namespace webgpu
Expand Down
30 changes: 12 additions & 18 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,26 +616,20 @@ void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events,
}
}

void WebGpuContext::PushErrorScopeIfNoLowerThan(webgpu::ValidationMode validation_mode) {
if (ValidationMode() >= validation_mode) {
device_.PushErrorScope(wgpu::ErrorFilter::Validation);
}
}
void WebGpuContext::PushErrorScope() { device_.PushErrorScope(wgpu::ErrorFilter::Validation); }

Status WebGpuContext::PopErrorScopeIfNoLowerThan(webgpu::ValidationMode validation_mode) {
Status WebGpuContext::PopErrorScope() {
Status status{};
if (ValidationMode() >= validation_mode) {
ORT_RETURN_IF_ERROR(Wait(device_.PopErrorScope(
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) {
ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped.");
if (error_type == wgpu::ErrorType::NoError) {
return;
}
*status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message);
},
&status)));
}
ORT_RETURN_IF_ERROR(Wait(device_.PopErrorScope(
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) {
ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped.");
if (error_type == wgpu::ErrorType::NoError) {
return;
}
*status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message);
},
&status)));
return status;
}

Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,18 @@ class WebGpuContext final {
void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events);

//
// Push error scope if the context's validation mode is no lower than the given mode.
// Push error scope.
//
// This is useful only when "skip_validation" is not set.
//
void PushErrorScopeIfNoLowerThan(webgpu::ValidationMode validation_mode);
void PushErrorScope();

//
// Pop error scope if the context's validation mode is no lower than the given mode.
// Pop error scope.
//
// This is useful only when "skip_validation" is not set.
//
Status PopErrorScopeIfNoLowerThan(webgpu::ValidationMode validation_mode);
Status PopErrorScope();

Status Run(ComputeContext& context, const ProgramBase& program);

Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,9 @@ std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
}

Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
context_.PushErrorScopeIfNoLowerThan(ValidationMode::Basic);
if (context_.ValidationMode() >= ValidationMode::Basic) {
context_.PushErrorScope();
}

if (profiler_->Enabled()) {
context_.StartProfiling();
Expand Down Expand Up @@ -860,7 +862,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}

return context_.PopErrorScopeIfNoLowerThan(ValidationMode::Basic);
if (context_.ValidationMode() >= ValidationMode::Basic) {
return context_.PopErrorScope();
}

return Status::OK();
}

bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
Expand Down

0 comments on commit 9501966

Please sign in to comment.