Skip to content

Commit 05c8b0f

Browse files
authored
Merge pull request #61 from asmeurer/2023.12-default
Make 2023.12 the default version
2 parents 718f15b + 2aae491 commit 05c8b0f

8 files changed

+99
-98
lines changed

README.md

-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,4 @@ libraries. Consuming library code should use the
1515
support the array API. Rather, it is intended to be used in the test suites of
1616
consuming libraries to test their array API usage.
1717

18-
array-api-strict currently supports the 2022.12 version of the standard.
19-
2023.12 support is planned and is tracked by [this
20-
issue](https://github.com/data-apis/array-api-strict/issues/25).
21-
2218
See the documentation for more details https://data-apis.org/array-api-strict/

array_api_strict/_flags.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"2023.12",
2525
)
2626

27-
API_VERSION = default_version = "2022.12"
27+
API_VERSION = default_version = "2023.12"
2828

2929
BOOLEAN_INDEXING = True
3030

@@ -76,10 +76,6 @@ def set_array_api_strict_flags(
7676
Note that 2021.12 is supported, but currently gives the same thing as
7777
2022.12 (except that the fft extension will be disabled).
7878
79-
2023.12 support is experimental. Some features in 2023.12 may still be
80-
missing, and it hasn't been fully tested. A future version of
81-
array-api-strict will change the default version to 2023.12.
82-
8379
boolean_indexing : bool, optional
8480
Whether indexing by a boolean array is supported. This flag is enabled
8581
by default. Note that although boolean array indexing does result in
@@ -142,8 +138,6 @@ def set_array_api_strict_flags(
142138
raise ValueError(f"Unsupported standard version {api_version!r}")
143139
if api_version == "2021.12":
144140
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)
145-
if api_version == "2023.12":
146-
warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2)
147141
API_VERSION = api_version
148142
array_api_strict.__array_api_version__ = API_VERSION
149143

@@ -262,7 +256,9 @@ def reset_array_api_strict_flags():
262256
BOOLEAN_INDEXING = True
263257
DATA_DEPENDENT_SHAPES = True
264258
ENABLED_EXTENSIONS = default_extensions
265-
259+
array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
260+
set(array_api_strict.__all__) -
261+
set(default_extensions))
266262

