Skip to content

Commit f052362

Browse files
committed
tests
1 parent d698683 commit f052362

File tree

4 files changed

+75
-70
lines changed

4 files changed

+75
-70
lines changed

docs/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ from mylib.vendored.array_api_compat import array_namespace as _array_namespace_
9292
def array_namespace(*xs, **kwargs):
9393
from mylib import MyArray
9494

95-
if any(isinstance(x, MyArray) for x in xs:
95+
if any(isinstance(x, MyArray) for x in xs):
9696
...
9797
else:
9898
return _array_namespace_orig(*xs, **kwargs)

src/array_api_extra/_lib/_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def in1d(
1717
*,
1818
assume_unique: bool = False,
1919
invert: bool = False,
20-
xp: ModuleType,
20+
xp: ModuleType | None = None,
2121
) -> Array:
2222
"""Checks whether each element of an array is also present in a
2323
second array.
@@ -29,6 +29,8 @@ def in1d(
2929
present in numpy:
3030
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
3131
"""
32+
if xp is None:
33+
xp = _compat.array_namespace(x1, x2)
3234

3335
# This code is run to make the code significantly faster
3436
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
@@ -71,11 +73,14 @@ def mean(
7173
*,
7274
axis: int | tuple[int, ...] | None = None,
7375
keepdims: bool = False,
74-
xp: ModuleType,
76+
xp: ModuleType | None = None,
7577
) -> Array:
7678
"""
7779
Complex mean, https://github.com/data-apis/array-api/issues/846.
7880
"""
81+
if xp is None:
82+
xp = _compat.array_namespace(x)
83+
7984
if xp.isdtype(x.dtype, "complex floating"):
8085
x_real = xp.real(x)
8186
x_imag = xp.imag(x)

0 commit comments

Comments
 (0)