Skip to content

Commit 7afe4c2

Browse files
[Plugin EP] Port graph capture/replay APIs (#27958)
## Description Ports graph capture/replay APIs (e.g., CUDA Graph) to the Plugin EP (`OrtEp`) C API so that plugin-based execution providers can participate in ORT-managed graph capture and replay. ### What changed **New Plugin EP C API functions** (`onnxruntime_ep_c_api.h`): - `OrtEp::IsGraphCaptureEnabled` — indicates whether the EP has graph capture enabled. - `OrtEp::IsGraphCaptured` — indicates whether a graph has been captured for a given annotation ID. - `OrtEp::ReplayGraph` — replays a previously captured graph. - `OrtEp::GetGraphCaptureNodeAssignmentPolicy` — returns the node assignment validation policy for graph capture. All four are optional (NULL defaults to safe behavior) and version-gated (`ort_version_supported >= 26`). If `IsGraphCaptureEnabled` returns true, `IsGraphCaptured` and `ReplayGraph` must also be implemented. otherwise `PluginExecutionProvider` logs a warning and disables graph capture for that EP. **New `OrtGraphCaptureNodeAssignmentPolicy` enum** (`onnxruntime_ep_c_api.h`): Replaces the hardcoded EP-name checks in `InferenceSession::Initialize()` with a policy-based approach: - `ALL_NODES_ON_EP` — all nodes must be on the target EP (e.g., TensorRT). - `ALLOW_CPU_FOR_SHAPES` — CPU nodes allowed for shape computation if no memcpy nodes exist (e.g., CUDA, WebGPU, DML). **Refactored `InferenceSession` graph capture selection** (`inference_session.cc`): - Removed the hardcoded `graph_support_ep_list` and per-EP `strcmp` checks. - Now iterates over all registered EPs and uses `IsGraphCaptureEnabled()` + `GetGraphCaptureNodeAssignmentPolicy()` to select and validate the graph-capturing EP. - `AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp()` → generalized to `AreAllComputeNodesAssignedToEpOrCpu()`, which also requires at least one node on the target EP. - `IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy()` added to the base class (defaults to `ALL_NODES_ON_EP`). **Bounded graph capture recursion** (`inference_session.cc/h`): - `Run()` now delegates to `RunImpl()` with a `graph_capture_depth` parameter. - Caps internal run attempts at `kMaxGraphCaptureRunAttempts = 8`, returning a clear error if the EP never reports `IsGraphCaptured() == true`. **EP implementations**: - **WebGPU plugin EP**: Fully implements all four graph capture APIs by forwarding to the underlying `IExecutionProvider`. - **CUDA plugin EP**: Stubs with TODOs (returns disabled/not-implemented). - **NvTensorRTRTX EP**: `IsGraphCaptureEnabled()` now returns `false` since this EP manages graph capture internally (not via ORT). **C++ wrapper** (`onnxruntime_cxx_api.h` / `onnxruntime_cxx_inline.h`): - Added `Ort::Env::CopyTensor()` convenience overload for copying a single tensor (wraps `CopyTensors` with `num_tensors=1`). ### Tests - **`ep_plugin_provider_test.cc`**: Unit tests for each new `PluginExecutionProvider` graph capture method, including NULL function pointer defaults, version < 26 backward compatibilities, and validation that `IsGraphCaptureEnabled()` returns false when `IsGraphCaptured` or `ReplayGraph` are NULL. - **`test_graph_capture.cc`**: End-to-end test for WebGPU plugin EP graph capture/replay using IO binding (warm-up + capture run, then replay with different inputs). ### Motivation and Context Previously, graph capture support was limited to a hardcoded list of EPs (`kCudaExecutionProvider`, `kTensorrtExecutionProvider`, `kJsExecutionProvider`, `kWebGpuExecutionProvider`, `kDmlExecutionProvider`) with EP-specific validation logic in `InferenceSession`. This made it impossible for plugin EPs to participate in ORT-managed graph capture/replay without modifying the core session code. This PR makes graph capture/replay extensible to any EP, including out-of-tree plugin EPs, by exposing it through the `OrtEp` C API.
1 parent 3fad293 commit 7afe4c2

20 files changed

Lines changed: 708 additions & 107 deletions

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ class IExecutionProvider {
292292
return Status::OK();
293293
}
294294

295+
/**
296+
Get the node assignment validation policy for graph capture.
297+
When graph capture is enabled, ORT validates that nodes are assigned to EPs
298+
in a way compatible with graph capture. This tells ORT which policy to apply.
299+
*/
300+
virtual OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const {
301+
return OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP;
302+
}
303+
295304
/**
296305
Called when session creation is complete
297306
This provides an opportunity for execution providers to optionally synchronize and

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,10 @@ struct Env : detail::Base<OrtEnv> {
13891389
const std::vector<Value>& dst_tensors,
13901390
OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors
13911391

1392+
/// Wraps OrtApi::CopyTensors
1393+
/// Copies only one src tensor to another dst tensor.
1394+
Status CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const;
1395+
13921396
/// \brief Wraps OrtApi::SetPerSessionThreadPoolCallbacks
13931397
/// Stores work callbacks on the Env for per-session thread pools.
13941398
/// Only affects sessions created after this call. Does not affect global thread pools.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,11 @@ inline Status Env::CopyTensors(const std::vector<Value>& src_tensors,
10551055
return Status(status);
10561056
}
10571057

1058+
inline Status Env::CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const {
1059+
OrtStatus* status = GetApi().CopyTensors(p_, &src_tensor, &dst_tensor, stream, 1);
1060+
return Status(status);
1061+
}
1062+
10581063
inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type,
10591064
OrtAllocatorType allocator_type,
10601065
const OrtKeyValuePairs* allocator_options) {

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,6 +2027,23 @@ typedef enum OrtEpDataLayout {
20272027
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
20282028
} OrtEpDataLayout;
20292029

2030+
/**
2031+
* \brief Node assignment policies for graph capture validation.
2032+
*
2033+
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
2034+
* compatible with graph capture. An EP can specify which validation policy ORT should apply.
2035+
*
2036+
* \since Version 1.26.
2037+
*/
2038+
typedef enum OrtGraphCaptureNodeAssignmentPolicy {
2039+
/** All nodes in the main graph must be assigned to this EP. No CPU fallback is allowed. */
2040+
OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP = 0,
2041+
2042+
/** Compute nodes must be on this EP. CPU nodes are allowed for shape computation as long as
2043+
* no memory copy nodes exist. */
2044+
OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES = 1,
2045+
} OrtGraphCaptureNodeAssignmentPolicy;
2046+
20302047
/**
20312048
* \brief The OrtEp struct provides functions to implement for an execution provider.
20322049
* \since Version 1.22.
@@ -2346,6 +2363,101 @@ struct OrtEp {
23462363
*/
23472364
ORT_API2_STATUS(CreateProfiler, _In_ OrtEp* this_ptr,
23482365
_Outptr_result_maybenull_ OrtEpProfilerImpl** profiler);
2366+
2367+
/** \brief Indicate whether the graph capturing mode (e.g., CUDA graph) is enabled for the provider.
2368+
*
2369+
* Graph capture allows an EP to record a sequence of device (e.g., GPU) operations during an initial run and replay
2370+
* them on subsequent runs, bypassing per-kernel CPU launch overhead.
2371+
*
2372+
* Applications enable graph capture via EP-specific provider options (e.g., `enable_cuda_graph=1`
2373+
* for the CUDA EP). An EP should return true from this function if it has been configured to enable
2374+
* graph capture/replay.
2375+
*
2376+
* **ORT graph capture/replay summary:**
2377+
* During OrtSession initialization, ORT calls OrtEp::IsGraphCaptureEnabled() on each EP in the order specified during
2378+
* provider registration with the session. If an EP returns true, ORT validates that the graph is suitable for
2379+
* graph capture, and if so, caches the EP for graph capture during the next run. The graph validation ensures
2380+
* that there are no control flow nodes and that node-to-EP assignments are compatible with the policy specified
2381+
* by the EP via OrtEp::GetGraphCaptureNodeAssignmentPolicy().
2382+
* Note that an OrtSession only supports graph capture for one EP (i.e., the first EP to claim support).
2383+
*
2384+
* During the first call to OrtApi::Run() for the OrtSession, ORT performs multiple internal runs of the model
2385+
* until the EP indicates that the graph has been captured by returning `true` from `OrtEp::IsGraphCaptured()`.
2386+
* If the EP is unable to capture the graph within 8 runs, the call to OrtApi::Run() returns an error OrtStatus.
2387+
* Each internal run invokes `OrtEp::OnRunStart()`, normal execution, and `OrtEp::OnRunEnd()`. EPs should use
2388+
* these run callbacks to track the number of necessary warm-up runs and begin/end graph capture when ready.
2389+
*
2390+
* After successful graph capture, subsequent calls to OrtApi::Run() skip normal execution and ORT instead calls
2391+
* `OrtEp::ReplayGraph()` directly.
2392+
*
2393+
* Applications can capture and replay multiple graphs (e.g., one per distinct input shape) by setting the
2394+
* `"gpu_graph_id"` run config entry via `OrtApi::AddRunConfigEntry()` to different integer values. ORT passes
2395+
* the value as the `graph_annotation_id` parameter to `OrtEp::IsGraphCaptured()` and `OrtEp::ReplayGraph()`.
2396+
*
2397+
* \param[in] this_ptr The OrtEp instance.
2398+
* \return true if graph capture mode is enabled, false otherwise.
2399+
*
2400+
* \note Implementation of this function is optional. If set to NULL, ORT assumes graph capture is not enabled.
2401+
* \note If this function returns true, `OrtEp::IsGraphCaptured` and `OrtEp::ReplayGraph` must also be implemented.
2402+
* If either is NULL, ORT will log a warning and ignore this EP for graph capture.
2403+
*
2404+
* \since Version 1.26.
2405+
*/
2406+
ORT_API_T(bool, IsGraphCaptureEnabled, _In_ const OrtEp* this_ptr);
2407+
2408+
/** \brief Indicate whether a graph has been captured and instantiated.
2409+
*
2410+
* ORT calls this before each `Session::Run()`. If true, ORT calls `ReplayGraph()` instead of
2411+
* normal execution. After a run where this returns false, ORT automatically retries until it
2412+
* returns true (handling warm-up runs transparently).
2413+
*
2414+
* \param[in] this_ptr The OrtEp instance.
2415+
* \param[in] graph_annotation_id Identifies which captured graph to query.
2416+
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
2417+
* The default value is 0 when the run config entry is not set.
2418+
* Setting different IDs allows the EP to capture and manage multiple graphs (e.g., one per
2419+
* distinct input shape). A value of -1 means graph capture/replay should be skipped for this run.
2420+
* \return true if the graph has been captured, false otherwise.
2421+
*
2422+
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
2423+
*
2424+
* \since Version 1.26.
2425+
*/
2426+
ORT_API_T(bool, IsGraphCaptured, _In_ const OrtEp* this_ptr, _In_ int graph_annotation_id);
2427+
2428+
/** \brief Run the instantiated (captured) graph.
2429+
*
2430+
* Called by ORT instead of normal execution when `IsGraphCaptured()` returns true.
2431+
*
2432+
* \param[in] this_ptr The OrtEp instance.
2433+
* \param[in] graph_annotation_id Identifies which captured graph to replay.
2434+
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
2435+
* The default value is 0 when the run config entry is not set.
2436+
* A value of -1 means graph replay should be skipped for this run.
2437+
*
2438+
* \snippet{doc} snippets.dox OrtStatus Return Value
2439+
*
2440+
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
2441+
*
2442+
* \since Version 1.26.
2443+
*/
2444+
ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
2445+
2446+
/** \brief Get the node assignment validation policy for graph capture.
2447+
*
2448+
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
2449+
* compatible with graph capture. This function tells ORT which validation policy to apply.
2450+
*
2451+
* \param[in] this_ptr The OrtEp instance.
2452+
* \return The node assignment policy for graph capture.
2453+
*
2454+
* \note Implementation of this function is optional. If set to NULL, ORT uses
2455+
* OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP (strictest validation).
2456+
*
2457+
* \since Version 1.26.
2458+
*/
2459+
ORT_API_T(OrtGraphCaptureNodeAssignmentPolicy, GetGraphCaptureNodeAssignmentPolicy,
2460+
_In_ const OrtEp* this_ptr);
23492461
};
23502462

