Skip to content

Commit 8f6042e

Browse files
authored
ENH: Introduce moveaxis function (#99)
* ENH: Introduce `moveaxis` function * Bump version
1 parent dfc1721 commit 8f6042e

File tree

5 files changed

+28
-4
lines changed

5 files changed

+28
-4
lines changed

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels = ["conda-forge"]
44
description = "Add a short description here"
55
name = "finch-tensor"
66
platforms = ["osx-arm64"]
7-
version = "0.2.9"
7+
version = "0.2.10"
88

99
[tasks]
1010
compile = "python -c 'import finch'"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "finch-tensor"
3-
version = "0.2.9"
3+
version = "0.2.10"
44
description = ""
55
authors = ["Willow Ahrens <[email protected]>"]
66
readme = "README.md"

src/finch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
eye,
5353
tensordot,
5454
permute_dims,
55+
moveaxis,
5556
where,
5657
nonzero,
5758
sum,
@@ -166,6 +167,7 @@
166167
"tensordot",
167168
"matmul",
168169
"permute_dims",
170+
"moveaxis",
169171
"where",
170172
"nonzero",
171173
"int_",

src/finch/tensor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,19 @@ def linspace(
868868
)
869869

870870

871-
def permute_dims(x: Tensor, axes: tuple[int, ...]):
871+
def permute_dims(x: Tensor, axes: tuple[int, ...]) -> Tensor:
872872
return x.permute_dims(axes)
873873

874874

875-
def astype(x: Tensor, dtype: DType, /, *, copy: bool = True):
875+
def moveaxis(x: Tensor, source: int, destination: int) -> Tensor:
876+
axes = list(range(x.ndim))
877+
norm_source = normalize_axis_index(source, x.ndim)
878+
norm_dest = normalize_axis_index(destination, x.ndim)
879+
axes.insert(norm_dest, axes.pop(norm_source))
880+
return x.permute_dims(tuple(axes))
881+
882+
883+
def astype(x: Tensor, dtype: DType, /, *, copy: bool = True) -> Tensor:
876884
if not copy:
877885
if x.dtype == dtype:
878886
return x

tests/test_sparse.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,20 @@ def test_permute_dims(arr3d, permutation, format, order, opt):
164164
assert_equal(actual_lazy_mode.todense(), expected)
165165

166166

167+
@pytest.mark.parametrize(
168+
"src_dest", [(0, 1), (1, 0), (-1, 2), (-2, -1), (1, 1)]
169+
)
170+
@parametrize_optimizer
171+
def test_moveaxis(arr3d, src_dest, opt):
172+
finch.set_optimizer(opt)
173+
src, dest = src_dest
174+
arr_finch = finch.Tensor(arr3d)
175+
176+
actual = finch.moveaxis(arr_finch, src, dest)
177+
expected = np.moveaxis(arr3d, src, dest)
178+
assert_equal(actual.todense(), expected)
179+
180+
167181
@pytest.mark.parametrize("order", ["C", "F"])
168182
@parametrize_optimizer
169183
def test_astype(arr3d, order, opt):

0 commit comments

Comments
 (0)