Skip to content

Commit

Permalink
Only slice again if there are strides
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 5, 2023
1 parent 1fe44ce commit 3783da6
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,12 @@ defmodule EXLA.Defn do
else
zeros = List.duplicate(0, tuple_size(ans.shape))
slice = Value.dynamic_slice(tensor, start_indices, lengths)
Value.slice(slice, zeros, lengths, strides)

if Enum.all?(strides, & &1 == 1) do
slice
else
Value.slice(slice, zeros, lengths, strides)
end
end
end

Expand All @@ -1414,7 +1419,12 @@ defmodule EXLA.Defn do
else
zeros = List.duplicate(0, tuple_size(ans.shape))
slice = EXLA.Op.dynamic_slice(tensor, start_indices, lengths)
EXLA.Op.slice(slice, zeros, lengths, strides)

if Enum.all?(strides, & &1 == 1) do
slice
else
EXLA.Op.slice(slice, zeros, lengths, strides)
end
end
end

Expand Down

0 comments on commit 3783da6

Please sign in to comment.