Skip to content

Commit d90d81f

Browse files
authored
Merge pull request #267 from ev-br/fix_count_nonzero
BUG: fix count_nonzero
2 parents fc8777f + ea068a0 commit d90d81f

File tree

4 files changed

+54
-5
lines changed

4 files changed

+54
-5
lines changed

array_api_compat/cupy/_aliases.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def astype(
125125
return out.copy() if copy and out is x else out
126126

127127

128+
# cupy.count_nonzero does not have keepdims
129+
def count_nonzero(
130+
x: ndarray,
131+
axis=None,
132+
keepdims=False
133+
) -> ndarray:
134+
result = cp.count_nonzero(x, axis)
135+
if keepdims:
136+
if axis is None:
137+
return cp.reshape(result, [1]*x.ndim)
138+
return cp.expand_dims(result, axis)
139+
return result
140+
141+
128142
# These functions are completely new here. If the library already has them
129143
# (i.e., numpy 2.0), use the library version instead of our wrapper.
130144
if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ def astype(
146160
'acos', 'acosh', 'asin', 'asinh', 'atan',
147161
'atan2', 'atanh', 'bitwise_left_shift',
148162
'bitwise_invert', 'bitwise_right_shift',
149-
'bool', 'concat', 'pow', 'sign']
163+
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
150164

151165
_all_ignore = ['cp', 'get_xp']

array_api_compat/dask/array/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,21 @@ def argsort(
335335
return restore(x)
336336

337337

338+
# dask.array.count_nonzero does not have keepdims
339+
def count_nonzero(
340+
x: Array,
341+
axis=None,
342+
keepdims=False
343+
) -> Array:
344+
result = da.count_nonzero(x, axis)
345+
if keepdims:
346+
if axis is None:
347+
return da.reshape(result, [1]*x.ndim)
348+
return da.expand_dims(result, axis)
349+
return result
350+
351+
352+
338353
__all__ = _aliases.__all__ + [
339354
'__array_namespace_info__', 'asarray', 'astype', 'acos',
340355
'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -343,6 +358,6 @@ def argsort(
343358
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
344359
'uint8', 'uint16', 'uint32', 'uint64',
345360
'complex64', 'complex128', 'iinfo', 'finfo',
346-
'can_cast', 'result_type']
361+
'can_cast', 'count_nonzero', 'result_type']
347362

348363
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]

array_api_compat/numpy/_aliases.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,19 @@ def astype(
127127
return x.astype(dtype=dtype, copy=copy)
128128

129129

130+
# count_nonzero returns a python int for axis=None and keepdims=False
131+
# https://github.com/numpy/numpy/issues/17562
132+
def count_nonzero(
133+
x : ndarray,
134+
axis=None,
135+
keepdims=False
136+
) -> ndarray:
137+
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
138+
if axis is None and not keepdims:
139+
return np.asarray(result)
140+
return result
141+
142+
130143
# These functions are completely new here. If the library already has them
131144
# (i.e., numpy 2.0), use the library version instead of our wrapper.
132145
if hasattr(np, 'vecdot'):
@@ -148,6 +161,6 @@ def astype(
148161
'acos', 'acosh', 'asin', 'asinh', 'atan',
149162
'atan2', 'atanh', 'bitwise_left_shift',
150163
'bitwise_invert', 'bitwise_right_shift',
151-
'bool', 'concat', 'pow']
164+
'bool', 'concat', 'count_nonzero', 'pow']
152165

153166
_all_ignore = ['np', 'get_xp']

array_api_compat/torch/_aliases.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,22 @@ def diff(
521521
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
522522

523523

524-
# torch uses `dim` instead of `axis`
524+
# torch uses `dim` instead of `axis`, does not have keepdims
525525
def count_nonzero(
526526
x: array,
527527
/,
528528
*,
529529
axis: Optional[Union[int, Tuple[int, ...]]] = None,
530530
keepdims: bool = False,
531531
) -> array:
532-
return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
532+
result = torch.count_nonzero(x, dim=axis)
533+
if keepdims:
534+
if axis is not None:
535+
return result.unsqueeze(axis)
536+
return _axis_none_keepdims(result, x.ndim, keepdims)
537+
else:
538+
return result
539+
533540

534541

535542
def where(condition: array, x1: array, x2: array, /) -> array:

0 commit comments

Comments
 (0)