diff --git a/docs/reference.rst b/docs/reference.rst index 6d369a13..f862e7cb 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -621,6 +621,8 @@ Low-level bits .. autofunction:: thread_count .. autofunction:: set_thread_count .. autofunction:: sync_thread +.. autofunction:: sync_device +.. autofunction:: sync_all_devices .. autofunction:: flush_kernel_cache .. autofunction:: flush_malloc_cache .. autofunction:: expand_threshold diff --git a/ext/drjit-core b/ext/drjit-core index 32486b64..a8f8f0b2 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 32486b64dcbf3a8f7d0a28c274863d2c8ea25f65 +Subproject commit a8f8f0b2f72493117965ffa599461c1e6c5b18ff diff --git a/src/python/dlpack.cpp b/src/python/dlpack.cpp index a0ac00ed..1bdffdf4 100644 --- a/src/python/dlpack.cpp +++ b/src/python/dlpack.cpp @@ -118,21 +118,24 @@ static nb::ndarray<> dlpack(nb::handle_t h, bool force_cpu, nb::handl // https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html /* stream = -1 request producer to perform no synchronization - stream = 0 is ambiguous + stream = 0 is ambiguous (could mean either None, 1, or 2) stream = 1 or None is the legacy default stream stream = 2 is the per-thread default stream stream > 2 is a CUDA handle to the consumer's stream */ if (!stream.is_none() && !stream.equal(nb::int_(-1)) && !stream.equal(nb::int_(1))) { - if (stream.equal(nb::int_(0))) + if (stream.equal(nb::int_(0))) { jit_sync_thread(); - else { + } else { + // Note: the special value 2 (syncing w.r.t. the per-thread default stream) + // is handled by `jit_cuda_sync_stream()`. uintptr_t stream_handle; if (!nb::try_cast(stream, stream_handle)) nb::raise_type_error("__dlpack__(): 'stream' argument must be 'None' or of type 'int'."); jit_cuda_sync_stream(stream_handle); } } + } else { jit_sync_thread(); } @@ -265,6 +268,11 @@ void export_dlpack(nb::module_ &) { .def("tf", [](nb::handle_t h) { nb::module_ tf = nb::module_::import_("tensorflow.experimental.dlpack"); - return tf.attr("from_dlpack")(dlpack(h, false)); + // TensorFlow uses non-default streams for compute and data transfer, so + // we must synchronize on the stream used by DrJit (producer) before + // proceeding with TF. Unfortunately, we do not have access to TF's streams, + // so we cannot use a lightweight stream-to-stream synchronization. + return tf.attr("from_dlpack")(dlpack(h, /* force_cpu */ false, + /* stream */ nb::int_(0))); }, doc_tf); } diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 3b4cfaf2..e9fbe4fd 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -7144,6 +7144,16 @@ then you have found a bug. Please report it on the project's `GitHub issue tracker `__. + +.. topic:: sync_device + + Wait for all computation on the current device to finish. + +.. topic:: sync_all_devices + + Wait for all computation on *all devices* to finish. + + .. topic:: flush_malloc_cache Free the memory allocation cache maintained by Dr.Jit. diff --git a/src/python/main.cpp b/src/python/main.cpp index 99712f3e..103e0fb5 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -171,6 +171,8 @@ NB_MODULE(_drjit_ext, m_) { m.def("has_backend", &jit_has_backend, doc_has_backend); m.def("sync_thread", &jit_sync_thread, doc_sync_thread) + .def("sync_device", &jit_sync_device, doc_sync_device) + .def("sync_all_devices", &jit_sync_all_devices, doc_sync_all_devices) .def("flush_kernel_cache", &jit_flush_kernel_cache, doc_flush_kernel_cache) .def("flush_malloc_cache", &jit_flush_malloc_cache, doc_flush_malloc_cache) .def("malloc_clear_statistics", &jit_malloc_clear_statistics)