Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nx/guides/getting_started/quickstart.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Nx.shape(tensor)
We can also create a new tensor with the given shape using `Nx.reshape/2`:

```elixir
Nx.reshape(tensor, {1, 4}, names: [:batches, :values])
Nx.reshape(tensor, {1, 6}, names: [:batches, :values])
```
Comment on lines 164 to 166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a stray change. Can you open a separate PR with this?


This operation generally reuses all of the tensor data and simply
Expand Down
4 changes: 3 additions & 1 deletion torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1528,13 +1528,15 @@ defmodule Torchx.Backend do
|> then(unfold_flat)
|> then(function)

{device, _} = from_nx(tensor)

indices_to_flatten =
tensor
|> Nx.axes()
|> Enum.map(fn axis ->
tensor
|> Nx.shape()
|> Nx.iota(axis: axis, backend: Torchx.Backend)
|> Nx.iota(axis: axis, backend: {Torchx.Backend, device: device})
|> then(unfold_flat)
|> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1)
end)
Expand Down