Skip to content
Draft
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
52ea0c1
[WIP] Add basic DeepSeekV3
fmassa Jul 4, 2025
0d3ae2d
Lint
fmassa Jul 4, 2025
98d9dfd
Workarounds to make graph capture pass
fmassa Jul 4, 2025
61a63c4
Add dummy propagation rules just to see what we need to implement
fmassa Jul 4, 2025
67eb264
Cleanup
fmassa Jul 4, 2025
86d53ff
prims.fma comes from softmax_backward
fmassa Jul 4, 2025
7864f4d
Make _geenrate_dummy_strategy more generic
fmassa Jul 5, 2025
60ccf1a
Add proper redistribute_cost to dummy strategies
fmassa Jul 5, 2025
dbbc205
Hack around missing dtypes in compute estimation and handle grouped_m…
fmassa Jul 5, 2025
d92f8c6
Add representative batch size
fmassa Jul 5, 2025
e25ff7b
Fix grouped_mm stride issue
wconstab Jul 18, 2025
3b7e7fa
get DS3 running forward, OOM at backward
wconstab Jul 18, 2025
3833a06
WIP factory_strategy
wconstab Jul 18, 2025
3740b45
Start rebasing on top of main
fmassa Jul 25, 2025
39fedfd
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 25, 2025
6bec5f5
Fixes so that it runs
fmassa Jul 25, 2025
ce1c0a5
[WIP] Plumb fake_mode to avoid materializing memory
fmassa Jul 26, 2025
5d79bec
Use more representative values for DS3 example
fmassa Jul 26, 2025
daea5a2
Add approximate flop formula to grouped_mm
fmassa Jul 26, 2025
6d350e0
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 27, 2025
418ad55
Glimpses of having DeepSeekV3 returning a reasonable solution
fmassa Jul 27, 2025
fce321f
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 30, 2025
6d5747a
Use with_implicit_strategies instead of my generate_dummy_strategy
fmassa Jul 30, 2025
e0ae8a2
[WIP] Convert view->mm->view into matmul
fmassa Jul 30, 2025
1b83581
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Jul 31, 2025
cf1229d
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa Aug 4, 2025
4fe5a40
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Aug 9, 2025
67542ad
Remove sharding rules that have been since moved to PyTorch
fmassa Aug 9, 2025
779e808
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa Sep 4, 2025
124034e
Fixes after rebase
fmassa Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _get_decomp_table():
decomp_table.pop(torch.ops.aten.native_layer_norm.default)
decomp_table.pop(torch.ops.aten.embedding_dense_backward.default)
decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)
decomp_table.pop(torch.ops.aten._softmax_backward_data.default)

# decompose addmm to allow for TP on mm
decomp_table.pop(torch.ops.aten.addmm.default)
Expand Down Expand Up @@ -246,7 +247,7 @@ def __init__(self, model, input_fn, mesh: DeviceMesh):
self.mesh = mesh
self.build_model_graph()

sharding_optimizer = ShardingOptimizer(self.gm, self.mesh)
sharding_optimizer = ShardingOptimizer(self.gm, self.mesh, self.fake_mode)
# makes sharding of params and gradients the same
sharding_optimizer.add_grad_param_constraints()
self.sharding_optimizer = sharding_optimizer
Expand Down
46 changes: 39 additions & 7 deletions autoparallel/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,38 @@

import torch
from torch.utils._pytree import tree_map_only
from torch.utils.flop_counter import FlopCounterMode
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula


@register_flop_formula(torch.ops.aten._grouped_mm)
def gmm_flop(
a_shape, b_shape, offs_shape=None, bias_shape=None, out_shape=None, **kwargs
) -> int:
"""Count flops for the gmm operation."""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor
if len(a_shape) == 2:
assert offs_shape is not None
(b,) = offs_shape
m0, k = a_shape
# assumption: assume roughtly balanced, so falls-back to bmm
m = m0 // b
else:
assert offs_shape is None
b, m, k = a_shape
if len(b_shape) == 2:
assert offs_shape is not None
(b2,) = offs_shape
k2, n0 = b_shape
# assumption: assume roughtly balanced, so falls-back to bmm
n = n0 // b2
else:
b2, k2, n = b_shape
assert b == b2
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
flop = b * m * n * 2 * k
return flop


@dataclass
Expand Down Expand Up @@ -147,12 +178,13 @@ def _get_device_tflops(dtype):
f"Unsupported device: {device_name}. Supported devices: {[limit.name for limit in DEVICE_LIMITS]}"
)

