Skip to content

Commit f37e46e

Browse files
committed
DLPack: sync CUDA stream when exporting to TF
1 parent a54f606 commit f37e46e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/python/dlpack.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
118118
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
119119
/*
120120
stream = -1 request producer to perform no synchronization
121-
stream = 0 is ambiguous
121+
stream = 0 is ambiguous (could mean either None, 1, or 2)
122122
stream = 1 or None is the legacy default stream
123123
stream = 2 is the per-thread default stream
124124
stream > 2 is a CUDA handle to the consumer's stream
125125
*/
126126
if (!stream.is_none() && !stream.equal(nb::int_(-1)) && !stream.equal(nb::int_(1))) {
127-
if (stream.equal(nb::int_(0)))
127+
if (stream.equal(nb::int_(0)) || stream.equal(nb::int_(2)))
128128
jit_sync_thread();
129129
else {
130130
uintptr_t stream_handle;
@@ -133,6 +133,7 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
133133
jit_cuda_sync_stream(stream_handle);
134134
}
135135
}
136+
136137
} else {
137138
jit_sync_thread();
138139
}
@@ -265,6 +266,9 @@ void export_dlpack(nb::module_ &) {
265266
.def("tf",
266267
[](nb::handle_t<ArrayBase> h) {
267268
nb::module_ tf = nb::module_::import_("tensorflow.experimental.dlpack");
268-
return tf.attr("from_dlpack")(dlpack(h, false));
269+
// TensorFlow uses non-default streams for compute and data transfer, so
270+
// we must synchronize on the stream used by DrJit (producer) before
271+
// proceeding with TF.
272+
return tf.attr("from_dlpack")(dlpack(h, /* force_cpu */ false, /* stream */ nb::int_(2)));
269273
}, doc_tf);
270274
}

0 commit comments

Comments
 (0)