@@ -14891,6 +14891,7 @@ defmodule Nx do
14891
14891
should be applied
14892
14892
* `:direction` - Can be `:asc` or `:desc`. Defaults to `:asc`
14893
14893
* `:stable` - If the sorting is stable. Defaults to `false`
14894
+ * `:type` - The type of the resulting tensor. Defaults to `:s64`.
14894
14895
14895
14896
## Examples
14896
14897
@@ -14921,9 +14922,9 @@ defmodule Nx do
14921
14922
>
14922
14923
14923
14924
iex> t = Nx.tensor([[3, 1, 7], [2, 5, 4]], names: [:x, :y])
14924
- iex> Nx.argsort(t, axis: :y, direction: :asc)
14925
+ iex> Nx.argsort(t, axis: :y, direction: :asc, type: :u32 )
14925
14926
#Nx.Tensor<
14926
- s64 [x: 2][y: 3]
14927
+ u32 [x: 2][y: 3]
14927
14928
[
14928
14929
[1, 0, 2],
14929
14930
[0, 2, 1]
@@ -14985,7 +14986,7 @@ defmodule Nx do
14985
14986
"""
14986
14987
@ doc type: :ndim
14987
14988
def argsort ( tensor , opts \\ [ ] ) do
14988
- opts = keyword! ( opts , axis: 0 , direction: :asc , stable: false )
14989
+ opts = keyword! ( opts , axis: 0 , direction: :asc , stable: false , type: { :s , 64 } )
14989
14990
14990
14991
apply_vectorized ( tensor , fn tensor , offset ->
14991
14992
direction =
@@ -15007,7 +15008,7 @@ defmodule Nx do
15007
15008
Nx.Shared . raise_complex_not_supported ( type , :argsort , 2 )
15008
15009
15009
15010
impl! ( tensor ) . argsort (
15010
- % { tensor | type: { :s , 64 } } ,
15011
+ % { tensor | type: Nx.Type . normalize! ( opts [ :type ] ) } ,
15011
15012
tensor ,
15012
15013
axis: axis ,
15013
15014
direction: direction ,
0 commit comments