File tree 4 files changed +75
-70
lines changed
4 files changed +75
-70
lines changed Original file line number Diff line number Diff line change @@ -92,7 +92,7 @@ from mylib.vendored.array_api_compat import array_namespace as _array_namespace_
92
92
def array_namespace (* xs , ** kwargs ):
93
93
from mylib import MyArray
94
94
95
- if any (isinstance (x, MyArray) for x in xs:
95
+ if any (isinstance (x, MyArray) for x in xs) :
96
96
...
97
97
else :
98
98
return _array_namespace_orig(* xs, ** kwargs)
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ def in1d(
17
17
* ,
18
18
assume_unique : bool = False ,
19
19
invert : bool = False ,
20
- xp : ModuleType ,
20
+ xp : ModuleType | None = None ,
21
21
) -> Array :
22
22
"""Checks whether each element of an array is also present in a
23
23
second array.
@@ -29,6 +29,8 @@ def in1d(
29
29
present in numpy:
30
30
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
31
31
"""
32
+ if xp is None :
33
+ xp = _compat .array_namespace (x1 , x2 )
32
34
33
35
# This code is run to make the code significantly faster
34
36
if x2 .shape [0 ] < 10 * x1 .shape [0 ] ** 0.145 :
@@ -71,11 +73,14 @@ def mean(
71
73
* ,
72
74
axis : int | tuple [int , ...] | None = None ,
73
75
keepdims : bool = False ,
74
- xp : ModuleType ,
76
+ xp : ModuleType | None = None ,
75
77
) -> Array :
76
78
"""
77
79
Complex mean, https://github.com/data-apis/array-api/issues/846.
78
80
"""
81
+ if xp is None :
82
+ xp = _compat .array_namespace (x )
83
+
79
84
if xp .isdtype (x .dtype , "complex floating" ):
80
85
x_real = xp .real (x )
81
86
x_imag = xp .imag (x )
You can’t perform that action at this time.
0 commit comments