33# This source code is licensed under the BSD license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ """
7+ Sharding optimization using Integer Linear Programming (ILP).
8+
9+ This module solves the optimal sharding strategy problem by formulating it as an ILP
10+ where each binary variable x_{i,a,o,j} ∈ {0,1} represents a choice of input placement j
11+ and output placement o for operation i and argument a. The objective minimizes total cost:
12+
13+ minimize: Σ_{i,a,o,j} c_{i,a,o,j} * x_{i,a,o,j}
14+
15+ where:
16+ - x_{i,a,o,j}: binary decision variable (1 if strategy selected, 0 otherwise)
17+ - c_{i,a,o,j}: total cost (communication + computation) for this strategy choice
18+
19+ subject to the following constraint categories:
20+
21+ 1. UNIQUENESS CONSTRAINTS: Each operation-argument pair must select exactly one
22+ input-output placement combination.
23+
24+ ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
25+
26+ → Implemented in: add_unique_decision_constraint()
27+
28+ 2. CONSISTENCY CONSTRAINTS: For multi-argument operations, all arguments must agree
29+ on the same output placement to ensure the operation can execute correctly.
30+
31+ ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
32+ where A_i is the number of arguments for operation i.
33+
34+ → Implemented in: add_same_output_across_args_constraint()
35+
36+ 3. FLOW CONSTRAINTS: The output placement of producer operations must match the
37+ input placement of consumer operations (dataflow consistency).
38+
39+ ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
40+ where operation i feeds into operation k at argument position a.
41+
42+ → Implemented in: add_output_input_consistent_constraint()
43+
44+ 4. COST CONSTRAINTS: Variables with infinite cost (invalid configurations) are
45+ forced to zero.
46+
47+ ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0
48+
49+ → Implemented in: add_inf_cost_constraint()
50+
51+ 5. EFFICIENCY CONSTRAINTS: Penalize inefficient collective operations like
52+ non-batch dimension shard-to-replicate conversions and forbid invalid
53+ transitions like replicate-to-partial.
54+
55+ - Shard(dim≠0) → Replicate: multiply cost by 4
56+ - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
57+ - Partial → Shard(dim≠0): multiply cost by 4
58+
59+ → Implemented in: penalize_inefficient_collectives()
60+
61+ 6. USER CONSTRAINTS (optional): Force specific placements for inputs, outputs,
62+ parameters, or memory usage bounds.
63+
64+ 6a. Input/Output constraints: x_{i,a,o*,j*} = 1 for specified (o*,j*)
65+ → Implemented in: add_sharded_input_constraint(), add_sharded_output_constraint()
66+
67+ 6b. Memory constraints: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
68+ → Implemented in: add_parameter_memory_constraint()
69+
70+ 6c. Parameter-gradient consistency: x_{param} = x_{grad_param}
71+ → Implemented in: add_grad_param_constraints()
72+
73+ 6d. General node constraints: Force specific placement for any node
74+ → Implemented in: add_node_constraint()
75+
76+ The solver finds the globally optimal sharding strategy that minimizes total
77+ runtime cost while satisfying all constraints.
78+ """
79+
680import math
781
882import pulp
@@ -43,6 +117,9 @@ def __init__(self, gm, mesh):
43117 self .mesh = mesh
44118 self .node_map = {node : i for i , node in enumerate (self .graph .nodes )}
45119 self .strats = self .build_sharding_metadata ()
120+ # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data
121+ # Each key represents a choice of input placement ii and output placement ss
122+ # for operation s_i and argument argi (corresponds to x_{i,a,o,j} in math notation)
46123 self .ds , self .num_inp_out , self .num_args = self .build_ds ()
47124 self .validate ()
48125 self .prob = pulp .LpProblem ("AutoParallel" , pulp .LpMinimize )
@@ -80,6 +157,26 @@ def build_sharding_metadata(self):
80157 return strats
81158
82159 def build_ds (self ):
160+ """
161+ Build decision variables (ds) for the ILP optimization.
162+
163+ Creates binary variables x_{i,a,o,j} for each valid combination of:
164+ - s_i: operation index
165+ - argi: argument index
166+ - ss: output placement strategy index (o in math notation)
167+ - ii: input placement strategy index (j in math notation)
168+
169+ Returns:
170+ ds: Dictionary mapping (s_i, argi, ss, ii) -> {
171+ "va": PuLP binary variable,
172+ "cost": communication + computation cost,
173+ "full_strat": complete strategy object,
174+ "out_strat": output placement specification,
175+ "inp_strat": input placement specification
176+ }
177+ num_inp_out: Metadata about strategy counts per operation-argument
178+ num_args: Number of arguments per operation
179+ """
83180 strats = self .strats
84181 ds = {}
85182 num_inp_out = {}
@@ -127,6 +224,12 @@ def walk_over_options(self, node, constrain_arg=None):
127224 yield argi , oi , ii
128225
129226 def add_unique_decision_constraint (self ):
227+ """
228+ UNIQUENESS CONSTRAINTS (Category 1): Each operation-argument pair must select exactly one
229+ input-output placement combination.
230+
231+ Mathematical form: ∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
232+ """
130233 # a single pair of input-output policy is chosen
131234 for s_i , node in enumerate (self .graph .nodes ):
132235 if node .op not in {"placeholder" , "call_function" }:
@@ -139,6 +242,12 @@ def add_unique_decision_constraint(self):
139242 self .prob += (pulp .lpSum (eqs ) == 1 , _get_next_name ("unique_decision" ))
140243
141244 def add_same_output_across_args_constraint (self ):
245+ """
246+ CONSISTENCY CONSTRAINTS (Category 2): For multi-argument operations, all arguments must agree
247+ on the same output placement to ensure the operation can execute correctly.
248+
249+ Mathematical form: ∀i,o: Σ_j x_{i,0,o,j} = Σ_j x_{i,1,o,j} = ... = Σ_j x_{i,A_i-1,o,j}
250+ """
142251 # enforce that the same output policy is chosen
143252 # across arguments
144253 for s_i , node in enumerate (self .graph .nodes ):
@@ -163,6 +272,12 @@ def add_same_output_across_args_constraint(self):
163272 )
164273
165274 def add_output_input_consistent_constraint (self ):
275+ """
276+ FLOW CONSTRAINTS (Category 3): The output placement of producer operations must match the
277+ input placement of consumer operations (dataflow consistency).
278+
279+ Mathematical form: ∀(i→k): Σ_j x_{i,0,o,j} = Σ_j x_{k,a,j,o}
280+ """
166281 # enforce that the input of strat_{i+1} == output of strat_{i}
167282 for s_i , node in enumerate (self .graph .nodes ):
168283 if node .op == "output" :
@@ -196,6 +311,12 @@ def add_output_input_consistent_constraint(self):
196311 )
197312
198313 def add_inf_cost_constraint (self ):
314+ """
315+ COST CONSTRAINTS (Category 4): Variables with infinite cost (invalid configurations) are
316+ forced to zero.
317+
318+ Mathematical form: ∀i,a,o,j: c_{i,a,o,j} = ∞ ⟹ x_{i,a,o,j} = 0
319+ """
199320 # force inf cost values to be 0, as the solver doesn't accept inf
200321 for x in self .ds .values ():
201322 if not math .isfinite (x ["cost" ]):
@@ -213,6 +334,13 @@ def add_default_constraints(self):
213334
214335 def penalize_inefficient_collectives (self ):
215336 """
337+ EFFICIENCY CONSTRAINTS (Category 5): Penalize inefficient collective operations like
338+ non-batch dimension shard-to-replicate conversions and forbid invalid transitions.
339+
340+ - Shard(dim≠0) → Replicate: multiply cost by 4
341+ - Replicate → Partial: x_{i,a,o,j} = 0 (forbidden)
342+ - Partial → Shard(dim≠0): multiply cost by 4
343+
216344 When performing shard_{n} -> replicate (for n != 0), there is additional
217345 computation cost associated. Let's penalize it here while we don't add
218346 the computation cost together in the comm cost
@@ -440,6 +568,12 @@ def get_grad_param_nodes(self):
440568 return grad_param_nodes
441569
442570 def add_grad_param_constraints (self ):
571+ """
572+ USER CONSTRAINTS (Category 6c): Parameter-gradient consistency constraints.
573+ Ensures parameters and their gradients have matching sharding strategies.
574+
575+ Mathematical form: x_{param} = x_{grad_param}
576+ """
443577 # TODO: need to make sure that the params and grads are aligned, which are not always the case
444578 # and we might have fewer gradients than parameters
445579
@@ -478,6 +612,12 @@ def add_grad_param_constraints(self):
478612 )
479613
480614 def add_parameter_memory_constraint (self , memory_factor_low , memory_factor_high ):
615+ """
616+ USER CONSTRAINTS (Category 6b): Memory constraints for parameters.
617+ Ensures total parameter memory usage stays within specified bounds.
618+
619+ Mathematical form: Σ_{params} (size_ratio * x_{param}) ≤ memory_limit
620+ """
481621 # get all parameters
482622 param_nodes = self .get_param_nodes ()
483623 elms = []
@@ -499,6 +639,12 @@ def add_parameter_memory_constraint(self, memory_factor_low, memory_factor_high)
499639 self .prob += (pulp .lpSum (elms ) >= memory_factor_low , "memory_constraint_low" )
500640
501641 def add_node_constraint (self , node , placement = None , constraint_name = None ):
642+ """
643+ USER CONSTRAINTS (Category 6d): General node constraints.
644+ Force specific placement for any node.
645+
646+ Mathematical form: x_{i,a,o*,j*} = 1 for specified (o*,j*)
647+ """
502648 strat = self .strats [node ]
503649 if placement is None :
504650 # default is Shard(0) to parallelize on the batch
@@ -514,6 +660,12 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
514660 self ._add_node_constraint (node , oi = oi , constraint_name = constraint_name )
515661
516662 def add_sharded_input_constraint (self , input_placements = None ):
663+ """
664+ USER CONSTRAINTS (Category 6a): Input placement constraints.
665+ Force specific placements for input nodes and their corresponding gradient inputs.
666+
667+ Mathematical form: x_{i,a,o*,j*} = 1 for specified input placements (o*,j*)
668+ """
517669 input_nodes = self .get_input_nodes ()
518670 if input_placements is None :
519671 input_placements = [None ] * len (input_nodes )
@@ -538,6 +690,12 @@ def add_sharded_input_constraint(self, input_placements=None):
538690 )
539691
540692 def add_sharded_output_constraint (self , output_placements = None ):
693+ """
694+ USER CONSTRAINTS (Category 6a): Output placement constraints.
695+ Force specific placements for output nodes and their corresponding gradient outputs.
696+
697+ Mathematical form: x_{i,a,o*,j*} = 1 for specified output placements (o*,j*)
698+ """
541699 # add final constraint on the output strategy
542700 output_nodes = self .get_fn_output_nodes ()
543701
0 commit comments