Skip to content

Commit

Permalink
fix: make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Oct 17, 2024
1 parent 4206732 commit 0034521
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
end

defp must_split_expr?(:dot, [t0, c0, _b0, t1, c1, _b1], shards) do
left_shards = shards[t0.data.id].shards
left_shards =
case shards[t0.data.id] do
%{shards: shards} -> shards
_ -> nil
end

left_valid =
Enum.all?(c0, fn axis ->
Expand All @@ -138,7 +142,11 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
end
end)

right_shards = shards[t1.data.id].shards
right_shards =
case shards[t1.data.id] do
%{shards: shards} -> shards
_ -> nil
end

right_valid =
Enum.all?(c1, fn axis ->
Expand Down

0 comments on commit 0034521

Please sign in to comment.