@@ -106,14 +106,13 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
106106 + [
107107 "--model.name=auto_parallel.llama3" ,
108108 "--compile.enable" ,
109+ "--experimental.comms_bucket_reorder_strategy=none" ,
109110 ],
110- "llama3_autop_1d_compile_bucket_reorder " : llama3_1d_common_opts
111+ "llama3_autop_1d_compile_aten_bucket_reorder " : llama3_1d_common_opts
111112 + [
112113 "--model.name=auto_parallel.llama3" ,
113114 "--compile.enable" ,
114- "--experimental.bucket_all_gathers_fx=fsdp" ,
115- "--experimental.bucket_reduce_scatters_fx=fsdp" ,
116- "--experimental.reorder_for_compute_comm_overlap" ,
115+ "--experimental.comms_bucket_reorder_strategy=aten" ,
117116 ],
118117}
119118
@@ -127,41 +126,31 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
127126 + [
128127 "--model.name=auto_parallel.llama3" ,
129128 "--compile.enable" ,
129+ "--experimental.comms_bucket_reorder_strategy=none" ,
130130 ],
131- "llama3_autop_2d_compile_bucket_reorder " : llama3_2d_common_opts
131+ "llama3_autop_2d_compile_aten_bucket_reorder " : llama3_2d_common_opts
132132 + [
133133 "--model.name=auto_parallel.llama3" ,
134134 "--compile.enable" ,
135- "--experimental.bucket_all_gathers_fx=fsdp" ,
136- "--experimental.bucket_reduce_scatters_fx=fsdp" ,
137- "--experimental.reorder_for_compute_comm_overlap" ,
135+ "--experimental.comms_bucket_reorder_strategy=aten" ,
138136 ],
139137}
140138
141- test_run = {
142- "FSDP_tp_compile" : llama3_2d_common_opts
143- + [
144- "--model.name=llama3" ,
145- "--compile.enable" ,
146- ],
147- }
148-
149-
150139all_runs = (
151140 llama3_1d
152141 | llama3_2d
153142 | {
154- "llama3_autop_1d_compile_ruisi_bucket_reorder " : llama3_1d_common_opts
143+ "llama3_autop_1d_compile_inductor_bucket_reorder " : llama3_1d_common_opts
155144 + [
156145 "--model.name=auto_parallel.llama3" ,
157146 "--compile.enable" ,
158- "--experimental.enable_simplefsdp_passes " ,
147+ "--experimental.comms_bucket_reorder_strategy=inductor " ,
159148 ],
160- "llama3_autop_2d_compile_ruisi_bucket_reorder " : llama3_2d_common_opts
149+ "llama3_autop_2d_compile_inductor_bucket_reorder " : llama3_2d_common_opts
161150 + [
162151 "--model.name=auto_parallel.llama3" ,
163152 "--compile.enable" ,
164- "--experimental.enable_simplefsdp_passes " ,
153+ "--experimental.comms_bucket_reorder_strategy=inductor " ,
165154 ],
166155 }
167156)
@@ -178,10 +167,26 @@ def build_sweep(names):
178167 [
179168 "llama3_FSDP_compile" ,
180169 "llama3_autop_1d_compile" ,
181- "llama3_autop_1d_compile_ruisi_bucket_reorder" ,
170+ "llama3_autop_1d_compile_inductor_bucket_reorder" ,
171+ "llama3_FSDP_tp_compile" ,
172+ "llama3_autop_2d_compile" ,
173+ "llama3_autop_2d_compile_inductor_bucket_reorder" ,
174+ ]
175+ ),
176+ "compare_1d_bucketing" : build_sweep (
177+ [
178+ "llama3_FSDP_compile" ,
179+ "llama3_autop_1d_compile" ,
180+ "llama3_autop_1d_compile_aten_bucket_reorder" ,
181+ "llama3_autop_1d_compile_inductor_bucket_reorder" ,
182+ ]
183+ ),
184+ "compare_2d_bucketing" : build_sweep (
185+ [
182186 "llama3_FSDP_tp_compile" ,
183187 "llama3_autop_2d_compile" ,
184- "llama3_autop_2d_compile_ruisi_bucket_reorder" ,
188+ "llama3_autop_2d_compile_aten_bucket_reorder" ,
189+ "llama3_autop_2d_compile_inductor_bucket_reorder" ,
185190 ]
186191 ),
187192}
0 commit comments