Skip to content

Commit 641cbbf

Browse files
authored
Make 2d sweeps use batch size 2, add 'update3' sweep (#113)
1 parent 19f1fab commit 641cbbf

File tree

1 file changed

+47
-22
lines changed

1 file changed

+47
-22
lines changed

mast/sweep.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,80 +88,105 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
8888
return None
8989

9090

91+
llama3_1d_common_opts = [
92+
"--training.local_batch_size=1",
93+
"--parallelism.tensor_parallel_degree=1",
94+
]
95+
llama3_2d_common_opts = [
96+
"--training.local_batch_size=2",
97+
"--parallelism.tensor_parallel_degree=8",
98+
]
9199
llama3_1d = {
92-
"llama3_FSDP_compile": [
100+
"llama3_FSDP_compile": llama3_1d_common_opts
101+
+ [
93102
"--model.name=llama3",
94103
"--training.compile",
95-
"--parallelism.tensor_parallel_degree=1",
96104
],
97-
"llama3_autop_1d_compile": [
105+
"llama3_autop_1d_compile": llama3_1d_common_opts
106+
+ [
98107
"--model.name=llama3_auto_parallel",
99108
"--training.compile",
100-
"--parallelism.tensor_parallel_degree=1",
101109
],
102-
"llama3_autop_1d_compile_bucket_reorder": [
110+
"llama3_autop_1d_compile_bucket_reorder": llama3_1d_common_opts
111+
+ [
103112
"--model.name=llama3_auto_parallel",
104113
"--training.compile",
105-
"--parallelism.tensor_parallel_degree=1",
106114
"--experimental.bucket_all_gathers_fx=fsdp",
107115
"--experimental.bucket_reduce_scatters_fx=fsdp",
108116
"--experimental.reorder_for_compute_comm_overlap",
109117
],
110118
}
111119

112120
llama3_2d = {
113-
"llama3_FSDP_tp_compile": [
121+
"llama3_FSDP_tp_compile": llama3_2d_common_opts
122+
+ [
114123
"--model.name=llama3",
115124
"--training.compile",
116-
"--parallelism.tensor_parallel_degree=8",
117125
],
118-
"llama3_autop_2d_compile": [
126+
"llama3_autop_2d_compile": llama3_2d_common_opts
127+
+ [
119128
"--model.name=llama3_auto_parallel",
120129
"--training.compile",
121-
"--parallelism.tensor_parallel_degree=8",
122130
],
123-
"llama3_autop_2d_compile_bucket_reorder": [
131+
"llama3_autop_2d_compile_bucket_reorder": llama3_2d_common_opts
132+
+ [
124133
"--model.name=llama3_auto_parallel",
125134
"--training.compile",
126-
"--parallelism.tensor_parallel_degree=8",
127135
"--experimental.bucket_all_gathers_fx=fsdp",
128136
"--experimental.bucket_reduce_scatters_fx=fsdp",
129137
"--experimental.reorder_for_compute_comm_overlap",
130138
],
131139
}
132140

133141
test_run = {
134-
"FSDP_tp_compile": [
142+
"FSDP_tp_compile": llama3_2d_common_opts
143+
+ [
135144
"--model.name=llama3",
136145
"--training.compile",
137-
"--parallelism.tensor_parallel_degree=8",
138146
],
139147
}
140148

141-
sweeps = {
142-
"llama3_1d": llama3_1d,
143-
"llama3_2d": llama3_2d,
144-
}
149+
145150
all_runs = (
146151
llama3_1d
147152
| llama3_2d
148153
| {
149-
"llama3_autop_1d_compile_ruisi_bucket_reorder": [
154+
"llama3_autop_1d_compile_ruisi_bucket_reorder": llama3_1d_common_opts
155+
+ [
150156
"--model.name=llama3_auto_parallel",
151157
"--training.compile",
152-
"--parallelism.tensor_parallel_degree=1",
153158
"--experimental.enable_simplefsdp_passes",
154159
],
155-
"llama3_autop_2d_compile_ruisi_bucket_reorder": [
160+
"llama3_autop_2d_compile_ruisi_bucket_reorder": llama3_2d_common_opts
161+
+ [
156162
"--model.name=llama3_auto_parallel",
157163
"--training.compile",
158-
"--parallelism.tensor_parallel_degree=8",
159164
"--experimental.enable_simplefsdp_passes",
160165
],
161166
}
162167
)
163168

164169

170+
def build_sweep(names):
171+
return {name: all_runs[name] for name in names}
172+
173+
174+
sweeps = {
175+
"llama3_1d": llama3_1d,
176+
"llama3_2d": llama3_2d,
177+
"update3": build_sweep(
178+
[
179+
"llama3_FSDP_compile",
180+
"llama3_autop_1d_compile",
181+
"llama3_autop_1d_compile_ruisi_bucket_reorder",
182+
"llama3_FSDP_tp_compile",
183+
"llama3_autop_2d_compile",
184+
"llama3_autop_2d_compile_ruisi_bucket_reorder",
185+
]
186+
),
187+
}
188+
189+
165190
def run(args: argparse.Namespace) -> None:
166191

167192
if args.runs:

0 commit comments

Comments
 (0)