267263
class ArrayAPIStrictFlags:
268264
"""

array_api_strict/tests/test_array_object.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,15 @@ def test_array_keys_use_private_array():
406406
def test_array_namespace():
407407
a = ones((3, 3))
408408
assert a.__array_namespace__() == array_api_strict
409-
assert array_api_strict.__array_api_version__ == "2022.12"
409+
assert array_api_strict.__array_api_version__ == "2023.12"
410410

411411
assert a.__array_namespace__(api_version=None) is array_api_strict
412-
assert array_api_strict.__array_api_version__ == "2022.12"
412+
assert array_api_strict.__array_api_version__ == "2023.12"
413413

414414
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
415415
assert array_api_strict.__array_api_version__ == "2022.12"
416416

417-
with pytest.warns(UserWarning):
418-
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
417+
assert a.__array_namespace__(api_version="2023.12") is array_api_strict
419418
assert array_api_strict.__array_api_version__ == "2023.12"
420419

421420
with pytest.warns(UserWarning):
@@ -435,7 +434,7 @@ def test_iter():
435434

436435
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
437436
def dlpack_2023_12(api_version):
438-
if api_version != '2022.12':
437+
if api_version == '2021.12':
439438
with pytest.warns(UserWarning):
440439
set_array_api_strict_flags(api_version=api_version)
441440
else:

array_api_strict/tests/test_elementwise_functions.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from .._flags import set_array_api_strict_flags
1919

20-
import pytest
2120

2221
def nargs(func):
2322
return len(getfullargspec(func).args)
@@ -111,8 +110,7 @@ def _array_vals():
111110
yield asarray(1.0, dtype=d)
112111

113112
# Use the latest version of the standard so all functions are included
114-
with pytest.warns(UserWarning):
115-
set_array_api_strict_flags(api_version="2023.12")
113+
set_array_api_strict_flags(api_version="2023.12")
116114

117115
for x in _array_vals():
118116
for func_name, types in elementwise_function_input_types.items():

array_api_strict/tests/test_flags.py

+57-48
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,52 @@
1818

1919
import pytest
2020

21-
def test_flags():
22-
# Test defaults
21+
def test_flag_defaults():
2322
flags = get_array_api_strict_flags()
2423
assert flags == {
25-
'api_version': '2022.12',
24+
'api_version': '2023.12',
25+
'boolean_indexing': True,
26+
'data_dependent_shapes': True,
27+
'enabled_extensions': ('linalg', 'fft'),
28+
}
29+
30+
31+
def test_reset_flags():
32+
with pytest.warns(UserWarning):
33+
set_array_api_strict_flags(
34+
api_version='2021.12',
35+
boolean_indexing=False,
36+
data_dependent_shapes=False,
37+
enabled_extensions=())
38+
reset_array_api_strict_flags()
39+
flags = get_array_api_strict_flags()
40+
assert flags == {
41+
'api_version': '2023.12',
2642
'boolean_indexing': True,
2743
'data_dependent_shapes': True,
2844
'enabled_extensions': ('linalg', 'fft'),
2945
}
3046

31-
# Test setting flags
47+
48+
def test_setting_flags():
3249
set_array_api_strict_flags(data_dependent_shapes=False)
3350
flags = get_array_api_strict_flags()
3451
assert flags == {
35-
'api_version': '2022.12',
52+
'api_version': '2023.12',
3653
'boolean_indexing': True,
3754
'data_dependent_shapes': False,
3855
'enabled_extensions': ('linalg', 'fft'),
3956
}
4057
set_array_api_strict_flags(enabled_extensions=('fft',))
4158
flags = get_array_api_strict_flags()
4259
assert flags == {
43-
'api_version': '2022.12',
60+
'api_version': '2023.12',
4461
'boolean_indexing': True,
4562
'data_dependent_shapes': False,
4663
'enabled_extensions': ('fft',),
4764
}
65+
66+
def test_flags_api_version_2021_12():
4867
# Make sure setting the version to 2021.12 disables fft and issues a
4968
# warning.
5069
with pytest.warns(UserWarning) as record:
@@ -55,27 +74,23 @@ def test_flags():
5574
assert flags == {
5675
'api_version': '2021.12',
5776
'boolean_indexing': True,
58-
'data_dependent_shapes': False,
59-
'enabled_extensions': (),
77+
'data_dependent_shapes': True,
78+
'enabled_extensions': ('linalg',),
6079
}
61-
reset_array_api_strict_flags()
6280

63-
with pytest.warns(UserWarning):
64-
set_array_api_strict_flags(api_version='2021.12')
81+
def test_flags_api_version_2022_12():
82+
set_array_api_strict_flags(api_version='2022.12')
6583
flags = get_array_api_strict_flags()
6684
assert flags == {
67-
'api_version': '2021.12',
85+
'api_version': '2022.12',
6886
'boolean_indexing': True,
6987
'data_dependent_shapes': True,
70-
'enabled_extensions': ('linalg',),
88+
'enabled_extensions': ('linalg', 'fft'),
7189
}
72-
reset_array_api_strict_flags()
7390

74-
# 2023.12 should issue a warning
75-
with pytest.warns(UserWarning) as record:
76-
set_array_api_strict_flags(api_version='2023.12')
77-
assert len(record) == 1
78-
assert '2023.12' in str(record[0].message)
91+
92+
def test_flags_api_version_2023_12():
93+
set_array_api_strict_flags(api_version='2023.12')
7994
flags = get_array_api_strict_flags()
8095
assert flags == {
8196
'api_version': '2023.12',
@@ -84,6 +99,7 @@ def test_flags():
8499
'enabled_extensions': ('linalg', 'fft'),
85100
}
86101

102+
def test_setting_flags_invalid():
87103
# Test setting flags with invalid values
88104
pytest.raises(ValueError, lambda:
89105
set_array_api_strict_flags(api_version='2020.12'))
@@ -94,35 +110,15 @@ def test_flags():
94110
api_version='2021.12',
95111
enabled_extensions=('linalg', 'fft')))
96112

97-
# Test resetting flags
98-
with pytest.warns(UserWarning):
99-
set_array_api_strict_flags(
100-
api_version='2021.12',
101-
boolean_indexing=False,
102-
data_dependent_shapes=False,
103-
enabled_extensions=())
104-
reset_array_api_strict_flags()
105-
flags = get_array_api_strict_flags()
106-
assert flags == {
107-
'api_version': '2022.12',
108-
'boolean_indexing': True,
109-
'data_dependent_shapes': True,
110-
'enabled_extensions': ('linalg', 'fft'),
111-
}
112-
113113
def test_api_version():
114114
# Test defaults
115-
assert xp.__array_api_version__ == '2022.12'
115+
assert xp.__array_api_version__ == '2023.12'
116116

117117
# Test setting the version
118-
with pytest.warns(UserWarning):
119-
set_array_api_strict_flags(api_version='2021.12')
120-
assert xp.__array_api_version__ == '2021.12'
118+
set_array_api_strict_flags(api_version='2022.12')
119+
assert xp.__array_api_version__ == '2022.12'
121120

122121
def test_data_dependent_shapes():
123-
with pytest.warns(UserWarning):
124-
set_array_api_strict_flags(api_version='2023.12') # to enable repeat()
125-
126122
a = asarray([0, 0, 1, 2, 2])
127123
mask = asarray([True, False, True, False, True])
128124
repeats = asarray([1, 1, 2, 2, 2])
@@ -275,12 +271,16 @@ def test_fft(func_name):
275271
def test_api_version_2023_12(func_name):
276272
func = api_version_2023_12_examples[func_name]
277273

278-
# By default, these functions should error
274+
# By default, these functions should not error
275+
func()
276+
277+
# In 2022.12, these functions should error
278+
set_array_api_strict_flags(api_version='2022.12')
279279
pytest.raises(RuntimeError, func)
280280

281-
with pytest.warns(UserWarning):
282-
set_array_api_strict_flags(api_version='2023.12')
283-
func()
281+
# Test the behavior gets updated properly
282+
set_array_api_strict_flags(api_version='2023.12')
283+
func()
284284

285285
set_array_api_strict_flags(api_version='2022.12')
286286
pytest.raises(RuntimeError, func)
@@ -371,16 +371,25 @@ def test_disabled_extensions():
371371
assert 'linalg' not in ns
372372
assert 'fft' not in ns
373373

374+
reset_array_api_strict_flags()
375+
assert 'linalg' in xp.__all__
376+
assert 'fft' in xp.__all__
377+
xp.linalg # No error
378+
xp.fft # No error
379+
ns = {}
380+
exec('from array_api_strict import *', ns)
381+
assert 'linalg' in ns
382+
assert 'fft' in ns
374383

375384
def test_environment_variables():
376385
# Test that the environment variables work as expected
377386
subprocess_tests = [
378387
# ARRAY_API_STRICT_API_VERSION
379388
('''\
380389
import array_api_strict as xp
381-
assert xp.__array_api_version__ == '2022.12'
390+
assert xp.__array_api_version__ == '2023.12'
382391
383-
assert xp.get_array_api_strict_flags()['api_version'] == '2022.12'
392+
assert xp.get_array_api_strict_flags()['api_version'] == '2023.12'
384393
385394
''', {}),
386395
*[

array_api_strict/tests/test_linalg.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88

99
# Technically this is linear_algebra, not linalg, but it's simpler to keep
1010
# both of these tests together
11-
def test_vecdot_2023_12():
12-
# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
13-
# 0 behavior (which is primarily kept for backwards compatibility).
11+
12+
13+
# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
14+
# 0 behavior (which is primarily kept for backwards compatibility).
15+
def test_vecdot_2022_12():
16+
# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
17+
set_array_api_strict_flags(api_version='2022.12')
1418

1519
a = xp.ones((2, 3, 4, 5))
1620
b = xp.ones(( 3, 4, 1))
1721

18-
# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
1922
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
2023
assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5)
2124
assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5)
@@ -34,10 +37,13 @@ def test_vecdot_2023_12():
3437
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
3538
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)
3639

40+
def test_vecdot_2023_12():
3741
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
3842
# min(x1.ndim, x2.ndim), which is unambiguous
39-
with pytest.warns(UserWarning):
40-
set_array_api_strict_flags(api_version='2023.12')
43+
set_array_api_strict_flags(api_version='2023.12')
44+
45+
a = xp.ones((2, 3, 4, 5))
46+
b = xp.ones(( 3, 4, 1))
4147

4248
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
4349
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1))
@@ -56,7 +62,7 @@ def test_cross(api_version):
5662
# This test tests everything that should be the same across all supported
5763
# API versions.
5864

59-
if api_version != '2022.12':
65+
if api_version == '2021.12':
6066
with pytest.warns(UserWarning):
6167
set_array_api_strict_flags(api_version=api_version)
6268
else:
@@ -88,7 +94,7 @@ def test_cross_2022_12(api_version):
8894
# backwards compatibility. Note that unlike vecdot, array_api_strict
8995
# cross() never implemented the "after broadcasting" axis behavior, but
9096
# just reused NumPy cross(), which applies axes before broadcasting.
91-
if api_version != '2022.12':
97+
if api_version == '2021.12':
9298
with pytest.warns(UserWarning):
9399
set_array_api_strict_flags(api_version=api_version)
94100
else:
@@ -104,11 +110,6 @@ def test_cross_2022_12(api_version):
104110
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)
105111

106112
def test_cross_2023_12():
107-
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
108-
# min(x1.ndim, x2.ndim), which is unambiguous
109-
with pytest.warns(UserWarning):
110-
set_array_api_strict_flags(api_version='2023.12')
111-
112113
a = xp.ones((3, 2, 4, 5))
113114
b = xp.ones((3, 2, 4, 1))
114115
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))

array_api_strict/tests/test_statistical_functions.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import array_api_strict as xp
66

7+
# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
8+
# with dtype=None
79
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
8-
def test_sum_prod_trace_2023_12(func_name):
9-
# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes
10-
# with dtype=None
10+
def test_sum_prod_trace_2022_12(func_name):
11+
set_array_api_strict_flags(api_version='2022.12')
12+
1113
if func_name == 'trace':
1214
func = getattr(xp.linalg, func_name)
1315
else:
@@ -21,8 +23,16 @@ def test_sum_prod_trace_2023_12(func_name):
2123
assert func(a_complex).dtype == xp.complex128
2224
assert func(a_int).dtype == xp.int64
2325

24-
with pytest.warns(UserWarning):
25-
set_array_api_strict_flags(api_version='2023.12')
26+
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
27+
def test_sum_prod_trace_2023_12(func_name):
28+
a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32)
29+
a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64)
30+
a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32)
31+
32+
if func_name == 'trace':
33+
func = getattr(xp.linalg, func_name)
34+
else:
35+
func = getattr(xp, func_name)
2636

2737
assert func(a_real).dtype == xp.float32
2838
assert func(a_complex).dtype == xp.complex64

0 commit comments

Comments
 (0)