Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jul 4, 2025

This PR adds a self-contained implementation of DeepSeekV3 MoE taken from https://github.com/pytorch/torchtitan/tree/deepseek-v3/torchtitan/models/deepseek_v3 taken from commit pytorch/torchtitan@0a6ab71
We could also add the Attention-based part, but I decided to focus on the MoE first because I've addressed a few issues with the MLA-based attention in #7

There were quite a few problems that appeared. My goal with this PR is to create a list of work-items that will bring us closer to supporting DeepSeekV3 in AutoParallel.

With all those changes, we are able to get the solver to give a solution (although it will be inefficient as unknown ops will be Replicate() everywhere). We should at least add a batch sharding fallback

First problem: custom triton kernels

The code as is couldn't trace because of custom triton kernels. I wrapped the triton kernel in a custom op and things were fine (see 4fdf95c)

Second problem: Node relabeling missing grad_input

For some reason, we seem to have one additional output in the joint graph, which makes the grad_input relabeling don't work. For now I just commented it as it's not necessary for getting the list of ops that need to be supported

Third problem: Missing sharding propagation for a number of ops

This was expected (and was the main goal of getting this work done). After a bunch of workarounds, I was able to get an exhaustive list of all the ops that seem to require re-implementation.

Here is a list of ops for which sharding propagation needs to be implemented to support DeepSeekV3:

Note that the _softmax_backward_data and the fma ops are redundant -- fma comes from the decomposition of _softmax_backward_data, so we can decide to either decompose it or not

Forth problem: Missing int64 (and other dtypes) flops for compute estimation

I defaulted to just ignoring unsupported dtypes and returning 0 flops basically, which is reasonable. But I think we should add as cost the memory read / write cost as well, like inductor does (although there are issues with inductor estimation as well)

Fifth problem: _grouped_mm requires stride % 16

For now I just add inf cost for this type of ops, but we should remove the placement ops that are invalid beforehand

Potential Fix PR: pytorch/pytorch#158245

  • filters out the invalid (non-aligned) shardings earlier on so that we do not attempt to estimate flops for grouped_mm for invalid inputs

@fmassa fmassa requested review from bdhirsh and wconstab July 4, 2025 14:07
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 4, 2025
fmassa added a commit that referenced this pull request Jul 4, 2025
fmassa added a commit that referenced this pull request Jul 4, 2025
assert len(op_spec[i].strategies) == len(
op_spec[0].strategies
), "Assume each cat input has same number of strategies"
# for i in range(1, num_tensors):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed? Are we filtering to the least common strategies now? / should we be?

Copy link
Contributor Author

@fmassa fmassa Jul 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just commented those asserts just so that I could move forward with what is needed for operator support. That's why I kept the torch.cat operator in the list of operators to implement, as I basically butchered things just so it could run up to the solver construction

@zpcore
Copy link
Contributor

zpcore commented Jul 5, 2025

The following three ops are already registered. Did you notice any issues with those?

aten.cat.default
aten.index_put.default
aten.slice_scatter.default

@fmassa
Copy link
Contributor Author

fmassa commented Jul 5, 2025

The following three ops are already registered. Did you notice any issues with those?

aten.cat.default
aten.index_put.default
aten.slice_scatter.default

IIRC, they assumed that the number of strategies for every input was the same, or something like that, and it caused issues in the past. It would be good to double-check if things work as expected nowadays.

EDIT: I commented out the custom registration for those three ops, and here were the issues I found:

aten.cat.default

(At least) The redistribute costs is missing, only one output strategy proposed, no tensor_meta

aten.index_put.default

Got this error in the PyTorch rule

