18
18
19
19
import pytest
20
20
21
- def test_flags ():
22
- # Test defaults
21
+ def test_flag_defaults ():
23
22
flags = get_array_api_strict_flags ()
24
23
assert flags == {
25
- 'api_version' : '2022.12' ,
24
+ 'api_version' : '2023.12' ,
25
+ 'boolean_indexing' : True ,
26
+ 'data_dependent_shapes' : True ,
27
+ 'enabled_extensions' : ('linalg' , 'fft' ),
28
+ }
29
+
30
+
31
+ def test_reset_flags ():
32
+ with pytest .warns (UserWarning ):
33
+ set_array_api_strict_flags (
34
+ api_version = '2021.12' ,
35
+ boolean_indexing = False ,
36
+ data_dependent_shapes = False ,
37
+ enabled_extensions = ())
38
+ reset_array_api_strict_flags ()
39
+ flags = get_array_api_strict_flags ()
40
+ assert flags == {
41
+ 'api_version' : '2023.12' ,
26
42
'boolean_indexing' : True ,
27
43
'data_dependent_shapes' : True ,
28
44
'enabled_extensions' : ('linalg' , 'fft' ),
29
45
}
30
46
31
- # Test setting flags
47
+
48
+ def test_setting_flags ():
32
49
set_array_api_strict_flags (data_dependent_shapes = False )
33
50
flags = get_array_api_strict_flags ()
34
51
assert flags == {
35
- 'api_version' : '2022 .12' ,
52
+ 'api_version' : '2023 .12' ,
36
53
'boolean_indexing' : True ,
37
54
'data_dependent_shapes' : False ,
38
55
'enabled_extensions' : ('linalg' , 'fft' ),
39
56
}
40
57
set_array_api_strict_flags (enabled_extensions = ('fft' ,))
41
58
flags = get_array_api_strict_flags ()
42
59
assert flags == {
43
- 'api_version' : '2022 .12' ,
60
+ 'api_version' : '2023 .12' ,
44
61
'boolean_indexing' : True ,
45
62
'data_dependent_shapes' : False ,
46
63
'enabled_extensions' : ('fft' ,),
47
64
}
65
+
66
+ def test_flags_api_version_2021_12 ():
48
67
# Make sure setting the version to 2021.12 disables fft and issues a
49
68
# warning.
50
69
with pytest .warns (UserWarning ) as record :
@@ -55,27 +74,23 @@ def test_flags():
55
74
assert flags == {
56
75
'api_version' : '2021.12' ,
57
76
'boolean_indexing' : True ,
58
- 'data_dependent_shapes' : False ,
59
- 'enabled_extensions' : (),
77
+ 'data_dependent_shapes' : True ,
78
+ 'enabled_extensions' : ('linalg' , ),
60
79
}
61
- reset_array_api_strict_flags ()
62
80
63
- with pytest . warns ( UserWarning ):
64
- set_array_api_strict_flags (api_version = '2021 .12' )
81
+ def test_flags_api_version_2022_12 ( ):
82
+ set_array_api_strict_flags (api_version = '2022 .12' )
65
83
flags = get_array_api_strict_flags ()
66
84
assert flags == {
67
- 'api_version' : '2021 .12' ,
85
+ 'api_version' : '2022 .12' ,
68
86
'boolean_indexing' : True ,
69
87
'data_dependent_shapes' : True ,
70
- 'enabled_extensions' : ('linalg' ,),
88
+ 'enabled_extensions' : ('linalg' , 'fft' ),
71
89
}
72
- reset_array_api_strict_flags ()
73
90
74
- # 2023.12 should issue a warning
75
- with pytest .warns (UserWarning ) as record :
76
- set_array_api_strict_flags (api_version = '2023.12' )
77
- assert len (record ) == 1
78
- assert '2023.12' in str (record [0 ].message )
91
+
92
+ def test_flags_api_version_2023_12 ():
93
+ set_array_api_strict_flags (api_version = '2023.12' )
79
94
flags = get_array_api_strict_flags ()
80
95
assert flags == {
81
96
'api_version' : '2023.12' ,
@@ -84,6 +99,7 @@ def test_flags():
84
99
'enabled_extensions' : ('linalg' , 'fft' ),
85
100
}
86
101
102
+ def test_setting_flags_invalid ():
87
103
# Test setting flags with invalid values
88
104
pytest .raises (ValueError , lambda :
89
105
set_array_api_strict_flags (api_version = '2020.12' ))
@@ -94,35 +110,15 @@ def test_flags():
94
110
api_version = '2021.12' ,
95
111
enabled_extensions = ('linalg' , 'fft' )))
96
112
97
- # Test resetting flags
98
- with pytest .warns (UserWarning ):
99
- set_array_api_strict_flags (
100
- api_version = '2021.12' ,
101
- boolean_indexing = False ,
102
- data_dependent_shapes = False ,
103
- enabled_extensions = ())
104
- reset_array_api_strict_flags ()
105
- flags = get_array_api_strict_flags ()
106
- assert flags == {
107
- 'api_version' : '2022.12' ,
108
- 'boolean_indexing' : True ,
109
- 'data_dependent_shapes' : True ,
110
- 'enabled_extensions' : ('linalg' , 'fft' ),
111
- }
112
-
113
113
def test_api_version ():
114
114
# Test defaults
115
- assert xp .__array_api_version__ == '2022 .12'
115
+ assert xp .__array_api_version__ == '2023 .12'
116
116
117
117
# Test setting the version
118
- with pytest .warns (UserWarning ):
119
- set_array_api_strict_flags (api_version = '2021.12' )
120
- assert xp .__array_api_version__ == '2021.12'
118
+ set_array_api_strict_flags (api_version = '2022.12' )
119
+ assert xp .__array_api_version__ == '2022.12'
121
120
122
121
def test_data_dependent_shapes ():
123
- with pytest .warns (UserWarning ):
124
- set_array_api_strict_flags (api_version = '2023.12' ) # to enable repeat()
125
-
126
122
a = asarray ([0 , 0 , 1 , 2 , 2 ])
127
123
mask = asarray ([True , False , True , False , True ])
128
124
repeats = asarray ([1 , 1 , 2 , 2 , 2 ])
@@ -275,12 +271,16 @@ def test_fft(func_name):
275
271
def test_api_version_2023_12 (func_name ):
276
272
func = api_version_2023_12_examples [func_name ]
277
273
278
- # By default, these functions should error
274
+ # By default, these functions should not error
275
+ func ()
276
+
277
+ # In 2022.12, these functions should error
278
+ set_array_api_strict_flags (api_version = '2022.12' )
279
279
pytest .raises (RuntimeError , func )
280
280
281
- with pytest . warns ( UserWarning ):
282
- set_array_api_strict_flags (api_version = '2023.12' )
283
- func ()
281
+ # Test the behavior gets updated properly
282
+ set_array_api_strict_flags (api_version = '2023.12' )
283
+ func ()
284
284
285
285
set_array_api_strict_flags (api_version = '2022.12' )
286
286
pytest .raises (RuntimeError , func )
@@ -371,16 +371,25 @@ def test_disabled_extensions():
371
371
assert 'linalg' not in ns
372
372
assert 'fft' not in ns
373
373
374
+ reset_array_api_strict_flags ()
375
+ assert 'linalg' in xp .__all__
376
+ assert 'fft' in xp .__all__
377
+ xp .linalg # No error
378
+ xp .fft # No error
379
+ ns = {}
380
+ exec ('from array_api_strict import *' , ns )
381
+ assert 'linalg' in ns
382
+ assert 'fft' in ns
374
383
375
384
def test_environment_variables ():
376
385
# Test that the environment variables work as expected
377
386
subprocess_tests = [
378
387
# ARRAY_API_STRICT_API_VERSION
379
388
('''\
380
389
import array_api_strict as xp
381
- assert xp.__array_api_version__ == '2022 .12'
390
+ assert xp.__array_api_version__ == '2023 .12'
382
391
383
- assert xp.get_array_api_strict_flags()['api_version'] == '2022 .12'
392
+ assert xp.get_array_api_strict_flags()['api_version'] == '2023 .12'
384
393
385
394
''' , {}),
386
395
* [
0 commit comments