@@ -289,5 +289,55 @@ def test_reshape(rng, dtype):
289
289
np .testing .assert_array_equal (actual , expected )
290
290
291
291
# DENSE
292
- # NOTE: dense reshape is probably broken in MLIR
292
+ # NOTE: dense reshape is probably broken in MLIR in 19.x branch
293
293
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
294
+
295
+
296
+ @parametrize_dtypes
297
+ def test_broadcast_to (dtype ):
298
+ # CSR, CSC, COO
299
+ for shape , new_shape , dimensions , input_arr , expected_arrs in [
300
+ (
301
+ (3 , 4 ),
302
+ (2 , 3 , 4 ),
303
+ [0 ],
304
+ np .array ([[0 , 1 , 0 , 3 ], [0 , 0 , 4 , 5 ], [6 , 7 , 0 , 0 ]]),
305
+ [
306
+ np .array ([0 , 3 , 6 ]),
307
+ np .array ([0 , 1 , 2 , 0 , 1 , 2 ]),
308
+ np .array ([0 , 2 , 4 , 6 , 8 , 10 , 12 ]),
309
+ np .array ([1 , 3 , 2 , 3 , 0 , 1 , 1 , 3 , 2 , 3 , 0 , 1 ]),
310
+ np .array ([1.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 1.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 ]),
311
+ ],
312
+ ),
313
+ (
314
+ (4 , 2 ),
315
+ (4 , 2 , 2 ),
316
+ [1 ],
317
+ np .array ([[0 , 1 ], [0 , 0 ], [2 , 3 ], [4 , 0 ]]),
318
+ [
319
+ np .array ([0 , 2 , 2 , 4 , 6 ]),
320
+ np .array ([0 , 1 , 0 , 1 , 0 , 1 ]),
321
+ np .array ([0 , 1 , 2 , 4 , 6 , 7 , 8 ]),
322
+ np .array ([1 , 1 , 0 , 1 , 0 , 1 , 0 , 0 ]),
323
+ np .array ([1.0 , 1.0 , 2.0 , 3.0 , 2.0 , 3.0 , 4.0 , 4.0 ]),
324
+ ],
325
+ ),
326
+ ]:
327
+ for fn_format in [sps .csr_array , sps .csc_array , sps .coo_array ]:
328
+ arr = fn_format (input_arr , shape = shape , dtype = dtype )
329
+ arr .sum_duplicates ()
330
+ tensor = sparse .asarray (arr )
331
+ result = sparse .broadcast_to (tensor , new_shape , dimensions = dimensions ).to_scipy_sparse ()
332
+
333
+ for actual , expected in zip (result , expected_arrs , strict = False ):
334
+ np .testing .assert_allclose (actual , expected )
335
+
336
+ # DENSE
337
+ np_arr = np .array ([0 , 0 , 2 , 3 , 0 , 1 ])
338
+ arr = np .asarray (np_arr , dtype = dtype )
339
+ tensor = sparse .asarray (arr )
340
+ result = sparse .broadcast_to (tensor , (3 , 6 ), dimensions = [0 ]).to_scipy_sparse ()
341
+
342
+ assert result .format == "csr"
343
+ np .testing .assert_allclose (result .todense (), np .repeat (np_arr [np .newaxis ], 3 , axis = 0 ))
0 commit comments