diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 452f4668..b011f08d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -492,9 +492,7 @@ def your_function(x, y): namespaces = set() for x in xs: - if isinstance(x, (bool, int, float, complex, type(None))): - continue - elif is_numpy_array(x): + if is_numpy_array(x): from .. import numpy as numpy_namespace import numpy as np if use_compat is True: @@ -558,6 +556,8 @@ def your_function(x, y): if use_compat is True: raise ValueError("The given array does not have an array-api-compat wrapper") namespaces.add(x.__array_namespace__(api_version=api_version)) + elif isinstance(x, (bool, int, float, complex, type(None))): + continue else: # TODO: Support Python scalars? raise TypeError(f"{type(x).__name__} is not a supported array type")