Skip to content

Commit 1d5fa54

Browse files
committed
ENH: searchorted: support int | float scalar x2
1 parent d6d7966 commit 1d5fa54

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

array_api_strict/_searching_functions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ._array_object import Array
66
from ._dtypes import _real_numeric_dtypes, _result_type
77
from ._dtypes import bool as _bool
8-
from ._flags import requires_api_version, requires_data_dependent_shapes
8+
from ._flags import requires_api_version, requires_data_dependent_shapes, get_array_api_strict_flags
99
from ._helpers import _maybe_normalize_py_scalars
1010

1111

@@ -64,7 +64,7 @@ def count_nonzero(
6464
@requires_api_version('2023.12')
6565
def searchsorted(
6666
x1: Array,
67-
x2: Array,
67+
x2: Array | int | float,
6868
/,
6969
*,
7070
side: Literal["left", "right"] = "left",
@@ -75,6 +75,12 @@ def searchsorted(
7575
7676
See its docstring for more information.
7777
"""
78+
flags = get_array_api_strict_flags()
79+
if flags["api_version"] >= "2025.12":
80+
81+
if isinstance(x2, bool | int | float | complex):
82+
x2 = x1._promote_scalar(x2)
83+
7884
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
7985
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
8086

0 commit comments

Comments
 (0)