if dtype not in device_limit.gemm_tflops:
raise ValueError(
f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
)
# TODO: add proper support for int64 etc
# if dtype not in device_limit.gemm_tflops:
# raise ValueError(
# f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
# )

return device_limit.gemm_tflops[dtype]
return device_limit.gemm_tflops.get(dtype, 1)


def _get_sharded_shape_stride(spec):
Expand Down Expand Up @@ -213,7 +245,7 @@ def estimate_strategy_runtime_cost(node, strategy):

# TODO: maybe cache the flop_counter to avoid recreating it
# all the time
with FlopCounterMode(display=False) as flop_counter:
with FlopCounterMode(display=False) as flop_counter, fake_mode:
node.target(*args, **kwargs)

flops = flop_counter.get_total_flops()
Expand Down
11 changes: 10 additions & 1 deletion autoparallel/export_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,32 +204,41 @@ def rename_nodes(fx_g, nodes, new_name, idxs=None):
# TODO: align number of grad names with inputs everywhere?
all_output_nodes = fx_g.graph.find_nodes(op="output")[0].all_input_nodes
output_nodes = all_output_nodes[: metadata.num_outputs]
print("output")
rename_nodes(fx_g, output_nodes, "output")
param_grad = all_output_nodes[
metadata.num_outputs : metadata.num_outputs + params_len
]
print("grad_param")
rename_nodes(fx_g, param_grad, "grad_param")
grad_inputs = all_output_nodes[metadata.num_outputs + params_len :]
inputs_that_require_grad = [
i for i, n in enumerate(metadata.input_info[params_len:]) if n.requires_grad
]
rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad)
print("grad_input")

# TODO: figure out and fix why this is not working
# rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone have a look and see what was going wrong here and why we had more nodes than expected?

Copy link
Contributor

@bdhirsh bdhirsh Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a look - if we're willing to support a copy_ mutable epilogue at the end of the joint graph (similar to what compile handles today), I can uncomment rename_nodes after patching in these PRs:

pytorch/pytorch#157730

#32


tangent_nodes = fx_g.graph.find_nodes(op="placeholder")[
-len(metadata.traced_tangents) :
]
outputs_that_require_grad = [
i for i, n in enumerate(metadata.output_info) if n.requires_grad
]
print("tangents")
rename_nodes(fx_g, tangent_nodes, "tangents", outputs_that_require_grad)
input_nodes = fx_g.graph.find_nodes(op="placeholder")[
params_len + buffer_len : -len(metadata.traced_tangents)
]
print("input")
rename_nodes(fx_g, input_nodes, "input")
param_nodes = fx_g.graph.find_nodes(op="placeholder")[:params_len]
print("param")
rename_nodes(fx_g, param_nodes, "param")

buffer_nodes = fx_g.graph.find_nodes(op="placeholder")[
params_len : params_len + buffer_len
]
print("buffer")
rename_nodes(fx_g, buffer_nodes, "buffer")
20 changes: 16 additions & 4 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ def _get_next_name(name):


class ShardingOptimizer:
def __init__(self, gm, mesh):
def __init__(self, gm, mesh, fake_mode):
self.gm = gm
self.graph = gm.graph
self.mesh = mesh
self.fake_mode = fake_mode
self.node_map = {node: i for i, node in enumerate(self.graph.nodes)}
self.strats = self.build_sharding_metadata()
# ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
Expand Down Expand Up @@ -168,7 +169,12 @@ def build_sharding_metadata(self):
strats[node] = strat
else:
strat = get_placement_options(
self.mesh, node.target, user_strats, user_args, user_kwargs
self.mesh,
node.target,
user_strats,
user_args,
user_kwargs,
self.fake_mode,
)
strats[node] = strat
elif node.op == "output":
Expand Down Expand Up @@ -215,7 +221,8 @@ def build_ds(self):
"num_output_strat": len(s.strategies),
}
for ss, ssi in enumerate(s.strategies):
compute_cost = estimate_strategy_runtime_cost(node, ssi)
with self.fake_mode:
compute_cost = estimate_strategy_runtime_cost(node, ssi)
for argi, xxi in enumerate(ssi.redistribute_cost):
for ii, comm_cost in enumerate(xxi):
va = pulp.LpVariable(
Expand Down Expand Up @@ -483,7 +490,12 @@ def print_costs_for_node(self, node, arg=0, **kwargs):
from torch.distributed.tensor._op_schema import _pretty_print_spec

tgt_strat = self.strats[node]
src_strat = self.strats[node.args[arg]]
# Use this instead of node.all_input_nodes because there could be
# duplicate nodes that get removed
all_input_nodes = [
x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node)
]
src_strat = self.strats[all_input_nodes[arg]]
src_placements = [""] + [
_pretty_print_spec(x.output_specs) for x in src_strat.strategies
]
Expand Down
153 changes: 142 additions & 11 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@
_op_rules = {}


