diff --git a/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex index 8f45bb71cc..381810f1c4 100644 --- a/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex +++ b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex @@ -57,6 +57,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do argument_sources = Map.take(state.args, Map.keys(arg_remapping)) + # TO-DO: collect shards for expr and arguments here and annotate them in the chain below [{id, category, expr, argument_sources} | acc] end ) @@ -129,6 +130,9 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do not (left_valid and right_valid) end + # default to true so that we can optimize this gradually + defp must_split_expr?(_, _, _), do: true + defp split_expr(expr, args, category, {cache, state}) do # We need to save this so that each previous stage # isn't affected by following ones