Skip to content

Commit 655b69c

Browse files
authored
[TRT RTX EP] Don't register set device function when we use existing stream (microsoft#26542)
### Description - Don't register set device function when we use existing stream - Fix bug nv_execution_provider.cc : set device only if user did not provide existing stream ### Motivation and Context In some use cases, we push a user generated CUDA context, create streams using this context, and then provide these streams to TRT-RTX. However, we noticed that after calling Run(), the custom context is replaced by another CUDA context created by ORT. This means that TRT-RTX is no longer using the original CUDA context. After investigating further, we found that the new context is being created in onnxruntime/core/framework/stream_execution_context.cc. The solution we propose is to not register set device function if we provide the stream. Also there is a bug in onnxruntime\core\providers\nv_tensorrt_rtx\nv_execution_provider.cc. We should set the device only if the user has not provided any stream. (coherent with the original comment)
1 parent 935affb commit 655b69c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

onnxruntime/core/providers/cuda/cuda_stream_handle.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,16 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
250250
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCudaNotificationOnDevice);
251251
// wait cuda notification on cpu ep
252252
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost);
253-
if (!use_existing_stream)
253+
if (!use_existing_stream) {
254254
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, ep_info](const OrtDevice& device) {
255255
CUDA_CALL_THROW(cudaSetDevice(device.Id()));
256256
cudaStream_t stream = nullptr;
257257
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
258258
// CUDA_CALL_THROW(cudaStreamCreate(&stream));
259259
return std::make_unique<CudaStream>(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr, ep_info);
260260
});
261-
else
261+
stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); });
262+
} else {
262263
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator,
263264
release_cpu_buffer_on_cuda_stream,
264265
external_stream,
@@ -267,7 +268,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
267268
ep_info](const OrtDevice& device) {
268269
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info);
269270
});
270-
stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); });
271+
}
271272
}
272273

273274
} // namespace onnxruntime

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
766766

767767
NvExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) {
768768
// Only set device if user hasn't provided a compute stream
769-
if (has_user_compute_stream) {
769+
if (!has_user_compute_stream) {
770770
CUDA_CALL_THROW(cudaSetDevice(device_id));
771771
(void)stream;
772772
}

0 commit comments

Comments
 (0)