Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Nx.Defn.Graph #1544

Merged
merged 50 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
25788c0
feat: add basic backend layout
polvalente Sep 25, 2024
6585353
proof of concept solution
polvalente Sep 25, 2024
54e7abf
feat: add working POC for sharding
polvalente Sep 26, 2024
5fe92ca
refactor: use only sharding compiler
polvalente Sep 26, 2024
09b844a
refactor: move things into __compile__
polvalente Sep 26, 2024
f6bd4ed
feat: working EXLA example
polvalente Sep 26, 2024
4c8200f
wip: initial work on dot (doesn't work)
polvalente Sep 26, 2024
2f9370a
wip
polvalente Sep 26, 2024
918118e
wip: refactor input sharding calculation
polvalente Sep 28, 2024
bd06388
wip: refactor to support broadcasts
polvalente Sep 28, 2024
5d7b106
wip: rework sharding representation (each slice is a shard)
polvalente Sep 29, 2024
1f7dfb3
feat: deal with broadcasting and re-slicing
polvalente Sep 29, 2024
243eee2
refactor: build parents tree into each shard
polvalente Oct 10, 2024
4bb2d97
feat: support implicit broadcasting
polvalente Oct 10, 2024
87f1c35
chore: remove unused var'
polvalente Oct 10, 2024
bca3943
feat: support dot product without contraction sharding
polvalente Oct 10, 2024
73a3553
feat: support constants
polvalente Oct 10, 2024
52fc860
add :tensor op to example function
polvalente Oct 10, 2024
7706678
feat: support squeeze
polvalente Oct 10, 2024
0ddb69a
chore: remove empty file
polvalente Oct 10, 2024
489d655
refactor: remove TensorSharding module
polvalente Oct 11, 2024
42b5ca6
chore: add stubs for missing callbacks
polvalente Oct 11, 2024
59034c1
test: add tests
polvalente Oct 11, 2024
d6da7a8
fix: transpose axis
polvalente Oct 11, 2024
68565bc
chore: remove example .exs files
polvalente Oct 11, 2024
fb86713
Update nx/lib/nx/defn/sharding_compiler.ex
polvalente Oct 14, 2024
0c111df
chore: format
polvalente Oct 14, 2024
5c5c881
chore: remove __stream__
polvalente Oct 14, 2024
5e05bbb
Merge remote-tracking branch 'origin/main' into pv-feat/experimental-…
polvalente Oct 15, 2024
6eb6fba
feat: add graph splitter for all-gather/all-reduce operations (#1545)
polvalente Oct 17, 2024
dce2c60
feat: add shard execution workflow (#1557)
polvalente Nov 28, 2024
7a2071e
Merge remote-tracking branch 'origin/main' into pv-feat/experimental-…
polvalente Mar 14, 2025
d4641de
refactor: remove sharding in favor of graph splits
polvalente Mar 14, 2025
28e1c65
refactor: only return the chain in the public interface
polvalente Mar 14, 2025
f7bb062
chore: clean up even more code
polvalente Mar 14, 2025
9a5d84a
feat: apply topsort to expr chain
polvalente Mar 14, 2025
e3c68d0
chore: style
polvalente Mar 14, 2025
384c770
feat: add GraphSplitter.run/2
polvalente Mar 16, 2025
76ef352
refactor: remove topsort because expr chain is already sorted as it's…
polvalente Mar 16, 2025
3a66244
refactor: simplify argument mapping representation
polvalente Mar 17, 2025
591b6f4
refactor: unify arguments and argument sources
polvalente Mar 17, 2025
b8dd06a
refactor: represent arguments as a list
polvalente Mar 17, 2025
d63bba1
refactor: return final results from Enum.reduce directly
polvalente Mar 17, 2025
cd9014a
chore: rename to __traverse__
polvalente Mar 17, 2025
1c0ca22
docs: add example
polvalente Mar 17, 2025
f7e01d1
docs: fix outdated comment
polvalente Mar 17, 2025
867d268
chore: simplify result accumulation
polvalente Mar 17, 2025
ab1adaa
docs formatting
polvalente Mar 17, 2025
f6090c3
docs: add more docs
polvalente Mar 17, 2025
8cb72cc
Apply suggestions from code review
polvalente Mar 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 327 additions & 0 deletions nx/lib/nx/defn/graph.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
defmodule Nx.Defn.Graph do
@moduledoc """
A module for splitting `Nx.Defn.Expr` into stages.

This module is used to split an `Nx.Defn.Expr` into stages, which are then
executed in a chain.

`split/2` and `t:Stage.t()` describe how to split
the graph and what's the expected result.

`run/2` executes the given graph against the provided arguments in a sequential manner.
"""
alias Nx.Defn.Composite

alias Nx.Tensor, as: T
alias Nx.Defn.Expr

defmodule Stage do
@typedoc """
A stage in the graph splitter.

* `:arguments`: a list of maps that point to the source from which to fetch the corresponding
value for the given argument.
* `:expr`: the expression that represents the computation for the Stage.
* `:id`: the unique id for the Stage.
"""
@type t :: %__MODULE__{
id: reference(),
expr: %{__struct__: Nx.Defn.Expr},
arguments: [%{source: {reference() | nil, non_neg_integer()}}]
}

defstruct [:id, :expr, :arguments]
end

@doc """
Splits the received Nx.Defn.Expr into stages given the rules.

`expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
and returns `true` when a split must happen, and `false` otherwise.

## Examples

iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
iex> {out0} = stage0.expr
iex> out0
#Nx.Tensor<
f32
\n\
Nx.Defn.Expr
parameter a:0 s32
b = negate a s32
c = sin b f32
>
iex> stage1.expr
#Nx.Tensor<
f32
\n\
Nx.Defn.Expr
parameter a:1 f32
parameter c:0 s32
b = cos a f32
d = add b, c f32
>
"""
def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do
{chain, _, _} = __split__(expr, expr_split_fn)
chain
end

@doc """
Executes the stage chain with the given arguments.
"""
def run(chain, args) do
scope =
Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end)
|> Map.new()

{result, _scope} =
Enum.reduce(chain, {nil, scope}, fn stage, {_result, scope} ->
%{id: id, expr: expr, arguments: arguments} = stage

args =
Enum.map(arguments, fn %{source: source} ->
Map.fetch!(scope, source)
end)

case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do
%T{} = tensor ->
{tensor, Map.put(scope, {id, 0}, tensor)}

tuple ->
{_idx, scope} =
tuple
|> Tuple.to_list()
|> Enum.reduce({0, scope}, fn tensor, {idx, scope} ->
{idx + 1, Map.put(scope, {id, idx}, tensor)}
end)

{tuple, scope}
end
end)

result
end

@doc false
def __split__(expr, expr_split_fn) do
# state.expression_chain is a reverse accumulation of the stages and
# snapshots of the state at each one so that we can properly remap parameters for each stage.
state = %{
expression_chain: [],
nodes_to_replace: %{},
expr_split_fn: expr_split_fn,
# args is a map of id -> {stage_id, output_container_position}
args: %{}
}

cache = %{}
{expr, {cache, state}} = composite_eval(expr, state, cache)

expr_chain =
Enum.reduce(
[{make_ref(), expr, state.nodes_to_replace} | state.expression_chain],
[],
fn {id, expr, nodes_to_replace}, acc ->
# TO-DO: we need to also do a pass to avoid recalculating results that have been previously calculated.
# For example:
# x = arg0 + arg1
# y = arg0 - arg1
# z = x + y
# -----
# w = dot(z, arg1)
# y + w <- here, we currently have to recalculate y given that only z, arg0 and arg1 will be passed as arguments.
# ideally, we should also pass y as a value to avoid recalculating it.
# We might be able to calculate this in the first traversal somehow.

{expr, %{used_args: used_args}} =
composite_rewrite_subtree(
expr,
%{state | nodes_to_replace: nodes_to_replace}
)

arg_remapping =
used_args
|> Enum.sort_by(fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> idx end)
|> Enum.with_index(fn
{id, expr}, idx ->
{id, put_in(expr.data.args, [idx])}
end)
|> Map.new()

{expr, _} =
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})

arguments =
arg_remapping
|> Enum.map(fn {_id, arg_expr} ->
id = arg_expr.data.id
[idx] = arg_expr.data.args
source = Map.fetch!(state.args, id)
{idx, %{source: source}}
end)
|> Enum.sort_by(fn {idx, _} -> idx end)
|> Enum.map(fn {_, arg} -> arg end)

[
%Stage{
id: id,
expr: expr,
arguments: arguments
}
| acc
]
end
)

