-
Notifications
You must be signed in to change notification settings - Fork 7
graph pass dI/dW split example #212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| @@ -0,0 +1,105 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only real changes in this PR are example_llama3_di_dw.py and split_di_dw_graph.py, the rest is an artifact of me working on top of #205 and making the PR against main. I can move this graph pass somewhere else and/or wait for that other PR to land as necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that it would be good to start landing things instead of working from branches of branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. I was prioritizing "having a graph pass that Sanket can dump into his pipeline runtime" over having a PR that is ready to land. I'm happy to clean up this example to make it more self contained + make it a proper test.
|
|
||
| def input_fn(): | ||
| # 8192 for 70B | ||
| x = torch.randn(batch_size, seqlen, 4096, device=device, requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in my example i need the input to the forward/backward region we are tracing to require grad, so I can properly get a backward graph that computes input gradients. In the real pipeline example, this will probably work by us using the pipeline graph split APIs to get a subgraph for each model layer, where its input should already require grad.
| return len(grad_inputs) | ||
|
|
||
| # TODO: in theory we can infer num_weight_gradients from the graph metadata directly | ||
| def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Help me understand what bw_gm is; is this ONLY the linear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my understanding is that bw_gm here should be the entire backward graph of a given pipeline stage. A few things to note in this specific example file though:
(1) I am not using any kind of pipeline carving logic to get a fw/bw graph of a single stage, my dumb test is just generating a fw/bw graph of the entire user model. I figured this is more self contained, but it would be good to have an e2e test using the pipeline splitting logic too.
(2) technically, my code is currently performing the graph-split after Ivan's graph pass to split out FSDP allgathers from the fw and bw into a separate epilogue. In practice the order should probably not matter here, but one thing I noticed is that after running his pass, tangents_X inputs to the backward no longer show up in the main graph (left a comment here #201 (comment))
| # in the AOTAutograd bw graph, the first K outputs correspond | ||
| # to gradients for any params/buffers in the user model | ||
| num_weight_gradients = autop.joint_with_descriptors._aot_state.aot_config.num_params_buffers | ||
| main_b_gm_di, main_b_gm_dw = split_di_dw_graph(main_b_gm, num_weight_gradients=num_weight_gradients) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope it's the full main backward graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm what do you mean? The flow I have here is:
(1) get a separate forward and backward graph of a given stage, operating on a microbatch. This is where the example in this file is not really correct - this test file is running autoparallel on the entire model and grabbing the joint / fw / bw graphs, but really this should correspond to the fw and bw graph of a single stage in the pipeline. Hopefully once we can integrate into @sanketpurandare 's pipeline runtime we'll can operate on something more realistic.
(2) split off the FSDP allgathers / reduce scatters from this main fw and bw graph
(3) [this graph pass]: take the backward graph operating on a micro batch, and split it into separate subgraphs to compute input gradients and weight gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm i read this comment before your comment above
| assert len(outputs) == 1 | ||
| output = outputs[0] | ||
| assert isinstance(output.args[0], tuple) | ||
| grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use descriptors here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should definitely use descriptors. I mostly wanted to have something Sanket could try out quickly but happy to refactor.
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, thanks @bdhirsh !
I'm ok landing the PR as is, but I think it would be good to convert the example into a test.
Also, I'll let you decide wrt Ed's comment on descriptors, might be better to use it indeed in the long run
| @@ -0,0 +1,105 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that it would be good to start landing things instead of working from branches of branches.
| def group_mm_nodes_with_its_gradients(nodes): | ||
| fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta] | ||
| bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta] | ||
| assert len(fwd_nodes) * 2 == len(bwd_nodes) | ||
| res = {} | ||
| for fwd_node in fwd_nodes: | ||
| o = [] | ||
| for bwd_node in bwd_nodes: | ||
| if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]: | ||
| o.append(bwd_node) | ||
| assert len(o) == 2 | ||
| res[fwd_node] = o | ||
| return res | ||
|
|
||
|
|
||
| def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False): | ||
| # out = x @ w - S(0)R, RS(1) -> S(0)S(1) | ||
| # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0) | ||
| # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P | ||
|
|
||
| add_node_constraint = autop.sharding_optimizer.add_node_constraint | ||
| fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes) | ||
| fwd_nodes = list(fwd_bwd_groups.keys()) | ||
| dim1 = 0 if feat_dim == 1 else 1 | ||
| dim2 = 1 if feat_dim == 1 else 0 | ||
| # assume there are 7 mm nodes per transformer block | ||
| # skip last mm as it's the final projection layer | ||
| assert ( | ||
| len(fwd_nodes) - 1 | ||
| ) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}" | ||
| for block in range(0, len(fwd_nodes) - 1, 7): | ||
| fwd_nodes_block = fwd_nodes[block : block + 7] | ||
| # force the first 3 mm nodes to be S(0)S(1) | ||
| the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6] | ||
| for n in the_nodes: | ||
| add_node_constraint(n, (Shard(0), Shard(feat_dim))) | ||
| add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate())) | ||
| add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1))) | ||
|
|
||
| if bwd_constraint: | ||
| bwd_nodes = fwd_bwd_groups[n] | ||
| # first is g_w, second is g_x | ||
| add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1))) | ||
| add_node_constraint(bwd_nodes[1], (Shard(0), Partial())) | ||
|
|
||
| # add reduction to finish TP, yielding S(0)P | ||
| the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7] | ||
| for n in the_nodes: | ||
| add_node_constraint(n, (Shard(0), Partial())) | ||
| add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim))) | ||
| add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0))) | ||
|
|
||
| if bwd_constraint: | ||
| bwd_nodes = fwd_bwd_groups[n] | ||
| # first is g_w, second is g_x | ||
| add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2))) | ||
| add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim))) | ||
|
|
||
|
|
||
| def add_tp_constraints(autop): | ||
| mm_nodes = autop.gm.graph.find_nodes( | ||
| op="call_function", target=torch.ops.aten.mm.default | ||
| ) | ||
| einsum_nodes = autop.gm.graph.find_nodes( | ||
| op="call_function", target=torch.ops.aten.einsum.default | ||
| ) | ||
| assert (len(mm_nodes) > 0) ^ ( | ||
| len(einsum_nodes) > 0 | ||
| ), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}" | ||
| feat_dim = 1 if len(mm_nodes) > 0 else 2 | ||
| tgt_nodes = mm_nodes + einsum_nodes | ||
| force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True) | ||
|
|
||
| if einsum_nodes: | ||
| # add sequence parallelism if we have einsum nodes | ||
| autop.sharding_optimizer.add_node_constraint( | ||
| list(tgt_nodes[3].users)[0], (Shard(0), Shard(1)) | ||
| ) | ||
| autop.sharding_optimizer.add_node_constraint( | ||
| list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this part of the example (and maybe make it a test?), as is not really going to be useful here I think
| if enable_asynctp: | ||
| from torch.distributed._symmetric_memory import enable_symm_mem_for_group | ||
|
|
||
| enable_symm_mem_for_group(mesh["dp"].get_group().group_name) | ||
| enable_symm_mem_for_group(mesh["tp"].get_group().group_name) | ||
| torch._inductor.config._micro_pipeline_tp = False | ||
| from autoparallel.asynctp import micro_pipeline_tp_pass | ||
|
|
||
| existing_post_grad_custom_post_pass = ( | ||
| torch._inductor.config.post_grad_custom_post_pass | ||
| ) | ||
|
|
||
| def _pass(graph): | ||
| if existing_post_grad_custom_post_pass is not None: | ||
| existing_post_grad_custom_post_pass(graph) | ||
| micro_pipeline_tp_pass(graph) | ||
|
|
||
| torch._inductor.config.post_grad_custom_post_pass = _pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about removing this part as well to keep things simple?
Adding a graph pass to split up the backward graph into subgraphs for computing input gradients and weight gradients. I made my changes off of Sanket's PR here: #205 (we can also just abandon this PR and fold them in as necessary).
In the example, here are the relevant backward graphs:
I ended up implementing the partitioning by just re-using the default partitioner. We could also consider writing our own one-off pass, but the default partitioner seemed to mostly work out of the box. The two changes I had to make were:
(1) removing AC recompute metadata from the backward graph, so the default partitioner wouldn't think we are trying to run AC. This metadata shouldn't matter for now, although I could maybe imagine a version of the dI/dW subgraph splitting where there are intermediates that we compute during dI, that we want to recompute in the dW subgraph.
(2) flipping the order of the outputs of the backward graph, so it is (*grad_inputs, *grad_weights). This made it easier to re-use the default partitioner, which expects that the "first" graph that gets split outs maps to the first K outputs of the original graph.