[rank0]: Traceback (most recent call last):
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/examples/example_ds3.py", line 845, in <module>
[rank0]:     autop = AutoParallel(model, input_fn, mesh)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/api.py", line 250, in __init__
[rank0]:     sharding_optimizer = ShardingOptimizer(self.gm, self.mesh)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 45, in __init__
[rank0]:     self.strats = self.build_sharding_metadata()
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 69, in build_sharding_metadata
[rank0]:     strat = get_placement_options(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/utils.py", line 176, in get_placement_options
[rank0]:     out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py", line 744, in prop_index_put
[rank0]:     in_spec, indices_spec, values_spec = op_schema.args_schema
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: ValueError: too many values to unpack (expected 3)

aten.slice_scatter.default

(At least the) TensorMeta information is missing, doesn't take different number of input shardings (for input and index) into account, missing redistribute costs

@zpcore
Copy link
Contributor

zpcore commented Jul 6, 2025

The following three ops are already registered. Did you notice any issues with those?
aten.cat.default
aten.index_put.default
aten.slice_scatter.default

IIRC, they assumed that the number of strategies for every input was the same, or something like that, and it caused issues in the past. It would be good to double-check if things work as expected nowadays.

EDIT: I commented out the custom registration for those three ops, and here were the issues I found:

aten.cat.default

(At least) The redistribute costs is missing, only one output strategy proposed, no tensor_meta

aten.index_put.default

Got this error in the PyTorch rule

[rank0]: Traceback (most recent call last):
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/examples/example_ds3.py", line 845, in <module>
[rank0]:     autop = AutoParallel(model, input_fn, mesh)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/api.py", line 250, in __init__
[rank0]:     sharding_optimizer = ShardingOptimizer(self.gm, self.mesh)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 45, in __init__
[rank0]:     self.strats = self.build_sharding_metadata()
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/optimize_sharding.py", line 69, in build_sharding_metadata
[rank0]:     strat = get_placement_options(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/work/projects/autoparallel/autoparallel/utils.py", line 176, in get_placement_options
[rank0]:     out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/storage/home/fmassa/micromamba/envs/ptdev/lib/python3.12/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py", line 744, in prop_index_put
[rank0]:     in_spec, indices_spec, values_spec = op_schema.args_schema
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: ValueError: too many values to unpack (expected 3)

aten.slice_scatter.default

(At least the) TensorMeta information is missing, doesn't take different number of input shardings (for input and index) into account, missing redistribute costs

Thanks for the detail!

print("grad_input")

# TODO: figure out and fix why this is not working
# rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone have a look and see what was going wrong here and why we had more nodes than expected?

Copy link
Contributor

@bdhirsh bdhirsh Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a look - if we're willing to support a copy_ mutable epilogue at the end of the joint graph (similar to what compile handles today), I can uncomment rename_nodes after patching in these PRs:

pytorch/pytorch#157730

#32

@ezyang
Copy link
Contributor

ezyang commented Jul 8, 2025

The code as is couldn't trace because of custom triton kernels. I wrapped the triton kernel in a custom op and things were fine (see 4fdf95c)

If the code is a "user defined Triton kernel" it should trace directly; the OmniFMv2 triton kernels traced in this way.

@fmassa
Copy link
Contributor Author

fmassa commented Jul 8, 2025

The issue is was having was that somewhere in the triton kernel it was trying to grab the data_ptr of the tensor, and tracing through it didn't work

fmassa added 10 commits July 18, 2025 07:17
Doesn't yet work, but the code for now is a copy-paste from https://github.com/pytorch/torchtitan/blob/deepseek-v3/torchtitan/models/deepseek_v3/model/moe.py so it will make it easier to track the changes
Needs to fix the grad_input renaming which is not working for some reason
They are not correct, it's just to get a list of what we need for DeepSeekV3
prims.fma is probably easier to implement, but I'm removing this decomp just in case
Now should handle all ops properly, with correct shapes
…m cases with invalid strides

The grouped_mm should be handled in the sharding propagation and those cases should just be removed I think
Otherwise we can't shard on the batch dimension. With this change everything works up to executing the solver
@zpcore
Copy link
Contributor

zpcore commented Jul 25, 2025

@zpcore I rebased the PR on top of latest main and I tried removing the _softmax_backward_data fallback, but it seems like it doesn't have an input_spec so things fail https://github.com/pytorch/pytorch/blob/92e93bb580f31d405a72ee58f30fe82908bbeacf/torch/distributed/tensor/_ops/_math_ops.py#L563

This should probably be fixed before we can enable it

I see, just send out the PR to add back the missing field pytorch/pytorch#159167.

fmassa added 5 commits July 26, 2025 06:55
There was no flop formula, which was making the solver think that computing this op is free
This is still approximate as we can't evenly shard on the tokens, but doing this prior to see if we can introduce a DynamicShard primitive
Comment on lines 671 to 731
@register_opschema_rule(torch.ops.aten.sort.stable)
def sort_rule(mesh, op_schema):
op = torch.ops.aten.topk.default
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
return out_strat


@register_opschema_rule(torch.ops.aten.gather.default)
def gather_strategy(mesh, op_schema):
from torch.distributed.tensor._op_schema import PlacementList
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)

# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)

# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.scatter_add.default)
def scatter_add_strategy(mesh, op_schema):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab @zpcore can we double-check those added rules and make sure they are valid / make sense?

The strategy for scatter_add is basically following what I've added for gather, which is that we can allow all tensors to be sharded on any dimension which is not the dim from gather.

# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",))
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(world_size // 32, 32),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be instead

(world_size // 64, 64)

if I want to follow exactly what DeepSeek is doing

fmassa added a commit that referenced this pull request Aug 6, 2025
fmassa added a commit that referenced this pull request Aug 12, 2025
fmassa added a commit that referenced this pull request Aug 27, 2025
…nsum (#26)

* [WIP] Replace view -> mm -> view with matmul

This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing

* Fix matmul propagation rule

Somethings are starting to work, but we are not yet there

* Move function to graph_utils.py

* Pull improvements from #29

* Fix equation for einsum

* Cleanup code now that PyTorch has fixed _gen_einsum_strategies

Requires pytorch/pytorch#157593

* Generalize to more than 3d

* Generalize backward pass as well and make everything call into einsum

* Add note about future work

* Add einsum flops and generalize creation of sharded tensors

Before this, if we had a list of tensors we wouldn't shard the tensors inside the list

* Disable erroneous sdpa rule from backward

* Account for compute cost in collectives as well

This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement

* Account for compute cost in collectives as well

This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement

* Support getitem as well

* Improve comments and suppose 80% efficiency

* Suppose 70% efficiency for comms

* Add comment and set it to false by default

* Revert changes from another PR

* Add spaces back
fmassa added a commit that referenced this pull request Sep 29, 2025
Taken from #3 and #29. Decomposing softmax_backward leads to prims.fma, which doesn't have a sharding rule and we end up having a Replicate showing up as only possible sharding
fmassa added a commit that referenced this pull request Oct 1, 2025
Taken from #3 and #29. Decomposing softmax_backward leads to prims.fma, which doesn't have a sharding rule and we end up having a Replicate showing up as only possible sharding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants