@@ -88,80 +88,105 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
8888 return None
8989
9090
91+ llama3_1d_common_opts = [
92+ "--training.local_batch_size=1" ,
93+ "--parallelism.tensor_parallel_degree=1" ,
94+ ]
95+ llama3_2d_common_opts = [
96+ "--training.local_batch_size=2" ,
97+ "--parallelism.tensor_parallel_degree=8" ,
98+ ]
9199llama3_1d = {
92- "llama3_FSDP_compile" : [
100+ "llama3_FSDP_compile" : llama3_1d_common_opts
101+ + [
93102 "--model.name=llama3" ,
94103 "--training.compile" ,
95- "--parallelism.tensor_parallel_degree=1" ,
96104 ],
97- "llama3_autop_1d_compile" : [
105+ "llama3_autop_1d_compile" : llama3_1d_common_opts
106+ + [
98107 "--model.name=llama3_auto_parallel" ,
99108 "--training.compile" ,
100- "--parallelism.tensor_parallel_degree=1" ,
101109 ],
102- "llama3_autop_1d_compile_bucket_reorder" : [
110+ "llama3_autop_1d_compile_bucket_reorder" : llama3_1d_common_opts
111+ + [
103112 "--model.name=llama3_auto_parallel" ,
104113 "--training.compile" ,
105- "--parallelism.tensor_parallel_degree=1" ,
106114 "--experimental.bucket_all_gathers_fx=fsdp" ,
107115 "--experimental.bucket_reduce_scatters_fx=fsdp" ,
108116 "--experimental.reorder_for_compute_comm_overlap" ,
109117 ],
110118}
111119
112120llama3_2d = {
113- "llama3_FSDP_tp_compile" : [
121+ "llama3_FSDP_tp_compile" : llama3_2d_common_opts
122+ + [
114123 "--model.name=llama3" ,
115124 "--training.compile" ,
116- "--parallelism.tensor_parallel_degree=8" ,
117125 ],
118- "llama3_autop_2d_compile" : [
126+ "llama3_autop_2d_compile" : llama3_2d_common_opts
127+ + [
119128 "--model.name=llama3_auto_parallel" ,
120129 "--training.compile" ,
121- "--parallelism.tensor_parallel_degree=8" ,
122130 ],
123- "llama3_autop_2d_compile_bucket_reorder" : [
131+ "llama3_autop_2d_compile_bucket_reorder" : llama3_2d_common_opts
132+ + [
124133 "--model.name=llama3_auto_parallel" ,
125134 "--training.compile" ,
126- "--parallelism.tensor_parallel_degree=8" ,
127135 "--experimental.bucket_all_gathers_fx=fsdp" ,
128136 "--experimental.bucket_reduce_scatters_fx=fsdp" ,
129137 "--experimental.reorder_for_compute_comm_overlap" ,
130138 ],
131139}
132140
133141test_run = {
134- "FSDP_tp_compile" : [
142+ "FSDP_tp_compile" : llama3_2d_common_opts
143+ + [
135144 "--model.name=llama3" ,
136145 "--training.compile" ,
137- "--parallelism.tensor_parallel_degree=8" ,
138146 ],
139147}
140148
141- sweeps = {
142- "llama3_1d" : llama3_1d ,
143- "llama3_2d" : llama3_2d ,
144- }
149+
145150all_runs = (
146151 llama3_1d
147152 | llama3_2d
148153 | {
149- "llama3_autop_1d_compile_ruisi_bucket_reorder" : [
154+ "llama3_autop_1d_compile_ruisi_bucket_reorder" : llama3_1d_common_opts
155+ + [
150156 "--model.name=llama3_auto_parallel" ,
151157 "--training.compile" ,
152- "--parallelism.tensor_parallel_degree=1" ,
153158 "--experimental.enable_simplefsdp_passes" ,
154159 ],
155- "llama3_autop_2d_compile_ruisi_bucket_reorder" : [
160+ "llama3_autop_2d_compile_ruisi_bucket_reorder" : llama3_2d_common_opts
161+ + [
156162 "--model.name=llama3_auto_parallel" ,
157163 "--training.compile" ,
158- "--parallelism.tensor_parallel_degree=8" ,
159164 "--experimental.enable_simplefsdp_passes" ,
160165 ],
161166 }
162167)
163168
164169
170+ def build_sweep (names ):
171+ return {name : all_runs [name ] for name in names }
172+
173+
174+ sweeps = {
175+ "llama3_1d" : llama3_1d ,
176+ "llama3_2d" : llama3_2d ,
177+ "update3" : build_sweep (
178+ [
179+ "llama3_FSDP_compile" ,
180+ "llama3_autop_1d_compile" ,
181+ "llama3_autop_1d_compile_ruisi_bucket_reorder" ,
182+ "llama3_FSDP_tp_compile" ,
183+ "llama3_autop_2d_compile" ,
184+ "llama3_autop_2d_compile_ruisi_bucket_reorder" ,
185+ ]
186+ ),
187+ }
188+
189+
165190def run (args : argparse .Namespace ) -> None :
166191
167192 if args .runs :
0 commit comments