@@ -118,13 +118,13 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
118
118
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
119
119
/*
120
120
stream = -1 request producer to perform no synchronization
121
- stream = 0 is ambiguous
121
+ stream = 0 is ambiguous (could mean either None, 1, or 2)
122
122
stream = 1 or None is the legacy default stream
123
123
stream = 2 is the per-thread default stream
124
124
stream > 2 is a CUDA handle to the consumer's stream
125
125
*/
126
126
if (!stream.is_none () && !stream.equal (nb::int_ (-1 )) && !stream.equal (nb::int_ (1 ))) {
127
- if (stream.equal (nb::int_ (0 )))
127
+ if (stream.equal (nb::int_ (0 )) || stream. equal ( nb::int_ ( 2 )) )
128
128
jit_sync_thread ();
129
129
else {
130
130
uintptr_t stream_handle;
@@ -133,6 +133,7 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
133
133
jit_cuda_sync_stream (stream_handle);
134
134
}
135
135
}
136
+
136
137
} else {
137
138
jit_sync_thread ();
138
139
}
@@ -265,6 +266,9 @@ void export_dlpack(nb::module_ &) {
265
266
.def (" tf" ,
266
267
[](nb::handle_t <ArrayBase> h) {
267
268
nb::module_ tf = nb::module_::import_ (" tensorflow.experimental.dlpack" );
268
- return tf.attr (" from_dlpack" )(dlpack (h, false ));
269
+ // TensorFlow uses non-default streams for compute and data transfer, so
270
+ // we must synchronize on the stream used by DrJit (producer) before
271
+ // proceeding with TF.
272
+ return tf.attr (" from_dlpack" )(dlpack (h, /* force_cpu */ false , /* stream */ nb::int_ (2 )));
269
273
}, doc_tf);
270
274
}
0 commit comments