Skip to content

Commit 194db3f

Browse files
Correct return of object type at zero copy (#1571)
* Correct return of object type at zero copy in dpnp.asarray() * Add tests for gh-1570
1 parent 747ef6c commit 194db3f

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

dpnp/dpnp_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def asarray(
130130
)
131131

132132
# return x1 if dpctl returns a zero copy of x1_obj
133-
if array_obj is x1_obj:
133+
if array_obj is x1_obj and isinstance(x1, dpnp_array):
134134
return x1
135135

136136
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)

tests/test_sycl_queue.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import dpctl
2+
import dpctl.tensor as dpt
23
import numpy
34
import pytest
45
from dpctl.utils import ExecutionPlacementError
56
from numpy.testing import assert_allclose, assert_array_equal, assert_raises
67

78
import dpnp
9+
from dpnp.dpnp_array import dpnp_array
810

911
from .helper import assert_dtype_allclose, get_all_dtypes, is_win_platform
1012

@@ -1076,6 +1078,23 @@ def test_array_copy(device, func, device_param, queue_param):
10761078
assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)
10771079

10781080

1081+
@pytest.mark.parametrize(
1082+
"copy", [True, False, None], ids=["True", "False", "None"]
1083+
)
1084+
@pytest.mark.parametrize(
1085+
"device",
1086+
valid_devices,
1087+
ids=[device.filter_string for device in valid_devices],
1088+
)
1089+
def test_array_creation_from_dpctl(copy, device):
1090+
dpt_data = dpt.ones((3, 3), device=device)
1091+
1092+
result = dpnp.array(dpt_data, copy=copy)
1093+
1094+
assert_sycl_queue_equal(result.sycl_queue, dpt_data.sycl_queue)
1095+
assert isinstance(result, dpnp_array)
1096+
1097+
10791098
@pytest.mark.parametrize(
10801099
"device",
10811100
valid_devices,

tests/test_usm_type.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from math import prod
22

3+
import dpctl.tensor as dpt
34
import dpctl.utils as du
45
import pytest
56

@@ -180,6 +181,17 @@ def test_array_copy(func, usm_type_x, usm_type_y):
180181
assert y.usm_type == usm_type_y
181182

182183

184+
@pytest.mark.parametrize(
185+
"copy", [True, False, None], ids=["True", "False", "None"]
186+
)
187+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
188+
def test_array_creation_from_dpctl(copy, usm_type_x):
189+
x = dpt.ones((3, 3), usm_type=usm_type_x)
190+
y = dp.array(x, copy=copy)
191+
192+
assert y.usm_type == usm_type_x
193+
194+
183195
@pytest.mark.parametrize(
184196
"usm_type_start", list_of_usm_types, ids=list_of_usm_types
185197
)

0 commit comments

Comments
 (0)