From dcb22c08bd7594d578f9c6d0447dc1179b6f6c3c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 14:07:41 +0200 Subject: [PATCH 1/4] ENH: test take_along_axis --- array_api_tests/test_indexing_functions.py | 52 ++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 3ef01cb7..a7eb67f9 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -60,3 +60,55 @@ def test_take(x, data): # sanity check with pytest.raises(StopIteration): next(out_indices) + + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_take_along_axis(x, data): + # TODO + # 1. negative axis + # 2. negative indices + # 3. different dtypes for indices + axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis") + + idx_shape = x.shape[:axis] + (len_axis,) + x.shape[axis+1:] + indices = data.draw( + hh.arrays( + shape=idx_shape, + dtype=dh.default_int, + elements={"min_value": 0, "max_value": x.shape[axis]-1} + ), + label="indices" + ) + note(f"{indices=} {idx_shape=}") + + out = xp.take_along_axis(x, indices, axis=axis) + + ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take_along_axis", + out_shape=out.shape, + expected=x.shape[:axis] + (len_axis,) + x.shape[axis+1:], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) + + # value test: notation is from `np.take_along_axis` docstring + Ni, Nk = x.shape[:axis], x.shape[axis+1:] + for ii in sh.ndindex(Ni): + for kk in sh.ndindex(Nk): + a_1d = x[ii + (slice(None),) + kk] + i_1d = indices[ii + (slice(None),) + kk] + o_1d = out[ii + (slice(None),) + kk] + for j in range(len_axis): + assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' + From 0a181ed8f90b26dc4e756fa88d8dc35b212ee89d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 14:25:39 +0200 Subject: [PATCH 2/4] ENH: test test_along_axis with axis<0 --- array_api_tests/test_indexing_functions.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index a7eb67f9..4b5631d7 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -71,13 +71,13 @@ def test_take(x, data): ) def test_take_along_axis(x, data): # TODO - # 1. negative axis # 2. negative indices # 3. different dtypes for indices - axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + axis = data.draw(st.integers(-x.ndim, max(x.ndim - 1, 0)), label="axis") len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis") - idx_shape = x.shape[:axis] + (len_axis,) + x.shape[axis+1:] + n_axis = axis + x.ndim if axis < 0 else axis + idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:] indices = data.draw( hh.arrays( shape=idx_shape, @@ -94,7 +94,7 @@ def test_take_along_axis(x, data): ph.assert_shape( "take_along_axis", out_shape=out.shape, - expected=x.shape[:axis] + (len_axis,) + x.shape[axis+1:], + expected=x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:], kw=dict( x=x, indices=indices, @@ -103,7 +103,7 @@ def test_take_along_axis(x, data): ) # value test: notation is from `np.take_along_axis` docstring - Ni, Nk = x.shape[:axis], x.shape[axis+1:] + Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:] for ii in sh.ndindex(Ni): for kk in sh.ndindex(Nk): a_1d = x[ii + (slice(None),) + kk] @@ -111,4 +111,3 @@ def test_take_along_axis(x, data): o_1d = out[ii + (slice(None),) + kk] for j in range(len_axis): assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' - From 27927b51676dcbe361b5a31a805da101baef5df7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 14:40:10 +0200 Subject: [PATCH 3/4] ENH: test take_along_axis default axis=-1 --- array_api_tests/test_indexing_functions.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 4b5631d7..1baa7175 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -73,22 +73,30 @@ def test_take_along_axis(x, data): # TODO # 2. negative indices # 3. different dtypes for indices - axis = data.draw(st.integers(-x.ndim, max(x.ndim - 1, 0)), label="axis") - len_axis = data.draw(st.integers(0, 2*x.shape[axis]), label="len_axis") + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis - n_axis = axis + x.ndim if axis < 0 else axis + len_axis = data.draw(st.integers(0, 2*x.shape[n_axis]), label="len_axis") idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:] indices = data.draw( hh.arrays( shape=idx_shape, dtype=dh.default_int, - elements={"min_value": 0, "max_value": x.shape[axis]-1} + elements={"min_value": 0, "max_value": x.shape[n_axis]-1} ), label="indices" ) note(f"{indices=} {idx_shape=}") - out = xp.take_along_axis(x, indices, axis=axis) + out = xp.take_along_axis(x, indices, **axis_kw) ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape( From 5191ff4031344d7108d8aa7cb8c49a659c582781 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 21 Apr 2025 19:18:55 +0200 Subject: [PATCH 4/4] . --- array_api_tests/test_indexing_functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 1baa7175..a599d218 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -73,6 +73,7 @@ def test_take_along_axis(x, data): # TODO # 2. negative indices # 3. different dtypes for indices + # 4. "broadcast-compatible" indices axis = data.draw( st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), label="axis" @@ -84,8 +85,8 @@ def test_take_along_axis(x, data): axis_kw = {"axis": axis} n_axis = axis + x.ndim if axis < 0 else axis - len_axis = data.draw(st.integers(0, 2*x.shape[n_axis]), label="len_axis") - idx_shape = x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:] + new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len") + idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:] indices = data.draw( hh.arrays( shape=idx_shape, @@ -102,7 +103,7 @@ def test_take_along_axis(x, data): ph.assert_shape( "take_along_axis", out_shape=out.shape, - expected=x.shape[:n_axis] + (len_axis,) + x.shape[n_axis+1:], + expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:], kw=dict( x=x, indices=indices, @@ -117,5 +118,5 @@ def test_take_along_axis(x, data): a_1d = x[ii + (slice(None),) + kk] i_1d = indices[ii + (slice(None),) + kk] o_1d = out[ii + (slice(None),) + kk] - for j in range(len_axis): + for j in range(new_len): assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'