From 3783da6f07d817080a3d0a2986750eeb20f113db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 18:51:59 +0100 Subject: [PATCH] Only slice again if there are strides --- exla/lib/exla/defn.ex | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index c7927226215..865ab2476d7 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1401,7 +1401,12 @@ defmodule EXLA.Defn do else zeros = List.duplicate(0, tuple_size(ans.shape)) slice = Value.dynamic_slice(tensor, start_indices, lengths) - Value.slice(slice, zeros, lengths, strides) + + if Enum.all?(strides, & &1 == 1) do + slice + else + Value.slice(slice, zeros, lengths, strides) + end end end @@ -1414,7 +1419,12 @@ defmodule EXLA.Defn do else zeros = List.duplicate(0, tuple_size(ans.shape)) slice = EXLA.Op.dynamic_slice(tensor, start_indices, lengths) - EXLA.Op.slice(slice, zeros, lengths, strides) + + if Enum.all?(strides, & &1 == 1) do + slice + else + EXLA.Op.slice(slice, zeros, lengths, strides) + end end end