{expr_chain, cache, Map.delete(state, :expression_chain)}
end

defp composite_eval(expr, state, cache) do
Composite.traverse(expr, {cache, state}, &eval/2)
end

defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do
case {cache, state.nodes_to_replace} do
{_, %{^id => res}} ->
# Replace the node with the corresponding parameter
{res, {Map.put(cache, id, res), state}}

{%{^id => res}, _} ->
{res, {cache, state}}

_ ->
if state.expr_split_fn.(ans) do
split_expr(ans, {cache, state})
else
eval_apply(op, ans, {cache, state})
end
end
end

defp eval(other, {cache, state}) do
{other, {cache, state}}
end

defp split_expr(expr, {cache, state}) do
{args, {cache, state}} = Nx.Defn.Tree.apply_args(expr, {cache, state}, &eval/2)
# We need to save this so that each previous stage
# isn't affected by following ones
nodes_to_replace = state.nodes_to_replace

stage_id = make_ref()

{args, {tensor_args, _out_position, state}} =
Enum.map_reduce(args, {[], 0, state}, fn
%T{} = expr, {tensor_args, out_position, state} ->
arg = Expr.parameter(expr, map_size(state.args))

state = %{
state
| args: Map.put(state.args, arg.data.id, {stage_id, out_position}),
nodes_to_replace: Map.put(state.nodes_to_replace, expr.data.id, arg)
}

{arg, {[expr | tensor_args], out_position + 1, state}}

non_tensor_arg, acc ->
{non_tensor_arg, acc}
end)

