Skip to content

Commit bbf7a43

Browse files
authored
Update semantics of to_dlpack. (#2707)
to_dlpack now takes ownership of the original buffer, leaving it in an invalid state.
1 parent d6ab70c commit bbf7a43

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

jax/dlpack.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
def to_dlpack(x: xla.DeviceArray):
2222
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
2323
24-
The DLPack shares memory with `x`.
24+
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
25+
state.
2526
2627
Args:
2728
x: a `DeviceArray`, on either CPU or GPU.

tests/array_interoperability_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def testJaxRoundTrip(self, shape, dtype):
6969
x = jnp.array(np)
7070
dlpack = jax.dlpack.to_dlpack(x)
7171
y = jax.dlpack.from_dlpack(dlpack)
72-
self.assertAllClose(x, y, check_dtypes=True)
72+
self.assertAllClose(np, y, check_dtypes=True)
7373

7474
self.assertRaisesRegex(RuntimeError,
7575
"DLPack tensor may be consumed at most once",

0 commit comments

Comments
 (0)