diff --git a/sparse/_compressed/compressed.py b/sparse/_compressed/compressed.py index 7bb21014..37c54840 100644 --- a/sparse/_compressed/compressed.py +++ b/sparse/_compressed/compressed.py @@ -16,6 +16,7 @@ can_store, check_zero_fill_value, check_compressed_axes, + _zero_of_dtype, equivalent, ) from .._coo.core import COO @@ -143,7 +144,7 @@ def __init__( shape=None, compressed_axes=None, prune=False, - fill_value=0, + fill_value=None, idx_dtype=None, ): if isinstance(arg, ss.spmatrix): @@ -169,6 +170,10 @@ def __init__( arg.fill_value, ) + self.data, self.indices, self.indptr = arg + + if fill_value is None: + fill_value = _zero_of_dtype(self.data.dtype) if shape is None: raise ValueError("missing `shape` argument") @@ -177,8 +182,6 @@ def __init__( if len(shape) == 1: compressed_axes = None - self.data, self.indices, self.indptr = arg - if self.data.ndim != 1: raise ValueError("data must be a scalar or 1-dimensional.") @@ -845,7 +848,12 @@ def _prune(self): class _Compressed2d(GCXS): def __init__( - self, arg, shape=None, compressed_axes=None, prune=False, fill_value=0 + self, + arg, + shape=None, + compressed_axes=None, + prune=False, + fill_value=None, ): if not hasattr(arg, "shape") and shape is None: raise ValueError("missing `shape` argument") @@ -888,7 +896,7 @@ class CSR(_Compressed2d): Sparse supports 2-D CSR. """ - def __init__(self, arg, shape=None, prune=False, fill_value=0): + def __init__(self, arg, shape=None, prune=False, fill_value=None): super().__init__(arg, shape=shape, compressed_axes=(0,), fill_value=fill_value) @classmethod @@ -913,7 +921,7 @@ class CSC(_Compressed2d): Sparse supports 2-D CSC. """ - def __init__(self, arg, shape=None, prune=False, fill_value=0): + def __init__(self, arg, shape=None, prune=False, fill_value=None): super().__init__(arg, shape=shape, compressed_axes=(1,), fill_value=fill_value) @classmethod diff --git a/sparse/_compressed/elemwise.py b/sparse/_compressed/elemwise.py new file mode 100644 index 00000000..75c81afa --- /dev/null +++ b/sparse/_compressed/elemwise.py @@ -0,0 +1,155 @@ +from functools import lru_cache +from typing import Callable + +import numpy as np +import scipy.sparse +from numba import njit + +from .compressed import _Compressed2d + + +def op_unary(func, a): + res = a.copy() + res.data = func(a.data) + return res + + +@lru_cache(maxsize=None) +def _numba_d(func): + return njit(lambda *x: func(*x)) + + +def binary_op(func, a, b): + func = _numba_d(func) + if isinstance(a, _Compressed2d) and isinstance(b, _Compressed2d): + return op_union_indices(func, a, b) + else: + raise NotImplementedError() + +# From scipy._util +def _prune_array(array): + """Return an array equivalent to the input array. If the input + array is a view of a much larger array, copy its contents to a + newly allocated array. Otherwise, return the input unchanged. + """ + if array.base is not None and array.size < array.base.size // 2: + return array.copy() + return array + + + +def op_union_indices( + op: Callable, a: scipy.sparse.csr_matrix, b: scipy.sparse.csr_matrix, *, default_value=0 +): + assert a.shape == b.shape + + if type(a) != type(b): + b = type(a)(b) + # a.sort_indices() + # b.sort_indices() + + # TODO: numpy is weird with bools here + out_dtype = np.array(op(a.data[0], b.data[0])).dtype + default_value = out_dtype.type(default_value) + out_indptr = np.zeros_like(a.indptr) + out_indices = np.zeros(len(a.indices) + len(b.indices), dtype=np.promote_types(a.indices.dtype, b.indices.dtype)) + out_data = np.zeros(len(out_indices), dtype=out_dtype) + + nnz = op_union_indices_csr_csr( + op, + a.indptr, + a.indices, + a.data, + b.indptr, + b.indices, + b.data, + out_indptr, + out_indices, + out_data, + out_dtype=out_dtype, + default_value=default_value, + ) + out_data = _prune_array(out_data[:nnz]) + out_indices = _prune_array(out_indices[:nnz]) + return type(a)((out_data, out_indices, out_indptr), shape=a.shape) + + +@njit +def op_union_indices_csr_csr( + op: Callable, + a_indptr: np.ndarray, + a_indices: np.ndarray, + a_data: np.ndarray, + b_indptr: np.ndarray, + b_indices: np.ndarray, + b_data: np.ndarray, + out_indptr: np.ndarray, + out_indices: np.ndarray, + out_data: np.ndarray, + out_dtype, + default_value, +): + # out_indptr = np.zeros_like(a_indptr) + # out_indices = np.zeros(len(a_indices) + len(b_indices), dtype=a_indices.dtype) + # out_data = np.zeros(len(out_indices), dtype=out_dtype) + + out_idx = 0 + + for i in range(len(a_indptr) - 1): + + a_idx = a_indptr[i] + a_end = a_indptr[i + 1] + b_idx = b_indptr[i] + b_end = b_indptr[i + 1] + + while (a_idx < a_end) and (b_idx < b_end): + a_j = a_indices[a_idx] + b_j = b_indices[b_idx] + if a_j < b_j: + val = op(a_data[a_idx], default_value) + if val != default_value: + out_indices[out_idx] = a_j + out_data[out_idx] = val + out_idx += 1 + a_idx += 1 + elif b_j < a_j: + val = op(default_value, b_data[b_idx]) + if val != default_value: + out_indices[out_idx] = b_j + out_data[out_idx] = val + out_idx += 1 + b_idx += 1 + else: + val = op(a_data[a_idx], b_data[b_idx]) + if val != default_value: + out_indices[out_idx] = a_j + out_data[out_idx] = val + out_idx += 1 + a_idx += 1 + b_idx += 1 + + # Catch up the other set + while a_idx < a_end: + val = op(a_data[a_idx], default_value) + if val != default_value: + out_indices[out_idx] = a_indices[a_idx] + out_data[out_idx] = val + out_idx += 1 + a_idx += 1 + + while b_idx < b_end: + val = op(default_value, b_data[b_idx]) + if val != default_value: + out_indices[out_idx] = b_indices[b_idx] + out_data[out_idx] = val + out_idx += 1 + b_idx += 1 + + out_indptr[i + 1] = out_idx + + # This may need to change to be "resize" to allow memory reallocation + # resize is currently not implemented in numba + out_indices = out_indices[: out_idx] + out_data = out_data[: out_idx] + + return out_idx \ No newline at end of file diff --git a/sparse/_umath.py b/sparse/_umath.py index 016d4998..93cc431a 100644 --- a/sparse/_umath.py +++ b/sparse/_umath.py @@ -407,6 +407,48 @@ def broadcast_to(x, shape): ) +# TODO: Figure out the right way to type this +# TODO: Figure out how to do 1d COO + CSR or CSC +def _resolve_result_type(args: "list[ArrayLike]") -> "Type": + from ._compressed import GCXS, CSR, CSC + from ._coo import COO + from ._dok import DOK + from ._sparse_array import SparseArray + from ._compressed.compressed import _Compressed2d + + args = [arg for arg in args if isinstance(arg, SparseArray)] + + if all(isinstance(arg, DOK) for arg in args): + out_type = DOK + elif all(isinstance(arg, CSR) for arg in args): + out_type = CSR + elif all(isinstance(arg, CSC) for arg in args): + out_type = CSC + elif all(isinstance(arg, _Compressed2d) for arg in args): + out_type = CSR + elif all(isinstance(arg, GCXS) for arg in args): + out_type = GCXS + else: + out_type = COO + return out_type + + +def _from_scipy_sparse(a): + from ._compressed import CSR, CSC + from ._coo import COO + from ._dok import DOK + + assert isinstance(a, scipy.sparse.spmatrix) + if isinstance(a, scipy.sparse.csr_matrix): + return CSR(a) + elif isinstance(a, scipy.sparse.csc_matrix): + return CSC(a) + elif isinstance(a, scipy.sparse.dok_matrix): + return DOK(a.shape, data=dict(a)) + else: + return COO(a) + + class _Elemwise: def __init__(self, func, *args, **kwargs): """ @@ -423,24 +465,26 @@ def __init__(self, func, *args, **kwargs): """ from ._coo import COO from ._sparse_array import SparseArray - from ._compressed import GCXS + from ._compressed import GCXS, CSR, CSC + from ._compressed.compressed import _Compressed2d from ._dok import DOK - processed_args = [] - out_type = GCXS - - sparse_args = [arg for arg in args if isinstance(arg, SparseArray)] + args = [ + arg + if not isinstance(arg, scipy.sparse.spmatrix) + else _from_scipy_sparse(arg) + for arg in args + ] - if all(isinstance(arg, DOK) for arg in sparse_args): - out_type = DOK - elif all(isinstance(arg, GCXS) for arg in sparse_args): - out_type = GCXS - else: - out_type = COO + processed_args = [] + self.out_type = _resolve_result_type(args) + # Should this happen before dispatch? + # Hmm, this may need major major changes. + # Case to consider: CSR or CSC + 1d COO for arg in args: - if isinstance(arg, scipy.sparse.spmatrix): - processed_args.append(COO.from_scipy_sparse(arg)) + if self.out_type != COO and isinstance(arg, _Compressed2d): + processed_args.append(arg) elif isscalar(arg) or isinstance(arg, np.ndarray): # Faster and more reliable to pass ()-shaped ndarrays as scalars. processed_args.append(np.asarray(arg)) @@ -454,7 +498,6 @@ def __init__(self, func, *args, **kwargs): self.args = None return - self.out_type = out_type self.args = tuple(processed_args) self.func = func self.dtype = kwargs.pop("dtype", None) @@ -467,14 +510,19 @@ def __init__(self, func, *args, **kwargs): def get_result(self): from ._coo import COO + from ._sparse_array import SparseArray + from ._compressed.compressed import _Compressed2d if self.args is None: return NotImplemented if self._dense_result: - args = [a.todense() if isinstance(a, COO) else a for a in self.args] + args = [a.todense() if isinstance(a, SparseArray) else a for a in self.args] return self.func(*args, **self.kwargs) + if issubclass(self.out_type, _Compressed2d): + return self._get_result_compressed_2d() + if any(s == 0 for s in self.shape): data = np.empty((0,), dtype=self.fill_value.dtype) coords = np.empty((0, len(self.shape)), dtype=np.intp) @@ -521,6 +569,29 @@ def get_result(self): fill_value=self.fill_value, ).asformat(self.out_type) + def _get_result_compressed_2d(self): + from ._compressed import elemwise as elemwise2d + from ._compressed.compressed import _Compressed2d + + if len(self.args) == 1: + result = elemwise2d.op_unary(self.func, self.args[0]) + + processed_args = [] + for arg in self.args: + if isinstance(arg, self.out_type): + processed_args.append(arg) + elif isinstance(arg, _Compressed2d): + processed_args.append(self.out_type(arg)) + elif isinstance(arg, np.ndarray): + processed_args.append(np.broadcast_to(arg, self.shape)) + else: + raise NotImplementedError() + + if len(processed_args) == 2: + result = elemwise2d.binary_op(self.func, *processed_args) + + return result + def _get_fill_value(self): """ A function that finds and returns the fill-value. @@ -530,10 +601,11 @@ def _get_fill_value(self): ValueError If the fill-value is inconsistent. """ - from ._coo import COO + from ._sparse_array import SparseArray zero_args = tuple( - arg.fill_value[...] if isinstance(arg, COO) else arg for arg in self.args + arg.fill_value[...] if isinstance(arg, SparseArray) else arg + for arg in self.args ) # Some elemwise functions require a dtype argument, some abhorr it. @@ -550,7 +622,9 @@ def _get_fill_value(self): fill_value = fill_value_array[(0,) * fill_value_array.ndim] except IndexError: zero_args = tuple( - arg.fill_value if isinstance(arg, COO) else _zero_of_dtype(arg.dtype) + arg.fill_value + if isinstance(arg, SparseArray) + else _zero_of_dtype(arg.dtype) for arg in self.args ) fill_value = self.func(*zero_args, **self.kwargs)[()] diff --git a/sparse/tests/test_elemwise.py b/sparse/tests/test_elemwise.py index 86cd6a05..0d66574a 100644 --- a/sparse/tests/test_elemwise.py +++ b/sparse/tests/test_elemwise.py @@ -3,8 +3,9 @@ import pytest import operator from sparse import COO, DOK -from sparse._compressed import GCXS +from sparse._compressed import GCXS, CSR, CSC from sparse._utils import assert_eq, random_value_array +from sparse import SparseArray @pytest.mark.parametrize( @@ -481,6 +482,66 @@ def test_leftside_elemwise_scalar(func, scalar, convert_to_np_number): assert_eq(fs, func(y, x)) +from itertools import product +from functools import singledispatch + +@singledispatch +def asdense(x): + raise NotImplementedError() + +@asdense.register(SparseArray) +def _(x): + return x.todense() + +@asdense.register(np.ndarray) +def _(x): + return x + +# TODO: Add test for result types +@pytest.mark.parametrize("func", [np.add, np.subtract, np.multiply]) +@pytest.mark.parametrize( + "a,b", + # TODO: would be nice if the tests would take names from these parameters + list( + product( + [ + # pytest.param(CSR(sparse.random((20, 10), density=0.5)), id="CSR"), + # pytest.param(CSC(sparse.random((20, 10), density=0.5)), id="CSC"), + # pytest.param(sparse.random((20, 10), density=0.5).todense(), id="COO"), + # pytest.param(sparse.random((20, 10), density=0.5, format=COO), id="dense-2d"), + # pytest.param(np.random.rand(20), id="dense-row"), + # pytest.param(np.random.rand(1, 10), id="dense-col"), + CSR(sparse.random((20, 10), density=0.5)), + CSC(sparse.random((20, 10), density=0.5)), + sparse.random((20, 10), density=0.5).todense(), + sparse.random((20, 10), density=0.5, format=COO), + np.random.rand(10), + np.random.rand(20, 1), + ], + repeat=2, + ) + ), +) +def test_2d_binary_op(func, a, b): + # TODO: implement COO.asformat(CSR) + def _is_ndarray_1d(x): + return isinstance(x, np.ndarray) and sum(s != 1 for s in x.shape) <= 1 + + from sparse import SparseArray + + if func in [np.add, np.subtract] and (_is_ndarray_1d(a) or _is_ndarray_1d(b)): + # https://github.com/pydata/sparse/issues/460 + pytest.skip() + + ref_a = a.todense() if isinstance(a, SparseArray) else a + ref_b = b.todense() if isinstance(b, SparseArray) else b + + expected = func(ref_a, ref_b) + result = func(a, b) + + assert_eq(expected, asdense(result)) + + @pytest.mark.parametrize( "func, scalar", [