-
Notifications
You must be signed in to change notification settings - Fork 7
Add DeepSeekV3 and figure out what's needed to support it #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
autoparallel/propagation_rules.py
Outdated
| assert len(op_spec[i].strategies) == len( | ||
| op_spec[0].strategies | ||
| ), "Assume each cat input has same number of strategies" | ||
| # for i in range(1, num_tensors): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this changed? Are we filtering to the least common strategies now? / should we be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just commented those asserts just so that I could move forward with what is needed for operator support. That's why I kept the torch.cat operator in the list of operators to implement, as I basically butchered things just so it could run up to the solver construction
|
The following three ops are already registered. Did you notice any issues with those? |
IIRC, they assumed that the number of strategies for every input was the same, or something like that, and it caused issues in the past. It would be good to double-check if things work as expected nowadays. EDIT: I commented out the custom registration for those three ops, and here were the issues I found: aten.cat.default(At least) The redistribute costs is missing, only one output strategy proposed, no tensor_meta aten.index_put.defaultGot this error in the PyTorch rule aten.slice_scatter.default(At least the) TensorMeta information is missing, doesn't take different number of input shardings (for input and index) into account, missing redistribute costs |
Thanks for the detail!
|
autoparallel/export_module.py
Outdated
| print("grad_input") | ||
|
|
||
| # TODO: figure out and fix why this is not working | ||
| # rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can someone have a look and see what was going wrong here and why we had more nodes than expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took a look - if we're willing to support a copy_ mutable epilogue at the end of the joint graph (similar to what compile handles today), I can uncomment rename_nodes after patching in these PRs:
If the code is a "user defined Triton kernel" it should trace directly; the OmniFMv2 triton kernels traced in this way. |
|
The issue is was having was that somewhere in the triton kernel it was trying to grab the data_ptr of the tensor, and tracing through it didn't work |
3d2ad08 to
b5eb863
Compare
Doesn't yet work, but the code for now is a copy-paste from https://github.com/pytorch/torchtitan/blob/deepseek-v3/torchtitan/models/deepseek_v3/model/moe.py so it will make it easier to track the changes
Needs to fix the grad_input renaming which is not working for some reason
They are not correct, it's just to get a list of what we need for DeepSeekV3
prims.fma is probably easier to implement, but I'm removing this decomp just in case
Now should handle all ops properly, with correct shapes
…m cases with invalid strides The grouped_mm should be handled in the sharding propagation and those cases should just be removed I think
Otherwise we can't shard on the batch dimension. With this change everything works up to executing the solver
I see, just send out the PR to add back the missing field pytorch/pytorch#159167. |
Needs cleanup
There was no flop formula, which was making the solver think that computing this op is free
This is still approximate as we can't evenly shard on the tokens, but doing this prior to see if we can introduce a DynamicShard primitive
autoparallel/propagation_rules.py
Outdated
| @register_opschema_rule(torch.ops.aten.sort.stable) | ||
| def sort_rule(mesh, op_schema): | ||
| op = torch.ops.aten.topk.default | ||
| out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ | ||
| op | ||
| ]( | ||
| op_schema | ||
| ) | ||
| return out_strat | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.gather.default) | ||
| def gather_strategy(mesh, op_schema): | ||
| from torch.distributed.tensor._op_schema import PlacementList | ||
| from torch.distributed.tensor._ops._embedding_ops import _MaskPartial | ||
| from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy | ||
|
|
||
| input_strategy = op_schema.args_schema[0] | ||
| dim = op_schema.args_schema[1] | ||
| index_strategy = op_schema.args_schema[2] | ||
|
|
||
| input_shape = input_strategy.shape | ||
| index_shape = index_strategy.shape | ||
|
|
||
| single_mesh_dim_strategies = [] | ||
|
|
||
| # placement list stores placements of [output, input, index] | ||
| # first we always have replicate all for inputs and output | ||
| all_replicate: PlacementList = [Replicate()] * 3 | ||
| single_mesh_dim_strategies.append(all_replicate) | ||
|
|
||
| # input sharding, input sharded, index accepts mask partial, output follows index | ||
| # this only works when the input is sharded on the gather dimension, and | ||
| # index has size 1 on the gather dimension | ||
| if index_shape[dim] == 1: | ||
| index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||
| input_sharding: PlacementList = [ | ||
| index_partial_placement, | ||
| Shard(dim), | ||
| index_partial_placement, | ||
| ] | ||
| single_mesh_dim_strategies.append(input_sharding) | ||
|
|
||
| # index sharding, input replicated, index sharded, output follows index | ||
| # this only works when the sharding dimension is the gather dimension | ||
| index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] | ||
| single_mesh_dim_strategies.append(index_sharding) | ||
|
|
||
| if len(input_shape) == len(index_shape): | ||
| for d in range(len(input_shape)): | ||
| if d != dim: | ||
| sharding = [Shard(d), Shard(d), Shard(d)] | ||
| single_mesh_dim_strategies.append(sharding) | ||
|
|
||
| return expand_to_full_mesh_op_strategy( | ||
| mesh, op_schema, single_mesh_dim_strategies, input_index=1 | ||
| ) | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.scatter_add.default) | ||
| def scatter_add_strategy(mesh, op_schema): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/example_ds3.py
Outdated
| # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) | ||
| mesh = torch.distributed.device_mesh.init_device_mesh( | ||
| "cuda", | ||
| (world_size // 32, 32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be instead
(world_size // 64, 64)if I want to follow exactly what DeepSeek is doing
Gather and scatter_add were merged yesterday in pytorch/pytorch#160140
…nsum (#26) * [WIP] Replace view -> mm -> view with matmul This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing * Fix matmul propagation rule Somethings are starting to work, but we are not yet there * Move function to graph_utils.py * Pull improvements from #29 * Fix equation for einsum * Cleanup code now that PyTorch has fixed _gen_einsum_strategies Requires pytorch/pytorch#157593 * Generalize to more than 3d * Generalize backward pass as well and make everything call into einsum * Add note about future work * Add einsum flops and generalize creation of sharded tensors Before this, if we had a list of tensors we wouldn't shard the tensors inside the list * Disable erroneous sdpa rule from backward * Account for compute cost in collectives as well This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement * Account for compute cost in collectives as well This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement * Support getitem as well * Improve comments and suppose 80% efficiency * Suppose 70% efficiency for comms * Add comment and set it to false by default * Revert changes from another PR * Add spaces back
This PR adds a self-contained implementation of DeepSeekV3 MoE taken from https://github.com/pytorch/torchtitan/tree/deepseek-v3/torchtitan/models/deepseek_v3 taken from commit pytorch/torchtitan@0a6ab71
We could also add the Attention-based part, but I decided to focus on the MoE first because I've addressed a few issues with the MLA-based attention in #7
There were quite a few problems that appeared. My goal with this PR is to create a list of work-items that will bring us closer to supporting DeepSeekV3 in AutoParallel.
With all those changes, we are able to get the solver to give a solution (although it will be inefficient as unknown ops will be
Replicate()everywhere). We should at least add a batch sharding fallbackFirst problem: custom triton kernels
The code as is couldn't trace because of custom triton kernels. I wrapped the triton kernel in a custom op and things were fine (see 4fdf95c)
Second problem: Node relabeling missing grad_input
For some reason, we seem to have one additional output in the joint graph, which makes the
grad_inputrelabeling don't work. For now I just commented it as it's not necessary for getting the list of ops that need to be supportedThird problem: Missing sharding propagation for a number of ops
This was expected (and was the main goal of getting this work done). After a bunch of workarounds, I was able to get an exhaustive list of all the ops that seem to require re-implementation.
Here is a list of ops for which sharding propagation needs to be implemented to support DeepSeekV3:
sortstrategy pytorch/pytorch#159189sortstrategy pytorch/pytorch#159189sortandscatter_addstrategy pytorch/pytorch#159022Note that the
_softmax_backward_dataand thefmaops are redundant --fmacomes from the decomposition of_softmax_backward_data, so we can decide to either decompose it or notForth problem: Missing int64 (and other dtypes) flops for compute estimation
I defaulted to just ignoring unsupported dtypes and returning 0 flops basically, which is reasonable. But I think we should add as cost the memory read / write cost as well, like inductor does (although there are issues with inductor estimation as well)
Fifth problem: _grouped_mm requires stride % 16For now I just add inf cost for this type of ops, but we should remove the placement ops that are invalid beforehand
Potential Fix PR: pytorch/pytorch#158245