From 273d54ea006d9370470b97c5d1707da82b5078a2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Oct 2024 15:20:39 -0600 Subject: [PATCH] Update __array_api_version__ to 2023.12 --- array_api_compat/common/_helpers.py | 13 +++++++------ array_api_compat/cupy/__init__.py | 2 +- array_api_compat/dask/array/__init__.py | 2 +- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/torch/__init__.py | 2 +- docs/index.md | 4 ++-- tests/test_array_namespace.py | 14 ++++++++++---- 7 files changed, 23 insertions(+), 16 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2467793c..91056e24 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -317,7 +317,7 @@ def is_torch_namespace(xp) -> bool: is_array_api_strict_namespace """ return xp.__name__ in {'torch', _compat_module_name() + '.torch'} - + def is_ndonnx_namespace(xp): """ @@ -415,10 +415,11 @@ def is_array_api_strict_namespace(xp): return xp.__name__ == 'array_api_strict' def _check_api_version(api_version): - if api_version == '2021.12': - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") - elif api_version is not None and api_version != '2022.12': - raise ValueError("Only the 2022.12 version of the array API specification is currently supported") + if api_version in ['2021.12', '2022.12']: + warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12") + elif api_version is not None and api_version not in ['2021.12', '2022.12', + '2023.12']: + raise ValueError("Only the 2023.12 version of the array API specification is currently supported") def array_namespace(*xs, api_version=None, use_compat=None): """ @@ -431,7 +432,7 @@ def array_namespace(*xs, api_version=None, use_compat=None): api_version: str The newest version of the spec that you need support for (currently - the compat library wrapped APIs support v2022.12). + the compat library wrapped APIs support v2023.12). use_compat: bool or None If None (the default), the native namespace will be returned if it is diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 7968d68d..d8685761 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -13,4 +13,4 @@ from ..common._helpers import * # noqa: F401,F403 -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index ce0e609e..b49be6cf 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -3,7 +3,7 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' __import__(__package__ + '.linalg') __import__(__package__ + '.fft') diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index b66f30a2..9bdbf312 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -27,4 +27,4 @@ except ImportError: pass -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 172f5279..cfa3acf8 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -21,4 +21,4 @@ from ..common._helpers import * # noqa: F403 -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' diff --git a/docs/index.md b/docs/index.md index b268e61a..ef18265e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,8 +12,8 @@ each array library itself fully compatible with the array API, but this requires making backwards incompatible changes in many cases, so this will take some time. -Currently all libraries here are implemented against the [2022.12 -version](https://data-apis.org/array-api/2022.12/) of the standard. +Currently all libraries here are implemented against the [2023.12 +version](https://data-apis.org/array-api/2023.12/) of the standard. ## Installation diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index e35e31e1..f8e20437 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -13,7 +13,7 @@ from ._helpers import import_, all_libraries, wrapped_libraries @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"]) +@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) @pytest.mark.parametrize("library", all_libraries + ['array_api_strict']) def test_array_namespace(library, api_version, use_compat): xp = import_(library) @@ -94,14 +94,20 @@ def test_array_namespace_errors_torch(): def test_api_version(): x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2022.12") == torch_ + assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning with warnings.catch_warnings(record=True) as w: assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) + assert len(w) == 1 + assert "2021.12" in str(w[0].message) + + # Should issue a warning + with warnings.catch_warnings(record=True) as w: + assert array_namespace(x, api_version="2022.12") == torch_ + assert len(w) == 1 + assert "2022.12" in str(w[0].message) pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))