Skip to content

Commit 0034521

Browse files
committed
fix: make tests pass
1 parent 4206732 commit 0034521

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
128128
end
129129

130130
defp must_split_expr?(:dot, [t0, c0, _b0, t1, c1, _b1], shards) do
131-
left_shards = shards[t0.data.id].shards
131+
left_shards =
132+
case shards[t0.data.id] do
133+
%{shards: shards} -> shards
134+
_ -> nil
135+
end
132136

133137
left_valid =
134138
Enum.all?(c0, fn axis ->
@@ -138,7 +142,11 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
138142
end
139143
end)
140144

141-
right_shards = shards[t1.data.id].shards
145+
right_shards =
146+
case shards[t1.data.id] do
147+
%{shards: shards} -> shards
148+
_ -> nil
149+
end
142150

143151
right_valid =
144152
Enum.all?(c1, fn axis ->

0 commit comments

Comments
 (0)