diff --git a/nx/lib/nx/defn/graph.ex b/nx/lib/nx/defn/graph.ex new file mode 100644 index 0000000000..c11c1ef18b --- /dev/null +++ b/nx/lib/nx/defn/graph.ex @@ -0,0 +1,327 @@ +defmodule Nx.Defn.Graph do + @moduledoc """ + A module for splitting `Nx.Defn.Expr` into stages. + + This module is used to split an `Nx.Defn.Expr` into stages, which are then + executed in a chain. + + `split/2` and `t:Stage.t()` describe how to split + the graph and what's the expected result. + + `run/2` executes the given graph against the provided arguments in a sequential manner. + """ + alias Nx.Defn.Composite + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + defmodule Stage do + @typedoc """ + A stage in the graph splitter. + + * `:arguments`: a list of maps that point to the source from which to fetch the corresponding + value for the given argument. + * `:expr`: the expression that represents the computation for the Stage. + * `:id`: the unique id for the Stage. + """ + @type t :: %__MODULE__{ + id: reference(), + expr: %{__struct__: Nx.Defn.Expr}, + arguments: [%{source: {reference() | nil, non_neg_integer()}}] + } + + defstruct [:id, :expr, :arguments] + end + + @doc """ + Splits the received Nx.Defn.Expr into stages given the rules. + + `expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr` + and returns `true` when a split must happen, and `false` otherwise. + + ## Examples + + iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2) + iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end) + iex> {out0} = stage0.expr + iex> out0 + #Nx.Tensor< + f32 + \n\ + Nx.Defn.Expr + parameter a:0 s32 + b = negate a s32 + c = sin b f32 + > + iex> stage1.expr + #Nx.Tensor< + f32 + \n\ + Nx.Defn.Expr + parameter a:1 f32 + parameter c:0 s32 + b = cos a f32 + d = add b, c f32 + > + """ + def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do + {chain, _, _} = __split__(expr, expr_split_fn) + chain + end + + @doc """ + Executes the stage chain with the given arguments. + """ + def run(chain, args) do + scope = + Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end) + |> Map.new() + + {result, _scope} = + Enum.reduce(chain, {nil, scope}, fn stage, {_result, scope} -> + %{id: id, expr: expr, arguments: arguments} = stage + + args = + Enum.map(arguments, fn %{source: source} -> + Map.fetch!(scope, source) + end) + + case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do + %T{} = tensor -> + {tensor, Map.put(scope, {id, 0}, tensor)} + + tuple -> + {_idx, scope} = + tuple + |> Tuple.to_list() + |> Enum.reduce({0, scope}, fn tensor, {idx, scope} -> + {idx + 1, Map.put(scope, {id, idx}, tensor)} + end) + + {tuple, scope} + end + end) + + result + end + + @doc false + def __split__(expr, expr_split_fn) do + # state.expression_chain is a reverse accumulation of the stages and + # snapshots of the state at each one so that we can properly remap parameters for each stage. + state = %{ + expression_chain: [], + nodes_to_replace: %{}, + expr_split_fn: expr_split_fn, + # args is a map of id -> {stage_id, output_container_position} + args: %{} + } + + cache = %{} + {expr, {cache, state}} = composite_eval(expr, state, cache) + + expr_chain = + Enum.reduce( + [{make_ref(), expr, state.nodes_to_replace} | state.expression_chain], + [], + fn {id, expr, nodes_to_replace}, acc -> + # TO-DO: we need to also do a pass to avoid recalculating results that have been previously calculated. + # For example: + # x = arg0 + arg1 + # y = arg0 - arg1 + # z = x + y + # ----- + # w = dot(z, arg1) + # y + w <- here, we currently have to recalculate y given that only z, arg0 and arg1 will be passed as arguments. + # ideally, we should also pass y as a value to avoid recalculating it. + # We might be able to calculate this in the first traversal somehow. + + {expr, %{used_args: used_args}} = + composite_rewrite_subtree( + expr, + %{state | nodes_to_replace: nodes_to_replace} + ) + + arg_remapping = + used_args + |> Enum.sort_by(fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> idx end) + |> Enum.with_index(fn + {id, expr}, idx -> + {id, put_in(expr.data.args, [idx])} + end) + |> Map.new() + + {expr, _} = + composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping}) + + arguments = + arg_remapping + |> Enum.map(fn {_id, arg_expr} -> + id = arg_expr.data.id + [idx] = arg_expr.data.args + source = Map.fetch!(state.args, id) + {idx, %{source: source}} + end) + |> Enum.sort_by(fn {idx, _} -> idx end) + |> Enum.map(fn {_, arg} -> arg end) + + [ + %Stage{ + id: id, + expr: expr, + arguments: arguments + } + | acc + ] + end + ) + + {expr_chain, cache, Map.delete(state, :expression_chain)} + end + + defp composite_eval(expr, state, cache) do + Composite.traverse(expr, {cache, state}, &eval/2) + end + + defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do + case {cache, state.nodes_to_replace} do + {_, %{^id => res}} -> + # Replace the node with the corresponding parameter + {res, {Map.put(cache, id, res), state}} + + {%{^id => res}, _} -> + {res, {cache, state}} + + _ -> + if state.expr_split_fn.(ans) do + split_expr(ans, {cache, state}) + else + eval_apply(op, ans, {cache, state}) + end + end + end + + defp eval(other, {cache, state}) do + {other, {cache, state}} + end + + defp split_expr(expr, {cache, state}) do + {args, {cache, state}} = Nx.Defn.Tree.apply_args(expr, {cache, state}, &eval/2) + # We need to save this so that each previous stage + # isn't affected by following ones + nodes_to_replace = state.nodes_to_replace + + stage_id = make_ref() + + {args, {tensor_args, _out_position, state}} = + Enum.map_reduce(args, {[], 0, state}, fn + %T{} = expr, {tensor_args, out_position, state} -> + arg = Expr.parameter(expr, map_size(state.args)) + + state = %{ + state + | args: Map.put(state.args, arg.data.id, {stage_id, out_position}), + nodes_to_replace: Map.put(state.nodes_to_replace, expr.data.id, arg) + } + + {arg, {[expr | tensor_args], out_position + 1, state}} + + non_tensor_arg, acc -> + {non_tensor_arg, acc} + end) + + new_expr = put_in(expr.data.args, args) + + state = + update_in( + state.expression_chain, + &[ + {stage_id, List.to_tuple(Enum.reverse(tensor_args)), nodes_to_replace} + | &1 + ] + ) + + cache = Map.put(cache, new_expr.data.id, new_expr) + + {new_expr, {cache, state}} + end + + defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do + state = put_in(state.args[id], {nil, idx}) + {expr, {Map.put(cache, id, expr), state}} + end + + defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do + {tuple, cache} = composite_eval(tuple, state, cache) + res = elem(tuple, i) + {res, {Map.put(cache, id, res), state}} + end + + defp eval_apply(_op, %T{data: %Expr{id: id}} = ans, {cache, state}) do + {args, {cache, state}} = Nx.Defn.Tree.apply_args(ans, {cache, state}, &eval/2) + ans = put_in(ans.data.args, args) + {ans, {Map.put(cache, id, ans), state}} + end + + defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}}) + + defp composite_rewrite_subtree(container, state, acc) when is_list(container) do + Enum.map_reduce(container, acc, fn + %T{} = arg, acc -> + composite_rewrite_subtree(arg, state, acc) + + arg, acc when is_list(arg) -> + composite_rewrite_subtree(arg, state, acc) + + arg, acc -> + {arg, acc} + end) + end + + defp composite_rewrite_subtree(container, state, acc) do + Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2)) + end + + defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], res)} + + _ -> + {expr, put_in(acc.used_args[id], expr)} + end + end + + defp rewrite_subtree( + %T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr, + state, + acc + ) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], res)} + + _ -> + {call, acc} = rewrite_subtree(call, state, acc) + # `subexpr` is hermetic, in the sense that it is a self-contained scope + # from which the arguments always come from `call`, so we can + # keep it as is. + + {put_in(expr.data.args, [call, subexpr, fun]), acc} + end + end + + defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do + case state.nodes_to_replace do + %{^id => res} -> + # nodes_to_replace always contains a param + {res, put_in(acc.used_args[id], res)} + + _ -> + {args, acc} = composite_rewrite_subtree(args, state, acc) + {put_in(expr.data.args, args), acc} + end + end + + defp rewrite_subtree(other, _, acc), do: {other, acc} +end diff --git a/nx/test/nx/defn/graph_test.exs b/nx/test/nx/defn/graph_test.exs new file mode 100644 index 0000000000..144b413873 --- /dev/null +++ b/nx/test/nx/defn/graph_test.exs @@ -0,0 +1,461 @@ +defmodule Nx.Defn.GraphTest do + use ExUnit.Case, async: true + + alias Nx.Defn.Graph + alias Nx.Defn.Graph.Stage + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + doctest Nx.Defn.Graph + + describe "traverse/1" do + test "simple expression with 1 split and no common nodes" do + expr = + Nx.Defn.debug_expr(fn arg0, arg1 -> + x = Nx.add(arg0, arg1) + y = Nx.subtract(arg0, arg1) + z = Nx.dot(x, y) + w = Nx.multiply(z, 2) + Nx.divide(w, 4) + end).(Nx.tensor([1, 2]), Nx.tensor([3, 4])) + + split_fn = fn + %T{data: %Expr{op: :dot}} -> true + _ -> false + end + + {chain, cache, state} = Graph.__split__(expr, split_fn) + + assert [ + %Stage{ + id: stage_0_id, + expr: stage_0_expr, + arguments: stage_0_arguments + }, + %Stage{ + id: _stage_1_id, + expr: stage_1_expr, + arguments: stage_1_arguments + } + ] = chain + + assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments + + assert [{2, arg_2_original_node_id, arg_2_id}, {3, arg_3_original_node_id, arg_3_id}] = + state.nodes_to_replace + |> Enum.map(fn {original_node_id, + %T{data: %Expr{id: id, op: :parameter, args: [idx]}}} -> + {idx, original_node_id, id} + end) + |> Enum.sort() + + # ensure that arg2 and arg3 map to the correct stage and output container position + assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments + + # ensure that arg2 and arg3 are replacing the correct nodes + {_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} = + Enum.find(cache, fn + {_, %T{data: %Expr{op: :dot}}} -> true + _ -> false + end) + + assert dot_arg_0.data.id == arg_2_id + assert dot_arg_1.data.id == arg_3_id + + # ensure that the output of the first stage contains the original nodes from dot(x, y) + # also assert on the rough shape for the expression + assert {%T{data: %Expr{id: ^arg_2_original_node_id}} = left, + %T{data: %Expr{id: ^arg_3_original_node_id}} = right} = stage_0_expr + + assert %T{ + data: %Expr{ + op: :add, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %T{data: %Expr{op: :parameter, args: [1]}} + ] + } + } = left + + assert %T{ + data: %Expr{ + op: :subtract, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %T{data: %Expr{op: :parameter, args: [1]}} + ] + } + } = right + + assert %T{ + data: %Expr{ + op: :divide, + args: [ + %T{ + data: %Expr{ + op: :multiply, + args: [ + %T{data: %Expr{op: :constant, args: [2]}}, + %T{ + data: %Expr{ + op: :dot, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + [0], + [], + %T{data: %Expr{op: :parameter, args: [1]}}, + [0], + [] + ] + } + } + ] + } + }, + %T{data: %Expr{op: :constant, args: [4]}} + ] + } + } = stage_1_expr + end + + test "expression with 2 splits, common nodes and argument separation" do + expr = + Nx.Defn.debug_expr(fn arg0, arg1, arg2 -> + x = Nx.add(arg0, arg1) + y = Nx.subtract(arg0, arg1) + z = Nx.dot(x, y) + w = Nx.multiply(z, 2) + a = Nx.sum(w) + + a + |> Nx.add(w) + |> Nx.subtract(arg2) + end).(Nx.tensor([[1, 2]]), Nx.tensor([[3], [4]]), Nx.tensor([5, 6])) + + split_fn = fn + %T{data: %Expr{op: :dot}} -> true + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + {chain, cache, state} = Graph.__split__(expr, split_fn) + + assert [ + %Stage{ + id: stage_0_id, + expr: stage_0_expr, + arguments: stage_0_arguments + }, + %Stage{ + id: stage_1_id, + expr: stage_1_expr, + arguments: stage_1_arguments + }, + %Stage{ + id: _stage_2_id, + expr: stage_2_expr, + arguments: stage_2_arguments + } + ] = chain + + assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments + + assert map_size(state.args) == 6 + + original_args = + Enum.reduce(state.args, [], fn {id, _}, acc -> + if node = cache[id] do + [{hd(node.data.args), id} | acc] + else + acc + end + end) + |> Enum.sort() + |> Enum.map(fn {_, id} -> id end) + + [arg_0_id, arg_1_id, arg_2_id] = original_args + + assert [ + {2, arg_3_original_node_id, arg_3_id}, + {3, arg_4_original_node_id, arg_4_id}, + {4, arg_5_original_node_id, arg_5_id} + ] = + state.nodes_to_replace + |> Enum.map(fn {original_node_id, + %T{data: %Expr{id: id, op: :parameter, args: [idx]}}} -> + {idx, original_node_id, id} + end) + |> Enum.sort() + + assert arg_3_id not in original_args + assert arg_4_id not in original_args + assert arg_5_id not in original_args + + # ensure that arg3 and arg4 map to the correct stage and output container position + assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments + + # ensure that arg3 and arg4 are replacing the correct nodes + {_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} = + Enum.find(cache, fn + {_, %T{data: %Expr{op: :dot}}} -> true + _ -> false + end) + + assert dot_arg_0.data.id == arg_3_id + assert dot_arg_1.data.id == arg_4_id + + # ensure that the output of the first stage contains the original nodes from dot(x, y) + # also assert on the rough shape for the expression + assert {%T{data: %Expr{id: ^arg_3_original_node_id}} = left, + %T{data: %Expr{id: ^arg_4_original_node_id}} = right} = stage_0_expr + + assert %T{ + data: %Expr{ + op: :add, + args: [ + %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} + ] + } + } = left + + assert %T{ + data: %Expr{ + op: :subtract, + args: [ + %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} + ] + } + } = right + + assert {%T{ + data: %Expr{ + id: ^arg_5_original_node_id, + op: :multiply, + args: [ + %T{data: %Expr{op: :constant, args: [2]}}, + %T{ + data: %Expr{ + op: :dot, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + [1], + [], + %T{data: %Expr{op: :parameter, args: [1]}}, + [0], + [] + ] + } + } + ] + } + }} = stage_1_expr + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_2_expr + assert %T{data: %Expr{op: :add, args: [b, a]}} = c + assert %T{data: %Expr{id: ^arg_2_id, op: :parameter, args: [0]}} = d + assert %T{data: %Expr{op: :sum, args: [^a, [axes: nil, keep_axes: false]]}} = b + assert %T{data: %Expr{id: ^arg_5_id, op: :parameter, args: [1]}} = a + + assert [%{source: {nil, 2}}, %{source: {stage_1_id, 0}}] == stage_2_arguments + end + + test "supports optional callbacks" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + z = Nx.logical_not(y) + Nx.subtract(z, a) + end).(1, arg0) + + split_fn = fn + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert stage_0.arguments == [%{source: {nil, 1}}] + assert stage_1.arguments == [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}] + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + assert %T{data: %Expr{op: :optional, args: [call, subexpr, _fun]}} = c + + assert %T{data: %Expr{id: arg_0_id, op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :logical_not, args: [b]}} = call + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = b + assert %T{data: %Expr{id: arg_1_id, op: :parameter, args: [1]}} = a + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + %T{data: %Expr{id: subexpr_arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = subexpr + + # ensure subexpr is hermetic + assert subexpr_arg_0_id != arg_0_id + assert subexpr_arg_0_id != arg_1_id + end + + test "supports in-line anonymous functions" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + f = fn a -> Nx.equal(a, 0) end + z = f.(y) + Nx.subtract(z, a) + end).(1, arg0) + + split_fn = fn + %T{data: %Expr{op: :sum}} -> true + _ -> false + end + + assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn) + + assert [%{source: {nil, 1}}] == stage_0.arguments + + assert [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}] == stage_1.arguments + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + left, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = c + + assert %T{data: %Expr{op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left + assert %T{data: %Expr{op: :parameter, args: [1]}} = a + end + end + + describe "run/2" do + test "executes the stages chain and returns the correct result" do + function = fn arg0, arg1 -> + # root + x = Nx.multiply(arg0, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # left side + w_left = Nx.multiply(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # right side + w_right = Nx.divide(x, arg1) |> Nx.Defn.Expr.metadata(%{split: true}) + + # merge + Nx.add(w_right, w_left) + end + + args = [Nx.tensor([1, 2]), Nx.tensor([3, 4])] + + # This is used in the final assertion of this test + expected_result = Nx.Defn.jit_apply(function, args) + + expr = apply(Nx.Defn.debug_expr(function), args) + + split_fn = fn + %T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true + _ -> false + end + + chain = Graph.split(expr, split_fn) + + assert [root, right, left, merge] = chain + + assert {%T{data: %Expr{op: :multiply, args: [arg0, arg1]}}} = root.expr + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg0 + assert %T{data: %Expr{op: :parameter, args: [1]}} = arg1 + + # left should depend on exactly the same parameters as the root, as it's pulling from + # the global scope + assert {%T{data: %Expr{op: :multiply, args: [x, arg1_left]}}} = left.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = x + + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_left + + assert Enum.fetch!(left.arguments, 0).source == {nil, 1} + assert Enum.fetch!(left.arguments, 1).source == {root.id, 0} + + # right should depend on the result of the root and on arg1, but arg1 will be reindexed + # we should assert that the argument source for arg1_right is correct + assert {%T{data: %Expr{op: :divide, args: [x, arg1_right]}}} = right.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = x + + assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_right + + assert Enum.fetch!(right.arguments, 0).source == {nil, 1} + assert Enum.fetch!(right.arguments, 1).source == {root.id, 0} + + assert %T{data: %Expr{op: :add, args: [w_right, w_left]}} = merge.expr + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [0]}}, + %{split: true} + ] + } + } = w_right + + assert %T{ + data: %Expr{ + op: :metadata, + args: [ + %T{data: %Expr{op: :parameter, args: [1]}}, + %{split: true} + ] + } + } = w_left + + assert Enum.fetch!(merge.arguments, 0).source == {right.id, 0} + assert Enum.fetch!(merge.arguments, 1).source == {left.id, 0} + + assert Graph.run(chain, args) == expected_result + end + end +end