Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Oct 29, 2025

Adding a graph pass to split up the backward graph into subgraphs for computing input gradients and weight gradients. I made my changes off of Sanket's PR here: #205 (we can also just abandon this PR and fold them in as necessary).

In the example, here are the relevant backward graphs:

  • full backward: P2012417569
  • bw inputs only: P2012419228
  • bw weights only: P2012420054

I ended up implementing the partitioning by just re-using the default partitioner. We could also consider writing our own one-off pass, but the default partitioner seemed to mostly work out of the box. The two changes I had to make were:

(1) removing AC recompute metadata from the backward graph, so the default partitioner wouldn't think we are trying to run AC. This metadata shouldn't matter for now, although I could maybe imagine a version of the dI/dW subgraph splitting where there are intermediates that we compute during dI, that we want to recompute in the dW subgraph.

(2) flipping the order of the outputs of the backward graph, so it is (*grad_inputs, *grad_weights). This made it easier to re-use the default partitioner, which expects that the "first" graph that gets split outs maps to the first K outputs of the original graph.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 29, 2025
@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only real changes in this PR are example_llama3_di_dw.py and split_di_dw_graph.py, the rest is an artifact of me working on top of #205 and making the PR against main. I can move this graph pass somewhere else and/or wait for that other PR to land as necessary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that it would be good to start landing things instead of working from branches of branches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. I was prioritizing "having a graph pass that Sanket can dump into his pipeline runtime" over having a PR that is ready to land. I'm happy to clean up this example to make it more self contained + make it a proper test.


def input_fn():
# 8192 for 70B
x = torch.randn(batch_size, seqlen, 4096, device=device, requires_grad=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my example i need the input to the forward/backward region we are tracing to require grad, so I can properly get a backward graph that computes input gradients. In the real pipeline example, this will probably work by us using the pipeline graph split APIs to get a subgraph for each model layer, where its input should already require grad.

return len(grad_inputs)

# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me understand what bw_gm is; is this ONLY the linear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding is that bw_gm here should be the entire backward graph of a given pipeline stage. A few things to note in this specific example file though:

(1) I am not using any kind of pipeline carving logic to get a fw/bw graph of a single stage, my dumb test is just generating a fw/bw graph of the entire user model. I figured this is more self contained, but it would be good to have an e2e test using the pipeline splitting logic too.

(2) technically, my code is currently performing the graph-split after Ivan's graph pass to split out FSDP allgathers from the fw and bw into a separate epilogue. In practice the order should probably not matter here, but one thing I noticed is that after running his pass, tangents_X inputs to the backward no longer show up in the main graph (left a comment here #201 (comment))

# in the AOTAutograd bw graph, the first K outputs correspond
# to gradients for any params/buffers in the user model
num_weight_gradients = autop.joint_with_descriptors._aot_state.aot_config.num_params_buffers
main_b_gm_di, main_b_gm_dw = split_di_dw_graph(main_b_gm, num_weight_gradients=num_weight_gradients)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope it's the full main backward graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm what do you mean? The flow I have here is:

(1) get a separate forward and backward graph of a given stage, operating on a microbatch. This is where the example in this file is not really correct - this test file is running autoparallel on the entire model and grabbing the joint / fw / bw graphs, but really this should correspond to the fw and bw graph of a single stage in the pipeline. Hopefully once we can integrate into @sanketpurandare 's pipeline runtime we'll can operate on something more realistic.

(2) split off the FSDP allgathers / reduce scatters from this main fw and bw graph

(3) [this graph pass]: take the backward graph operating on a micro batch, and split it into separate subgraphs to compute input gradients and weight gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm i read this comment before your comment above

assert len(outputs) == 1
output = outputs[0]
assert isinstance(output.args[0], tuple)
grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use descriptors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should definitely use descriptors. I mostly wanted to have something Sanket could try out quickly but happy to refactor.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, thanks @bdhirsh !

I'm ok landing the PR as is, but I think it would be good to convert the example into a test.

Also, I'll let you decide wrt Ed's comment on descriptors, might be better to use it indeed in the long run

@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that it would be good to start landing things instead of working from branches of branches.

Comment on lines +139 to +219
def group_mm_nodes_with_its_gradients(nodes):
fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta]
bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta]
assert len(fwd_nodes) * 2 == len(bwd_nodes)
res = {}
for fwd_node in fwd_nodes:
o = []
for bwd_node in bwd_nodes:
if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]:
o.append(bwd_node)
assert len(o) == 2
res[fwd_node] = o
return res


def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False):
# out = x @ w - S(0)R, RS(1) -> S(0)S(1)
# g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0)
# g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P

add_node_constraint = autop.sharding_optimizer.add_node_constraint
fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes)
fwd_nodes = list(fwd_bwd_groups.keys())
dim1 = 0 if feat_dim == 1 else 1
dim2 = 1 if feat_dim == 1 else 0
# assume there are 7 mm nodes per transformer block
# skip last mm as it's the final projection layer
assert (
len(fwd_nodes) - 1
) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}"
for block in range(0, len(fwd_nodes) - 1, 7):
fwd_nodes_block = fwd_nodes[block : block + 7]
# force the first 3 mm nodes to be S(0)S(1)
the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6]
for n in the_nodes:
add_node_constraint(n, (Shard(0), Shard(feat_dim)))
add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate()))
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1)))

if bwd_constraint:
bwd_nodes = fwd_bwd_groups[n]
# first is g_w, second is g_x
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1)))
add_node_constraint(bwd_nodes[1], (Shard(0), Partial()))

# add reduction to finish TP, yielding S(0)P
the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7]
for n in the_nodes:
add_node_constraint(n, (Shard(0), Partial()))
add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim)))
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0)))

if bwd_constraint:
bwd_nodes = fwd_bwd_groups[n]
# first is g_w, second is g_x
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2)))
add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim)))


def add_tp_constraints(autop):
mm_nodes = autop.gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
einsum_nodes = autop.gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.einsum.default
)
assert (len(mm_nodes) > 0) ^ (
len(einsum_nodes) > 0
), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}"
feat_dim = 1 if len(mm_nodes) > 0 else 2
tgt_nodes = mm_nodes + einsum_nodes
force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True)

if einsum_nodes:
# add sequence parallelism if we have einsum nodes
autop.sharding_optimizer.add_node_constraint(
list(tgt_nodes[3].users)[0], (Shard(0), Shard(1))
)
autop.sharding_optimizer.add_node_constraint(
list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this part of the example (and maybe make it a test?), as is not really going to be useful here I think

Comment on lines +242 to +259
if enable_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
from autoparallel.asynctp import micro_pipeline_tp_pass

existing_post_grad_custom_post_pass = (
torch._inductor.config.post_grad_custom_post_pass
)

def _pass(graph):
if existing_post_grad_custom_post_pass is not None:
existing_post_grad_custom_post_pass(graph)
micro_pipeline_tp_pass(graph)

torch._inductor.config.post_grad_custom_post_pass = _pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about removing this part as well to keep things simple?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants