Skip to content

Commit c837c3e

Browse files
authored
Documentation for optimize_sharding.py (#38)
Signed-off-by: Edward Yang <[email protected]>
1 parent 8d5c4d1 commit c837c3e

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed

autoparallel/optimize_sharding.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,80 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
"""
7+
Sharding optimization using Integer Linear Programming (ILP).
8+
9+
This module solves the optimal sharding strategy problem by formulating it as an ILP
10+
where each binary variable x_{i,a,o,j} ∈ {0,1} represents a choice of input placement j
11+
and output placement o for operation i and argument a. The objective minimizes total cost:
12+
13+
minimize: Σ_{i,a,o,j} c_{i,a,o,j} * x_{i,a,o,j}
14+
15+
where:
16+
- x_{i,a,o,j}: binary decision variable (1 if strategy selected, 0 otherwise)
17+
- c_{i,a,o,j}: total cost (communication + computation) for this strategy choice
18+
19+
subject to the following constraint categories:
20+
21+
1. UNIQUENESS CONSTRAINTS: Each operation-argument pair must select exactly one
22+
input-output placement combination.
23+
24+
∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
25+
26+
→ Implemented in: add_unique_decision_constraint()
27+
28+
2. CONSISTENCY CONSTRAINTS: For multi-argument operations, all arguments must agree
29+
on the same output placement to ensure the operation can execute correctly.
30+
31+
∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
32+
where A_i is the number of arguments for operation i.
33+
34+
→ Implemented in: add_same_output_across_args_constraint()
35+
36+
3. FLOW CONSTRAINTS: The output placement of producer operations must match the
37+
input placement of consumer operations (dataflow consistency).
38+
39+
∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
40+
where operation i feeds into operation k at argument position a.
41+
42+
→ Implemented in: add_output_input_consistent_constraint()
43+
44+
4. COST CONSTRAINTS: Variables with infinite cost (invalid configurations) are
45+
forced to zero.
46+
47+
∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0
48+
49+
→ Implemented in: add_inf_cost_constraint()
50+
51+
5. EFFICIENCY CONSTRAINTS: Penalize inefficient collective operations like
52+
non-batch dimension shard-to-replicate conversions and forbid invalid
53+
transitions like replicate-to-partial.
54+
55+
- Shard(dim≠0) → Replicate: multiply cost by 4
56+
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
57+
- Partial → Shard(dim≠0): multiply cost by 4
58+
59+
→ Implemented in: penalize_inefficient_collectives()
60+
61+
6. USER CONSTRAINTS (optional): Force specific placements for inputs, outputs,
62+
parameters, or memory usage bounds.
63+
64+
6a. Input/Output constraints: x_{i,a,o*,j*} = 1 for specified (o*,j*)
65+
→ Implemented in: add_sharded_input_constraint(), add_sharded_output_constraint()
66+
67+
6b. Memory constraints: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
68+
→ Implemented in: add_parameter_memory_constraint()
69+
70+
6c. Parameter-gradient consistency: x_{param} = x_{grad_param}
71+
→ Implemented in: add_grad_param_constraints()
72+
73+
6d. General node constraints: Force specific placement for any node
74+
→ Implemented in: add_node_constraint()
75+
76+
The solver finds the globally optimal sharding strategy that minimizes total
77+
runtime cost while satisfying all constraints.
78+
"""
79+
680
import math
781

882
import pulp
@@ -43,6 +117,9 @@ def __init__(self, gm, mesh):
43117
self.mesh = mesh
44118
self.node_map = {node: i for i, node in enumerate(self.graph.nodes)}
45119
self.strats = self.build_sharding_metadata()
120+
# ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
121+
# Each key represents a choice of input placement ii and output placement ss
122+
# for operation s_i and argument argi (corresponds to x_{i,a,o,j} in math notation)
46123
self.ds, self.num_inp_out, self.num_args = self.build_ds()
47124
self.validate()
48125
self.prob = pulp.LpProblem("AutoParallel", pulp.LpMinimize)
@@ -80,6 +157,26 @@ def build_sharding_metadata(self):
80157
return strats
81158

82159
def build_ds(self):
160+
"""
161+
Build decision variables (ds) for the ILP optimization.
162+
163+
Creates binary variables x_{i,a,o,j} for each valid combination of:
164+
- s_i: operation index
165+
- argi: argument index
166+
- ss: output placement strategy index (o in math notation)
167+
- ii: input placement strategy index (j in math notation)
168+
169+
Returns:
170+
ds: Dictionary mapping (s_i, argi, ss, ii) -> {
171+
"va": PuLP binary variable,
172+
"cost": communication + computation cost,
173+
"full_strat": complete strategy object,
174+
"out_strat": output placement specification,
175+
"inp_strat": input placement specification
176+
}
177+
num_inp_out: Metadata about strategy counts per operation-argument
178+
num_args: Number of arguments per operation
179+
"""
83180
strats = self.strats
84181
ds = {}
85182
num_inp_out = {}
@@ -127,6 +224,12 @@ def walk_over_options(self, node, constrain_arg=None):
127224
yield argi, oi, ii
128225

129226
def add_unique_decision_constraint(self):
227+
"""
228+
UNIQUENESS CONSTRAINTS (Category 1): Each operation-argument pair must select exactly one
229+
input-output placement combination.
230+
231+
Mathematical form: ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
232+
"""
130233
# a single pair of input-output policy is chosen
131234
for s_i, node in enumerate(self.graph.nodes):
132235
if node.op not in {"placeholder", "call_function"}:
@@ -139,6 +242,12 @@ def add_unique_decision_constraint(self):
139242
self.prob += (pulp.lpSum(eqs) == 1, _get_next_name("unique_decision"))
140243

141244
def add_same_output_across_args_constraint(self):
245+
"""
246+
CONSISTENCY CONSTRAINTS (Category 2): For multi-argument operations, all arguments must agree
247+
on the same output placement to ensure the operation can execute correctly.
248+
249+
Mathematical form: ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
250+
"""
142251
# enforce that the same output policy is chosen
143252
# across arguments
144253
for s_i, node in enumerate(self.graph.nodes):
@@ -163,6 +272,12 @@ def add_same_output_across_args_constraint(self):
163272
)
164273

165274
def add_output_input_consistent_constraint(self):
275+
"""
276+
FLOW CONSTRAINTS (Category 3): The output placement of producer operations must match the
277+
input placement of consumer operations (dataflow consistency).
278+
279+
Mathematical form: ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
280+
"""
166281
# enforce that the input of strat_{i+1} == output of strat_{i}
167282
for s_i, node in enumerate(self.graph.nodes):
168283
if node.op == "output":
@@ -196,6 +311,12 @@ def add_output_input_consistent_constraint(self):
196311
)
197312

198313
def add_inf_cost_constraint(self):
314+
"""
315+
COST CONSTRAINTS (Category 4): Variables with infinite cost (invalid configurations) are
316+
forced to zero.
317+
318+
Mathematical form: ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0
319+
"""
199320
# force inf cost values to be 0, as the solver doesn't accept inf
200321
for x in self.ds.values():
201322
if not math.isfinite(x["cost"]):
@@ -213,6 +334,13 @@ def add_default_constraints(self):
213334

214335
def penalize_inefficient_collectives(self):
215336
"""
337+
EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
338+
non-batch dimension shard-to-replicate conversions and forbid invalid transitions.
339+
340+
- Shard(dim≠0) → Replicate: multiply cost by 4
341+
- Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
342+
- Partial → Shard(dim≠0): multiply cost by 4
343+
216344
When performing shard_{n} -> replicate (for n != 0), there is additional
217345
computation cost associated. Let's penalize it here while we don't add
218346
the computation cost together in the comm cost
@@ -440,6 +568,12 @@ def get_grad_param_nodes(self):
440568
return grad_param_nodes
441569

442570
def add_grad_param_constraints(self):
571+
"""
572+
USER CONSTRAINTS (Category 6c): Parameter-gradient consistency constraints.
573+
Ensures parameters and their gradients have matching sharding strategies.
574+
575+
Mathematical form: x_{param} = x_{grad_param}
576+
"""
443577
# TODO: need to make sure that the params and grads are aligned, which are not always the case
444578
# and we might have fewer gradients than parameters
445579

@@ -478,6 +612,12 @@ def add_grad_param_constraints(self):
478612
)
479613

480614
def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high):
615+
"""
616+
USER CONSTRAINTS (Category 6b): Memory constraints for parameters.
617+
Ensures total parameter memory usage stays within specified bounds.
618+
619+
Mathematical form: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
620+
"""
481621
# get all parameters
482622
param_nodes = self.get_param_nodes()
483623
elms = []
@@ -499,6 +639,12 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
499639
self.prob += (pulp.lpSum(elms) >= memory_factor_low, "memory_constraint_low")
500640

501641
def add_node_constraint(self, node, placement=None, constraint_name=None):
642+
"""
643+
USER CONSTRAINTS (Category 6d): General node constraints.
644+
Force specific placement for any node.
645+
646+
Mathematical form: x_{i,a,o*,j*} = 1 for specified (o*,j*)
647+
"""
502648
strat = self.strats[node]
503649
if placement is None:
504650
# default is Shard(0) to parallelize on the batch
@@ -514,6 +660,12 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
514660
self._add_node_constraint(node, oi=oi, constraint_name=constraint_name)
515661

516662
def add_sharded_input_constraint(self, input_placements=None):
663+
"""
664+
USER CONSTRAINTS (Category 6a): Input placement constraints.
665+
Force specific placements for input nodes and their corresponding gradient inputs.
666+
667+
Mathematical form: x_{i,a,o*,j*} = 1 for specified input placements (o*,j*)
668+
"""
517669
input_nodes = self.get_input_nodes()
518670
if input_placements is None:
519671
input_placements = [None] * len(input_nodes)
@@ -538,6 +690,12 @@ def add_sharded_input_constraint(self, input_placements=None):
538690
)
539691

540692
def add_sharded_output_constraint(self, output_placements=None):
693+
"""
694+
USER CONSTRAINTS (Category 6a): Output placement constraints.
695+
Force specific placements for output nodes and their corresponding gradient outputs.
696+
697+
Mathematical form: x_{i,a,o*,j*} = 1 for specified output placements (o*,j*)
698+
"""
541699
# add final constraint on the output strategy
542700
output_nodes = self.get_fn_output_nodes()
543701

0 commit comments

Comments
 (0)