def register_rule(op):
def register_rule(ops):
global _op_rules

def wrapper(impl):
_op_rules[op] = impl
if isinstance(ops, list):
for op in ops:
_op_rules[op] = impl
else:
_op_rules[ops] = impl
return impl

return wrapper
Expand Down Expand Up @@ -382,7 +386,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
This util applies to any factory function that takes 'size' as the first argument,
and supports Replication and Shard placements all at zero cost.
"""
assert isinstance(op_schema.args_schema[0], torch.Size)
assert isinstance(op_schema.args_schema[0], (torch.Size, list))
shape = op_schema.args_schema[0]
x = torch.empty(shape, device="meta")
stride = x.stride()
Expand Down Expand Up @@ -424,8 +428,11 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
* len(strategy_combs)
]

# TODO: should we add an input_spec here, so that we can ensure we always
# have input and output specs? For now I hacked it in utils.py
strategy = OpSpec(
output_specs=output_specs,
input_specs=[output_specs],
redistribute_cost=redistribute_cost,
)
all_strategies.append(strategy)
Expand Down Expand Up @@ -617,23 +624,26 @@ def _unsafe_index_rule(mesh, op_schema):
raise NotImplementedError()


@register_opschema_rule(torch.ops.aten.index.Tensor)
# Disable this rule as it's implementation is inferior than the baseline
# @register_opschema_rule(torch.ops.aten.index.Tensor)
def index_rule(mesh, op_schema):
raise NotImplementedError("Needs hardening, only tested on a few cases")
print(f"Ops that need to be implemented {torch.ops.aten.index.Tensor}")
# raise NotImplementedError("Needs hardening, only tested on a few cases")
strat = op_schema.args_schema
specs = strat # TODO: clean this up
res = []
idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())]
if strat[1].childs[0] is None:
idxs_placements = idxs_placements[:1]
else:
idxs_placements = idxs_placements[1:]
idxs_placements = [(Replicate(),) * mesh.ndim]
# if strat[1].childs[0] is None:
# idxs_placements = idxs_placements[:1]
# else:
# idxs_placements = idxs_placements[1:]
# TODO: this is a nasty hack and won't work for most of the cases
for i, ss in enumerate(strat[0].strategies):
for i, ss in enumerate(strat[0].strategies[:1]):
for plt in idxs_placements:
ispec = ss.input_specs[0]
ospec = DTensorSpec(mesh=mesh, placements=ispec.placements)
assert ss.output_spec == ispec
# assert ss.output_spec == ispec, f"{ss.output_spec}, {ispec}"
idxs_strats = [
DTensorSpec(mesh, placements=plt)
for x in strat[1].childs
Expand All @@ -658,6 +668,127 @@ def index_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.sort.stable)
def sort_rule(mesh, op_schema):
op = torch.ops.aten.topk.default
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
return out_strat


@register_opschema_rule(torch.ops.aten.gather.default)
def gather_strategy(mesh, op_schema):
from torch.distributed.tensor._op_schema import PlacementList
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)

# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)

# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.scatter_add.default)
def scatter_add_strategy(mesh, op_schema):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab @zpcore can we double-check those added rules and make sure they are valid / make sense?

The strategy for scatter_add is basically following what I've added for gather, which is that we can allow all tensors to be sharded on any dimension which is not the dim from gather.

from torch.distributed.tensor._op_schema import PlacementList

# from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]
# src_strategy = op_schema.args_schema[3]

input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)

"""
# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)
"""
# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim:
sharding = [Shard(d), Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)


@register_opschema_rule(torch.ops.aten.slice_scatter.default)
def slice_scatter_rule(mesh, op_schema):
op = torch.ops.aten.slice_scatter.default
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
](
op_schema
)
return out_strat


def sdpa_rule(op, mesh, op_schema):
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
op
Expand Down
Loading
Loading