Skip to content

Commit 1fe44ce

Browse files
committed
Support :type in argsort
1 parent 04e4df5 commit 1fe44ce

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

nx/lib/nx.ex

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14891,6 +14891,7 @@ defmodule Nx do
1489114891
should be applied
1489214892
* `:direction` - Can be `:asc` or `:desc`. Defaults to `:asc`
1489314893
* `:stable` - If the sorting is stable. Defaults to `false`
14894+
* `:type` - The type of the resulting tensor. Defaults to `:s64`.
1489414895
1489514896
## Examples
1489614897
@@ -14921,9 +14922,9 @@ defmodule Nx do
1492114922
>
1492214923
1492314924
iex> t = Nx.tensor([[3, 1, 7], [2, 5, 4]], names: [:x, :y])
14924-
iex> Nx.argsort(t, axis: :y, direction: :asc)
14925+
iex> Nx.argsort(t, axis: :y, direction: :asc, type: :u32)
1492514926
#Nx.Tensor<
14926-
s64[x: 2][y: 3]
14927+
u32[x: 2][y: 3]
1492714928
[
1492814929
[1, 0, 2],
1492914930
[0, 2, 1]
@@ -14985,7 +14986,7 @@ defmodule Nx do
1498514986
"""
1498614987
@doc type: :ndim
1498714988
def argsort(tensor, opts \\ []) do
14988-
opts = keyword!(opts, axis: 0, direction: :asc, stable: false)
14989+
opts = keyword!(opts, axis: 0, direction: :asc, stable: false, type: {:s, 64})
1498914990

1499014991
apply_vectorized(tensor, fn tensor, offset ->
1499114992
direction =
@@ -15007,7 +15008,7 @@ defmodule Nx do
1500715008
Nx.Shared.raise_complex_not_supported(type, :argsort, 2)
1500815009

1500915010
impl!(tensor).argsort(
15010-
%{tensor | type: {:s, 64}},
15011+
%{tensor | type: Nx.Type.normalize!(opts[:type])},
1501115012
tensor,
1501215013
axis: axis,
1501315014
direction: direction,

0 commit comments

Comments
 (0)