Skip to content

Commit d165fba

Browse files
authored
CUDA Plugin EP: NHWC Cleanup & Hardening (#28612)
## Summary Unifies the NHWC-eligible op allowlist between the bundled CUDA EP and the CUDA plugin EP into a single shared header, adds kernel-miss diagnostics, and expands NHWC test coverage from 4 ops to 11. ## Motivation The bundled EP (`cuda_execution_provider.cc`) and the plugin EP (`plugin/cuda_ep.cc`) independently maintained their own copies of the NHWC allowlist. This created a maintenance hazard where ops could be added to one but not the other, leading to silent divergence. Additionally, there was no runtime diagnostic when the framework rewrote a node to the NHWC domain but the plugin EP lacked a matching kernel — failures were silent fallbacks to CPU. ## Key Changes ### Shared NHWC Allowlist (`cuda_nhwc_ops.h`) | Item | Detail | |------|--------| | New file | `onnxruntime/core/providers/cuda/cuda_nhwc_ops.h` | | Contents | `IsNhwcEligibleOnnxOp()`, `IsNhwcEligibleMsOp()`, `IsNhwcEligible()` inline functions | | Ops covered | AveragePool, BatchNormalization, Conv, ConvTranspose, DepthToSpace, GlobalAveragePool, GlobalMaxPool, GridSample, LRN, MaxPool, SpaceToDepth (+ MS-domain GridSample) | ### Bundled EP Refactor (`cuda_execution_provider.cc`) - Removed the static `std::unordered_set<std::string_view> cuda_nhwc_onnx_ops` and the inline domain check logic. - Replaced with a single call to `cuda::IsNhwcEligible(node_domain, node_op_type)`. ### Plugin EP Refactor & Diagnostics (`plugin/cuda_ep.cc`) - `ShouldConvertDataLayoutForOpImpl`: Replaced ~20 lines of static set + domain checks with a single `cuda::IsNhwcEligible()` call. - `GetCapabilityImpl`: Added a WARNING-level diagnostic in the `else` branch (kernel not found). When a node in the `com.ms.internal.nhwc` domain has no registered kernel, the log emits the op type, domain, version, and node name — making future NHWC registration gaps immediately visible at session creation. ### Expanded NHWC Test Coverage (`test_cuda_plugin_ep.py`) - Added `_assert_nhwc_domain_assigned()` helper that verifies NHWC layout transformation occurred by checking for framework-inserted Transpose nodes in the EP's assignment info. - Added `_run_nhwc_model_test()` helper combining domain assertion + numerical validation. - Updated 4 existing NHWC tests (Conv, BatchNormalization, MaxPool, AveragePool) to include structural assertions. - Added 7 new NHWC test methods: - `test_nhwc_conv_transpose` - `test_nhwc_global_max_pool` - `test_nhwc_global_average_pool` - `test_nhwc_depth_to_space` - `test_nhwc_space_to_depth` - `test_nhwc_lrn` - `test_nhwc_grid_sample` ## Testing Notes Run the full CUDA plugin EP test suite with NHWC enabled: ```bash bash .env/cuda13_plugin.sh --build --install --test_plugin ``` Or run only the NHWC tests directly: ```bash cd onnxruntime/test/python/transformers ORT_TEST_CUDA_PLUGIN_EP=1 python -m unittest \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_conv \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_batch_normalization \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_maxpool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_avgpool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_conv_transpose \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_global_max_pool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_global_average_pool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_depth_to_space \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_space_to_depth \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_lrn \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_grid_sample ``` All 86 tests in the suite pass (11 NHWC + 75 existing), with no regressions.
1 parent 6a517f5 commit d165fba

4 files changed

Lines changed: 355 additions & 49 deletions

File tree

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "core/platform/env_var_utils.h"
1313
#include "core/providers/cuda/cuda_execution_provider.h"
1414
#include "core/providers/cuda/cuda_common.h"
15+
#include "core/providers/cuda/cuda_nhwc_ops.h"
1516
#include "core/providers/cuda/cuda_allocator.h"
1617
#include "core/providers/cuda/cuda_fwd.h"
1718
#include "core/providers/cuda/gpu_data_transfer.h"
@@ -383,23 +384,7 @@ std::optional<bool> CUDAExecutionProvider::ShouldConvertDataLayoutForOp([[maybe_
383384
return std::nullopt;
384385
}
385386

386-
// TODO(mtavenrath) generate list from registered kernels using nhwc domain
387-
static const std::unordered_set<std::string_view> cuda_nhwc_onnx_ops{
388-
"BatchNormalization",
389-
"Conv",
390-
"ConvTranspose",
391-
"GlobalMaxPool",
392-
"MaxPool",
393-
"GlobalAveragePool",
394-
"AveragePool",
395-
"GridSample",
396-
"DepthToSpace",
397-
"SpaceToDepth",
398-
"LRN",
399-
};
400-
401-
return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) ||
402-
(node_domain == kMSDomain && node_op_type == "GridSample");
387+
return cuda::IsNhwcEligible(node_domain, node_op_type);
403388

404389
#else // defined(ENABLE_CUDA_NHWC_OPS)
405390
ORT_UNUSED_PARAMETER(node_domain);
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <string_view>
7+
8+
namespace onnxruntime {
9+
namespace cuda {
10+
11+
// Unified allowlist of ops eligible for NHWC layout conversion in both the
12+
// bundled CUDA EP and the CUDA plugin EP. Maintaining a single source of truth
13+
// prevents silent divergence between the two implementations.
14+
15+
inline bool IsNhwcEligibleOnnxOp(std::string_view op_type) {
16+
// Alphabetical order for easy maintenance.
17+
return op_type == "AveragePool" ||
18+
op_type == "BatchNormalization" ||
19+
op_type == "Conv" ||
20+
op_type == "ConvTranspose" ||
21+
op_type == "DepthToSpace" ||
22+
op_type == "GlobalAveragePool" ||
23+
op_type == "GlobalMaxPool" ||
24+
op_type == "GridSample" ||
25+
op_type == "LRN" ||
26+
op_type == "MaxPool" ||
27+
op_type == "SpaceToDepth";
28+
}
29+
30+
inline bool IsNhwcEligibleMsOp(std::string_view op_type) {
31+
return op_type == "GridSample";
32+
}
33+
34+
// Returns true if the given (domain, op_type) pair is eligible for NHWC
35+
// conversion. |domain| should be kOnnxDomain ("") or kMSDomain
36+
// ("com.microsoft").
37+
inline bool IsNhwcEligible(std::string_view domain, std::string_view op_type) {
38+
if (domain.empty()) {
39+
return IsNhwcEligibleOnnxOp(op_type);
40+
}
41+
if (domain == "com.microsoft") {
42+
return IsNhwcEligibleMsOp(op_type);
43+
}
44+
return false;
45+
}
46+
47+
} // namespace cuda
48+
} // namespace onnxruntime

onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <unordered_set>
2121

2222
#include "core/graph/constants.h"
23+
#include "core/providers/cuda/cuda_nhwc_ops.h"
2324

2425
namespace onnxruntime {
2526
namespace cuda_plugin {
@@ -214,7 +215,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl(
214215
tentative_nodes.reserve(all_nodes.size());
215216

216217
for (const auto& node : all_nodes) {
217-
std::string ep_name = node.GetEpName();
218+
const std::string& ep_name = node.GetEpName();
218219
if (!ep_name.empty()) {
219220
if (ep_name == ep->name_) {
220221
candidate_nodes.push_back(node);
@@ -229,6 +230,18 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl(
229230
if (kernel_def != nullptr) {
230231
candidate_nodes.push_back(node);
231232
tentative_nodes.push_back(node);
233+
} else {
234+
// Emit a diagnostic when an NHWC-domain node has no matching kernel.
235+
// This helps identify gaps between the layout conversion allowlist and
236+
// the actually-registered NHWC kernels in the plugin build.
237+
const std::string& node_domain = node.GetDomain();
238+
if (node_domain == kMSInternalNHWCDomain) {
239+
ORT_CXX_LOGF(Ort::Logger(&ep->logger_), ORT_LOGGING_LEVEL_WARNING,
240+
"NHWC kernel miss: op=%s domain=%s version=%d node=%s - "
241+
"no matching kernel registered in the CUDA plugin EP.",
242+
node.GetOperatorType().c_str(), node_domain.c_str(),
243+
node.GetSinceVersion(), node.GetName().c_str());
244+
}
232245
}
233246
}
234247

@@ -308,36 +321,11 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl(
308321
return nullptr;
309322
}
310323

311-
// ONNX domain ops that have NHWC kernel registrations.
312-
static const std::unordered_set<std::string_view> cuda_nhwc_onnx_ops{
313-
"BatchNormalization",
314-
"Conv",
315-
"ConvTranspose",
316-
"GlobalMaxPool",
317-
"MaxPool",
318-
"GlobalAveragePool",
319-
"AveragePool",
320-
"GridSample",
321-
"DepthToSpace",
322-
"SpaceToDepth",
323-
"LRN",
324-
};
325-
326-
// Check ONNX domain (empty string) or MS domain (com.microsoft)
327-
bool is_onnx_domain = (safe_domain[0] == '\0');
328-
bool is_ms_domain = (std::strcmp(safe_domain, "com.microsoft") == 0);
329-
330-
if (is_onnx_domain && cuda_nhwc_onnx_ops.count(safe_op_type) > 0) {
324+
if (cuda::IsNhwcEligible(safe_domain, safe_op_type)) {
331325
*should_convert = 1; // Convert
332-
return nullptr;
333-
}
334-
335-
if (is_ms_domain && std::strcmp(safe_op_type, "GridSample") == 0) {
336-
*should_convert = 1; // Convert
337-
return nullptr;
326+
} else {
327+
*should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops.
338328
}
339-
340-
*should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops.
341329
return nullptr;
342330
#endif
343331
}

0 commit comments

Comments
 (0)