Skip to content

DLPack: sync CUDA stream when exporting to TensorFlow #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

merlinND
Copy link
Member

@merlinND merlinND commented Apr 3, 2025

This PR attempts to fix synchronization issues that come up in the unit tests of the new TF interop feature (#301): #301 (comment)

Since TensorFlow uses non-default CUDA streams for compute and data movement, I believe that we need to synchronize the stream used by DrJit before exporting a tensor to TF.

@merlinND
Copy link
Member Author

merlinND commented Apr 3, 2025

@njroussel After running the relevant unit tests in a loop on my machine for a while, this seems to resolve the synchronization issue. I wouldn't want to introduce unnecessary synchronization points though; do you also think this is the right solution given that TF uses separate non-default streams?

@wjakob
Copy link
Member

wjakob commented Apr 4, 2025

@merlinND, I'm wondering if this change has the potential to affect the conversion of other DLPack tensors negatively? Synchronization is always bad and best avoided. For example, should we perhaps only interpret 2 as a force flush when the thread being active is not the main application thread?

@merlinND
Copy link
Member Author

merlinND commented Apr 4, 2025

Hi Wenzel,

I agree that we should definitely avoid adding unnecessary synchronizations.

Checking the documentation again, I see that:

stream is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via “wait for event”).
(...)

  • stream = 2 is the per-thread default stream

So essentially, when passing stream, the consumer (caller of __dlpack__()) indicates which stream they intend to use the value on. So if the stream they use is the same stream that we've produced the value on, there's no need for synchronization. Right?

should we perhaps only interpret 2 as a force flush when the thread being active is not the main application thread?

In DrJit, is the main default stream always used, even when launching kernels from non-main threads?
If so, then jit_sync_thread() only syncs the main thread's default stream regardless of which thread it's called from?
Then, indeed we would only want to call jit_sync_thread() when __dlpack__ is being called from a non-main thread.

In DrJit, what would be a reliable way to check if the current thread is the "main thread"? (= the thread for which the DrJit stream is the default stream)


If we make the changes above, then I'm not sure which value of stream to use in order to request a sync when calling .tf(), because:

  • TF uses non-default streams and doesn't expose any of them, so we can't pass stream=pointer
  • .tf() can be called from the main thread but we always want to sync (because TF never uses the same stream as DrJit AFAICT), so we can't pass stream=2
  • stream=0 would currently do the right thing in DrJit, but it's technically disallowed by the spec: "0 is disallowed due to its ambiguity: 0 could mean either None, 1, or 2"

Maybe stream=0 is still our best bet.

@merlinND
Copy link
Member Author

merlinND commented Apr 9, 2025

Hi @wjakob,

Could you please check the questions above? I'll update the PR based on your answers and it should fix the flaky TF unit tests.

@wjakob wjakob force-pushed the master branch 4 times, most recently from c821068 to dcb5217 Compare April 14, 2025 16:01
@wjakob
Copy link
Member

wjakob commented Apr 15, 2025

Hi Merlin,

I just had a moment to take a look at this.

So essentially, when passing stream, the consumer (caller of dlpack()) indicates which stream they intend to use the value on. So if the stream they use is the same stream that we've produced the value on, there's no need for synchronization. Right?

That's right. I looked again what Dr.Jit does, since this changed some time ago. Mitsuba creates a custom stream per device, which is set up with flag CU_STREAM_DEFAULT. This means that it synchronizes with respect to the NULL stream (global default stream) but not other ones (in particular, it does not implicitly synchronize with respect to per-thread default streams / PTDS). Mitsuba also does not use the PTDS/PTSZ ABI of the cuda driver. We used to, and it caused difficulties with parallel scene loading.
IMO the following behavior is sensible:

  • stream is -1: user opted out of synchronization, don't do anything.
  • stream is 0: ambiguous case, synchronize via jit_sync_thread().
  • stream is 1: user requested to sync with respect to global default thread. Don't need to do anything since we are already synced.
  • stream is 2: user requested to sync with respect to the per thread default stream. We do need to synchronize. But doing a device-wide synchronization that waits on the CPU is a highly pessimistic way of doing that. The better way is to insert an event and wait on that, but it means that we will need to access PTDS/PTSZ-flavored ABI of the cuda driver to get access to the per thread default stream. I've attached an untested patch that tries to do this below, I hope it makes sense.
  • stream is > 2: we are getting a CUDA stream handle. Pass to jitc_cuda_sync_stream(), which will push an event and then wait on that.
diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h
index 720ca16..200d3cc 100644
--- a/include/drjit-core/jit.h
+++ b/include/drjit-core/jit.h
@@ -202,8 +202,11 @@ extern JIT_EXPORT void *jit_cuda_lookup(const char *name);
  * \brief Add CUDA event synchronization between thread state's and external
  * CUDA stream.
  *
- * An event will be recorded into the thread's states stream and the external stream
- * will wait on the event before performing any subsequent work.
+ * An event will be recorded into the thread's states stream and the external
+ * stream will wait on the event before performing any subsequent work. The
+ * special value stream==2 denotes the caller's per-thread default stream.
+ * There is no need to ever synchronize with the global NULL stream, since
+ * Dr.Jit implicitly synchronizes with respect to it.
  *
  * \param stream The CUstream handle of the external stream
  */