new_expr = put_in(expr.data.args, args)

state =
update_in(
state.expression_chain,
&[
{stage_id, List.to_tuple(Enum.reverse(tensor_args)), nodes_to_replace}
| &1
]
)

cache = Map.put(cache, new_expr.data.id, new_expr)

{new_expr, {cache, state}}
end

defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
state = put_in(state.args[id], {nil, idx})
{expr, {Map.put(cache, id, expr), state}}
end

defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do
{tuple, cache} = composite_eval(tuple, state, cache)
res = elem(tuple, i)
{res, {Map.put(cache, id, res), state}}
end

defp eval_apply(_op, %T{data: %Expr{id: id}} = ans, {cache, state}) do
{args, {cache, state}} = Nx.Defn.Tree.apply_args(ans, {cache, state}, &eval/2)
ans = put_in(ans.data.args, args)
{ans, {Map.put(cache, id, ans), state}}
end

defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}})

defp composite_rewrite_subtree(container, state, acc) when is_list(container) do
Enum.map_reduce(container, acc, fn
%T{} = arg, acc ->
composite_rewrite_subtree(arg, state, acc)

arg, acc when is_list(arg) ->
composite_rewrite_subtree(arg, state, acc)

arg, acc ->
{arg, acc}
end)
end

defp composite_rewrite_subtree(container, state, acc) do
Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2))
end

defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], res)}

_ ->
{expr, put_in(acc.used_args[id], expr)}
end
end

defp rewrite_subtree(
%T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr,
state,
acc
) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], res)}

_ ->
{call, acc} = rewrite_subtree(call, state, acc)
# `subexpr` is hermetic, in the sense that it is a self-contained scope
# from which the arguments always come from `call`, so we can
# keep it as is.

{put_in(expr.data.args, [call, subexpr, fun]), acc}
end
end

defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
# nodes_to_replace always contains a param
{res, put_in(acc.used_args[id], res)}

_ ->
{args, acc} = composite_rewrite_subtree(args, state, acc)
{put_in(expr.data.args, args), acc}
end
end

defp rewrite_subtree(other, _, acc), do: {other, acc}
end
Loading