Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
52ea0c1
[WIP] Add basic DeepSeekV3
fmassa Jul 4, 2025
0d3ae2d
Lint
fmassa Jul 4, 2025
98d9dfd
Workarounds to make graph capture pass
fmassa Jul 4, 2025
61a63c4
Add dummy propagation rules just to see what we need to implement
fmassa Jul 4, 2025
67eb264
Cleanup
fmassa Jul 4, 2025
86d53ff
prims.fma comes from softmax_backward
fmassa Jul 4, 2025
7864f4d
Make _geenrate_dummy_strategy more generic
fmassa Jul 5, 2025
60ccf1a
Add proper redistribute_cost to dummy strategies
fmassa Jul 5, 2025
dbbc205
Hack around missing dtypes in compute estimation and handle grouped_m…
fmassa Jul 5, 2025
d92f8c6
Add representative batch size
fmassa Jul 5, 2025
e25ff7b
Fix grouped_mm stride issue
wconstab Jul 18, 2025
3b7e7fa
get DS3 running forward, OOM at backward
wconstab Jul 18, 2025
3833a06
WIP factory_strategy
wconstab Jul 18, 2025
3740b45
Start rebasing on top of main
fmassa Jul 25, 2025
39fedfd
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 25, 2025
6bec5f5
Fixes so that it runs
fmassa Jul 25, 2025
ce1c0a5
[WIP] Plumb fake_mode to avoid materializing memory
fmassa Jul 26, 2025
5d79bec
Use more representative values for DS3 example
fmassa Jul 26, 2025
daea5a2
Add approximate flop formula to grouped_mm
fmassa Jul 26, 2025
6d350e0
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 27, 2025
418ad55
Glimpses of having DeepSeekV3 returning a reasonable solution
fmassa Jul 27, 2025
fce321f
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 30, 2025
6d5747a
Use with_implicit_strategies instead of my generate_dummy_strategy
fmassa Jul 30, 2025
e0ae8a2
[WIP] Convert view->mm->view into matmul
fmassa Jul 30, 2025
1b83581
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 31, 2025
cf1229d
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Aug 4, 2025
4fe5a40
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 9, 2025
67542ad
Remove sharding rules that have been since moved to PyTorch
fmassa Aug 9, 2025
779e808
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Sep 4, 2025
124034e
Fixes after rebase
fmassa Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,12 @@ def print_costs_for_node(self, node, arg=0, **kwargs):
from torch.distributed.tensor._op_schema import _pretty_print_spec

tgt_strat = self.strats[node]
src_strat = self.strats[node.args[arg]]
# Use this instead of node.all_input_nodes because there could be
# duplicate nodes that get removed
all_input_nodes = [
x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node)
]
src_strat = self.strats[all_input_nodes[arg]]
src_placements = [""] + [
_pretty_print_spec(x.output_specs) for x in src_strat.strategies
]
Expand Down
121 changes: 121 additions & 0 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,127 @@ def index_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.sort.stable)
def sort_rule(mesh, op_schema):
op = torch.ops.aten.topk.default
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
return out_strat


@register_opschema_rule(torch.ops.aten.gather.default)
def gather_strategy(mesh, op_schema):
from torch.distributed.tensor._op_schema import PlacementList
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)

# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)

# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.scatter_add.default)
def scatter_add_strategy(mesh, op_schema):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab @zpcore can we double-check those added rules and make sure they are valid / make sense?

The strategy for scatter_add is basically following what I've added for gather, which is that we can allow all tensors to be sharded on any dimension which is not the dim from gather.

from torch.distributed.tensor._op_schema import PlacementList

# from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]
# src_strategy = op_schema.args_schema[3]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)

"""
# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)
"""
# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.slice_scatter.default)
def slice_scatter_rule(mesh, op_schema):
op = torch.ops.aten.slice_scatter.default
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
return out_strat


def sdpa_rule(op, mesh, op_schema):
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
Expand Down
21 changes: 21 additions & 0 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,32 @@ def _generate_dummy_strategy(
return out_strat


def keep_unique_configs(op_strat):
added = set()
filtered_strats = []
for strat in op_strat.strategies:
input_specs = strat.input_specs
output_specs = strat.output_specs
if isinstance(input_specs, list):
input_specs = tuple(input_specs)
if isinstance(output_specs, list):
output_specs = tuple(output_specs)
key = (input_specs, output_specs)
if key in added:
continue

added.add(key)
filtered_strats.append(strat)
return OpStrategy(filtered_strats)


def get_placement_options(mesh, op, specs, user_args, user_kwargs, fake_mode):
# print(op)

if op in _op_rules:
out_strat = _op_rules[op](mesh, specs)
out_strat = remove_invalid_configs(out_strat, mesh)
out_strat = keep_unique_configs(out_strat)
return out_strat

strat = []
Expand Down Expand Up @@ -224,6 +244,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs, fake_mode):
propagate_tensor_meta(op, user_args, user_kwargs, out_strat, fake_mode)
fill_missing_redistribute_cost(op, specs, out_strat)
out_strat = remove_invalid_configs(out_strat, mesh)
out_strat = keep_unique_configs(out_strat)

return out_strat

Expand Down
Loading
Loading