@@ -73,22 +73,30 @@ def test_take_along_axis(x, data):
73
73
# TODO
74
74
# 2. negative indices
75
75
# 3. different dtypes for indices
76
- axis = data .draw (st .integers (- x .ndim , max (x .ndim - 1 , 0 )), label = "axis" )
77
- len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
76
+ axis = data .draw (
77
+ st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
78
+ label = "axis"
79
+ )
80
+ if axis is None :
81
+ axis_kw = {}
82
+ n_axis = x .ndim - 1
83
+ else :
84
+ axis_kw = {"axis" : axis }
85
+ n_axis = axis + x .ndim if axis < 0 else axis
78
86
79
- n_axis = axis + x . ndim if axis < 0 else axis
87
+ len_axis = data . draw ( st . integers ( 0 , 2 * x . shape [ n_axis ]), label = "len_axis" )
80
88
idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
81
89
indices = data .draw (
82
90
hh .arrays (
83
91
shape = idx_shape ,
84
92
dtype = dh .default_int ,
85
- elements = {"min_value" : 0 , "max_value" : x .shape [axis ]- 1 }
93
+ elements = {"min_value" : 0 , "max_value" : x .shape [n_axis ]- 1 }
86
94
),
87
95
label = "indices"
88
96
)
89
97
note (f"{ indices = } { idx_shape = } " )
90
98
91
- out = xp .take_along_axis (x , indices , axis = axis )
99
+ out = xp .take_along_axis (x , indices , ** axis_kw )
92
100
93
101
ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
94
102
ph .assert_shape (
0 commit comments