diff --git a/src/cuda_api.cpp b/src/cuda_api.cpp
index 293f4de..b3d0f70 100644
--- a/src/cuda_api.cpp
+++ b/src/cuda_api.cpp
@@ -123,6 +123,7 @@ bool jitc_cuda_api_init() {
         LOAD(cuStreamDestroy, "v2");
         LOAD(cuStreamSynchronize);
         LOAD(cuStreamWaitEvent);
+        LOAD(cuStreamWaitEvent_ptsz);
         LOAD(cuPointerGetAttribute);
         LOAD(cuArrayCreate, "v2");
         LOAD(cuArray3DCreate, "v2");
@@ -174,7 +175,7 @@ void jitc_cuda_api_shutdown() {
     Z(cuModuleGetFunction); Z(cuModuleLoadData); Z(cuModuleLoadDataEx); Z(cuModuleUnload);
     Z(cuOccupancyMaxPotentialBlockSize); Z(cuCtxPushCurrent);
     Z(cuCtxPopCurrent); Z(cuStreamCreate); Z(cuStreamDestroy);
-    Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuPointerGetAttribute);
+    Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuStreamWaitEvent_ptsz); Z(cuPointerGetAttribute);
     Z(cuArrayCreate); Z(cuArray3DCreate); Z(cuArray3DGetDescriptor);
     Z(cuArrayDestroy); Z(cuTexObjectCreate); Z(cuTexObjectGetResourceDesc);
     Z(cuTexObjectDestroy); Z(cuMemcpy2DAsync); Z(cuMemcpy3DAsync);
diff --git a/src/cuda_api.h b/src/cuda_api.h
index a2a04c0..142d9e5 100644
--- a/src/cuda_api.h
+++ b/src/cuda_api.h
@@ -260,6 +260,7 @@ DR_CUDA_SYM(CUresult (*cuStreamCreate)(CUstream *, unsigned int));
 DR_CUDA_SYM(CUresult (*cuStreamDestroy)(CUstream));
 DR_CUDA_SYM(CUresult (*cuStreamSynchronize)(CUstream));
 DR_CUDA_SYM(CUresult (*cuStreamWaitEvent)(CUstream, CUevent, unsigned int));
+DR_CUDA_SYM(CUresult (*cuStreamWaitEvent_ptsz)(CUstream, CUevent, unsigned int));
 DR_CUDA_SYM(CUresult (*cuMemAllocAsync)(CUdeviceptr *, size_t, CUstream));
 DR_CUDA_SYM(CUresult (*cuMemFreeAsync)(CUdeviceptr, CUstream));

diff --git a/src/cuda_core.cpp b/src/cuda_core.cpp
index 1e54389..cbcbdda 100644
--- a/src/cuda_core.cpp
+++ b/src/cuda_core.cpp
@@ -109,8 +109,12 @@ std::pair<CUmodule, bool> jitc_cuda_compile(const char *buf, bool release_state_
 void jitc_cuda_sync_stream(uintptr_t stream) {
     ThreadState* ts = thread_state(JitBackend::CUDA);
     CUevent sync_event = ts->sync_stream_event;
-    cuda_check(cuEventRecord(sync_event, (CUstream)ts->stream));
-    cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+    scoped_set_context guard(ts->context);
+    cuda_check(cuEventRecord(sync_event, (CUstream) ts->stream));
+    if (stream != 2)
+        cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+    else
+        cuda_check(cuStreamWaitEvent_ptsz(nullptr, sync_event, CU_EVENT_DEFAULT));
 }

@merlinND
Copy link
Member Author

Thank you @wjakob for looking into it and providing the patch! I've opened mitsuba-renderer/drjit-core#139 with your patch.
The solution makes sense to me.

@dvicini, do the changes in this PR make sense to you? In particular, calling array.tf() will now trigger a jit_sync_thread() call.
The reasoning is that otherwise, the streams used by TF and the stream used by DrJit are not synchronized at all. Did you ever run into sync issues?

@wjakob
Copy link
Member

wjakob commented Apr 16, 2025

@merlinND Is that so? I thought we now insert an event and wait for it asynchronously. That is assuming that TF uses the special 2 argument.

@merlinND
Copy link
Member Author

Sorry, it was not very clear because two things were included in this PR:

  1. TF uses 3 non-default streams (compute, device-to-host memcpy, host-to-device memcpy) that it creates and manages itself. I didn't find a way to get pointers to those streams, so I don't know how to ensure synchronization w.r.t. TF's streams other than jit_sync_thread().
  2. stream=2 is part of the DLPack spec but was not correctly implemented in DrJit. Your patch added support, which is nice. But I don't think we can use stream=2 in array.tf() because TF doesn't use the per-thread stream.

Please let me know if I missed something. As you said, jit_sync_thread() is heavy and we want to avoid it as much as possible.

@wjakob
Copy link
Member

wjakob commented Apr 18, 2025

Ok, that makes sense. Let's wait for @dvicini's on the Google™ viewpoint before merging :-)

@merlinND
Copy link
Member Author

Ping @dvicini, do you have an opinion above the above regarding TF interop?
#377 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants