Skip to content

Commit ef70738

Browse files
authored
Fix for invalid sharding and add self-contained example for Llama3 model from TorchTitan (#22)
* Add Llama3 example This is a self-contained copy-paste from TorchTitan that works * Cleanup * Fix license * Remove invalid configurations that would yield empty shapes * Cleanup
1 parent 1306427 commit ef70738

File tree

2 files changed

+661
-0
lines changed

2 files changed

+661
-0
lines changed

autoparallel/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,25 @@ def get_placement_options(mesh, op, specs, user_args):
119119

120120
propagate_tensor_meta(op, user_args, out_strat)
121121
fill_missing_redistribute_cost(op, specs, out_strat)
122+
123+
kept = []
124+
for strategy in out_strat.strategies:
125+
is_valid = True
126+
for input_spec in strategy.input_specs:
127+
shape = list(input_spec.tensor_meta.shape)
128+
for mesh_shape, plc in zip(mesh.shape, input_spec.placements):
129+
if plc.is_shard():
130+
dim = plc.dim
131+
if shape[dim] % mesh_shape == 0:
132+
shape[dim] /= mesh_shape
133+
else:
134+
is_valid = False
135+
break
136+
if is_valid:
137+
kept.append(strategy)
138+
139+
out_strat = OpStrategy(kept)
140+
122141
return out_strat
123142

124143

0 commit comments

Comments
 (0)