@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195
195
196
196
__array_priority__ = 12
197
197
198
+ __array_members__ = ("data" , "coords" )
199
+
198
200
def __init__ (
199
201
self ,
200
202
coords ,
@@ -207,6 +209,8 @@ def __init__(
207
209
fill_value = None ,
208
210
idx_dtype = None ,
209
211
):
212
+ from .._common import _coerce_to_supported_dense
213
+
210
214
if isinstance (coords , COO ):
211
215
self ._make_shallow_copy_of (coords )
212
216
if data is not None or shape is not None :
@@ -226,8 +230,8 @@ def __init__(
226
230
self .enable_caching ()
227
231
return
228
232
229
- self .data = np . asarray (data )
230
- self .coords = np . asarray (coords )
233
+ self .data = _coerce_to_supported_dense (data )
234
+ self .coords = _coerce_to_supported_dense (coords )
231
235
232
236
if self .coords .ndim == 1 :
233
237
if self .coords .size == 0 and shape is not None :
@@ -236,7 +240,7 @@ def __init__(
236
240
self .coords = self .coords [None , :]
237
241
238
242
if self .data .ndim == 0 :
239
- self .data = np .broadcast_to (self .data , self .coords .shape [1 ])
243
+ self .data = self . _component_namespace .broadcast_to (self .data , self .coords .shape [1 ])
240
244
241
245
if self .data .ndim != 1 :
242
246
raise ValueError ("`data` must be a scalar or 1-dimensional." )
@@ -251,7 +255,9 @@ def __init__(
251
255
shape = tuple (shape )
252
256
253
257
if shape and not self .coords .size :
254
- self .coords = np .zeros ((len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp )
258
+ self .coords = self ._component_namespace .zeros (
259
+ (len (shape ) if isinstance (shape , Iterable ) else 1 , 0 ), dtype = np .intp
260
+ )
255
261
super ().__init__ (shape , fill_value = fill_value )
256
262
if idx_dtype :
257
263
if not can_store (idx_dtype , max (shape )):
@@ -1307,10 +1313,10 @@ def _sort_indices(self):
1307
1313
"""
1308
1314
linear = self .linear_loc ()
1309
1315
1310
- if (np .diff (linear ) >= 0 ).all (): # already sorted
1316
+ if (self . _component_namespace .diff (linear ) >= 0 ).all (): # already sorted
1311
1317
return
1312
1318
1313
- order = np .argsort (linear , kind = "mergesort" )
1319
+ order = self . _component_namespace .argsort (linear , kind = "mergesort" )
1314
1320
self .coords = self .coords [:, order ]
1315
1321
self .data = self .data [order ]
1316
1322
@@ -1336,16 +1342,16 @@ def _sum_duplicates(self):
1336
1342
# Inspired by scipy/sparse/coo.py::sum_duplicates
1337
1343
# See https://github.com/scipy/scipy/blob/main/LICENSE.txt
1338
1344
linear = self .linear_loc ()
1339
- unique_mask = np .diff (linear ) != 0
1345
+ unique_mask = self . _component_namespace .diff (linear ) != 0
1340
1346
1341
1347
if unique_mask .sum () == len (unique_mask ): # already unique
1342
1348
return
1343
1349
1344
- unique_mask = np .append (True , unique_mask )
1350
+ unique_mask = self . _component_namespace .append (True , unique_mask )
1345
1351
1346
1352
coords = self .coords [:, unique_mask ]
1347
- (unique_inds ,) = np .nonzero (unique_mask )
1348
- data = np .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1353
+ (unique_inds ,) = self . _component_namespace .nonzero (unique_mask )
1354
+ data = self . _component_namespace .add .reduceat (self .data , unique_inds , dtype = self .data .dtype )
1349
1355
1350
1356
self .data = data
1351
1357
self .coords = coords
0 commit comments