-
Notifications
You must be signed in to change notification settings - Fork 48
DLPack: sync CUDA stream when exporting to TensorFlow #377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@njroussel After running the relevant unit tests in a loop on my machine for a while, this seems to resolve the synchronization issue. I wouldn't want to introduce unnecessary synchronization points though; do you also think this is the right solution given that TF uses separate non-default streams? |
@merlinND, I'm wondering if this change has the potential to affect the conversion of other DLPack tensors negatively? Synchronization is always bad and best avoided. For example, should we perhaps only interpret |
Hi Wenzel, I agree that we should definitely avoid adding unnecessary synchronizations. Checking the documentation again, I see that:
So essentially, when passing
In DrJit, is the main default stream always used, even when launching kernels from non-main threads? In DrJit, what would be a reliable way to check if the current thread is the "main thread"? (= the thread for which the DrJit stream is the default stream) If we make the changes above, then I'm not sure which value of
Maybe |
Hi @wjakob, Could you please check the questions above? I'll update the PR based on your answers and it should fix the flaky TF unit tests. |
c821068
to
dcb5217
Compare
Hi Merlin, I just had a moment to take a look at this.
That's right. I looked again what Dr.Jit does, since this changed some time ago. Mitsuba creates a custom stream per device, which is set up with flag
diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h
index 720ca16..200d3cc 100644
--- a/include/drjit-core/jit.h
+++ b/include/drjit-core/jit.h
@@ -202,8 +202,11 @@ extern JIT_EXPORT void *jit_cuda_lookup(const char *name);
* \brief Add CUDA event synchronization between thread state's and external
* CUDA stream.
*
- * An event will be recorded into the thread's states stream and the external stream
- * will wait on the event before performing any subsequent work.
+ * An event will be recorded into the thread's states stream and the external
+ * stream will wait on the event before performing any subsequent work. The
+ * special value stream==2 denotes the caller's per-thread default stream.
+ * There is no need to ever synchronize with the global NULL stream, since
+ * Dr.Jit implicitly synchronizes with respect to it.
*
* \param stream The CUstream handle of the external stream
*/
diff --git a/src/cuda_api.cpp b/src/cuda_api.cpp
index 293f4de..b3d0f70 100644
--- a/src/cuda_api.cpp
+++ b/src/cuda_api.cpp
@@ -123,6 +123,7 @@ bool jitc_cuda_api_init() {
LOAD(cuStreamDestroy, "v2");
LOAD(cuStreamSynchronize);
LOAD(cuStreamWaitEvent);
+ LOAD(cuStreamWaitEvent_ptsz);
LOAD(cuPointerGetAttribute);
LOAD(cuArrayCreate, "v2");
LOAD(cuArray3DCreate, "v2");
@@ -174,7 +175,7 @@ void jitc_cuda_api_shutdown() {
Z(cuModuleGetFunction); Z(cuModuleLoadData); Z(cuModuleLoadDataEx); Z(cuModuleUnload);
Z(cuOccupancyMaxPotentialBlockSize); Z(cuCtxPushCurrent);
Z(cuCtxPopCurrent); Z(cuStreamCreate); Z(cuStreamDestroy);
- Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuPointerGetAttribute);
+ Z(cuStreamSynchronize); Z(cuStreamWaitEvent); Z(cuStreamWaitEvent_ptsz); Z(cuPointerGetAttribute);
Z(cuArrayCreate); Z(cuArray3DCreate); Z(cuArray3DGetDescriptor);
Z(cuArrayDestroy); Z(cuTexObjectCreate); Z(cuTexObjectGetResourceDesc);
Z(cuTexObjectDestroy); Z(cuMemcpy2DAsync); Z(cuMemcpy3DAsync);
diff --git a/src/cuda_api.h b/src/cuda_api.h
index a2a04c0..142d9e5 100644
--- a/src/cuda_api.h
+++ b/src/cuda_api.h
@@ -260,6 +260,7 @@ DR_CUDA_SYM(CUresult (*cuStreamCreate)(CUstream *, unsigned int));
DR_CUDA_SYM(CUresult (*cuStreamDestroy)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamSynchronize)(CUstream));
DR_CUDA_SYM(CUresult (*cuStreamWaitEvent)(CUstream, CUevent, unsigned int));
+DR_CUDA_SYM(CUresult (*cuStreamWaitEvent_ptsz)(CUstream, CUevent, unsigned int));
DR_CUDA_SYM(CUresult (*cuMemAllocAsync)(CUdeviceptr *, size_t, CUstream));
DR_CUDA_SYM(CUresult (*cuMemFreeAsync)(CUdeviceptr, CUstream));
diff --git a/src/cuda_core.cpp b/src/cuda_core.cpp
index 1e54389..cbcbdda 100644
--- a/src/cuda_core.cpp
+++ b/src/cuda_core.cpp
@@ -109,8 +109,12 @@ std::pair<CUmodule, bool> jitc_cuda_compile(const char *buf, bool release_state_
void jitc_cuda_sync_stream(uintptr_t stream) {
ThreadState* ts = thread_state(JitBackend::CUDA);
CUevent sync_event = ts->sync_stream_event;
- cuda_check(cuEventRecord(sync_event, (CUstream)ts->stream));
- cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+ scoped_set_context guard(ts->context);
+ cuda_check(cuEventRecord(sync_event, (CUstream) ts->stream));
+ if (stream != 2)
+ cuda_check(cuStreamWaitEvent((CUstream)stream, sync_event, CU_EVENT_DEFAULT));
+ else
+ cuda_check(cuStreamWaitEvent_ptsz(nullptr, sync_event, CU_EVENT_DEFAULT));
} |
61c30e2
to
6eaa0c3
Compare
Thank you @wjakob for looking into it and providing the patch! I've opened mitsuba-renderer/drjit-core#139 with your patch. @dvicini, do the changes in this PR make sense to you? In particular, calling |
@merlinND Is that so? I thought we now insert an event and wait for it asynchronously. That is assuming that TF uses the special |
Sorry, it was not very clear because two things were included in this PR:
Please let me know if I missed something. As you said, |
Ok, that makes sense. Let's wait for @dvicini's on the Google™ viewpoint before merging :-) |
Ping @dvicini, do you have an opinion above the above regarding TF interop? |
This PR attempts to fix synchronization issues that come up in the unit tests of the new TF interop feature (#301): #301 (comment)
Since TensorFlow uses non-default CUDA streams for compute and data movement, I believe that we need to synchronize the stream used by DrJit before exporting a tensor to TF.