23512463
/** \brief The function signature that ORT will call to create OrtEpFactory instances.

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
124124
bool IsGraphCaptureEnabled() const override;
125125
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
126126
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
127+
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
128+
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
129+
}
127130
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
128131
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
129132
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo
5858
Compile = nullptr;
5959
ReleaseNodeComputeInfos = nullptr;
6060

61+
// Graph capture/replay
62+
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
63+
IsGraphCaptured = IsGraphCapturedImpl;
64+
ReplayGraph = ReplayGraphImpl;
65+
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
66+
6167
const OrtApi& ort_api = factory_.GetOrtApi();
6268
Ort::Status log_status(ort_api.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_INFO,
6369
"CUDA Plugin EP created",
@@ -304,5 +310,30 @@ OrtStatus* ORT_API_CALL CudaEp::SyncImpl(OrtEp* this_ptr) noexcept {
304310
EXCEPTION_TO_STATUS_END
305311
}
306312

313+
bool ORT_API_CALL CudaEp::IsGraphCaptureEnabledImpl(const OrtEp* /*this_ptr*/) noexcept {
314+
// TODO: forward to EpImpl()->IsGraphCaptureEnabled()
315+
return false;
316+
}
317+
318+
/*static*/
319+
bool ORT_API_CALL CudaEp::IsGraphCapturedImpl(const OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
320+
// TODO: forward to EpImpl()->IsGraphCaptured(graph_annotation_id)
321+
return false;
322+
}
323+
324+
/*static*/
325+
OrtStatus* ORT_API_CALL CudaEp::ReplayGraphImpl(OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
326+
// TODO: forward to EpImpl()->ReplayGraph(graph_annotation_id)
327+
return Ort::GetApi().CreateStatus(ORT_NOT_IMPLEMENTED,
328+
"Graph capture replay is not yet supported in the CUDA plugin EP.");
329+
}
330+
331+
/*static*/
332+
OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL CudaEp::GetGraphCaptureNodeAssignmentPolicyImpl(
333+
const OrtEp* /*this_ptr*/) noexcept {
334+
// TODO: forward to EpImpl()->GetGraphCaptureNodeAssignmentPolicy()
335+
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
336+
}
337+
307338
} // namespace cuda_plugin
308339
} // namespace onnxruntime

onnxruntime/core/providers/cuda/plugin/cuda_ep.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ class CudaEp : public onnxruntime::ep::adapter::Ep {
6161

6262
static OrtStatus* ORT_API_CALL SyncImpl(OrtEp* this_ptr) noexcept;
6363

64+
static bool ORT_API_CALL IsGraphCaptureEnabledImpl(const OrtEp* this_ptr) noexcept;
65+
66+
static bool ORT_API_CALL IsGraphCapturedImpl(const OrtEp* this_ptr,
67+
int graph_annotation_id) noexcept;
68+
69+
static OrtStatus* ORT_API_CALL ReplayGraphImpl(OrtEp* this_ptr,
70+
int graph_annotation_id) noexcept;
71+
72+
static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
73+
const OrtEp* this_ptr) noexcept;
74+
6475
CudaEpFactory& factory_;
6576
std::string name_;
6677
Config config_;

onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ namespace Dml
356356
return m_impl->ReplayGraph(graph_annotation_id);
357357
}
358358

359+
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override
360+
{
361+
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
362+
}
363+
359364
private:
360365
ComPtr<ExecutionProviderImpl> m_impl;
361366
};

onnxruntime/core/providers/js/js_execution_provider.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class JsExecutionProvider : public IExecutionProvider {
7272
bool IsGraphCaptureEnabled() const override;
7373
bool IsGraphCaptured(int graph_annotation_id) const override;
7474
Status ReplayGraph(int graph_annotation_id) override;
75+
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
76+
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
77+
}
7578

7679
private:
7780
bool IsGraphCaptureAllowed() const;

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,9 @@ void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require
12801280
}
12811281

12821282
bool NvExecutionProvider::IsGraphCaptureEnabled() const {
1283-
return cuda_graph_enable_;
1283+
// Return false so that ORT's framework does not cache this EP for ORT-managed graph capture/replay.
1284+
// NvTensorRTRTX manages CUDA graph capture/replay internally.
1285+
return false;
12841286
}
12851287

12861288
bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {

0 commit comments

Comments
 (0)