Skip to content

Commit 68c8a07

Browse files
authored
Whc/knobs (#233)
* Move reordering pass config from example/llama3 into util lets us share the configs between examples and torchtitan * Update sweep.py to use new titan CLI args for reorder/bucket
1 parent 7fb094d commit 68c8a07

File tree

3 files changed

+75
-59
lines changed

3 files changed

+75
-59
lines changed

autoparallel/auto_bucketing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from functools import partial
7+
68
import torch
79
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
810

@@ -112,3 +114,46 @@ def aten_autobucketing_reordering_pass(
112114
max_in_flight_gb=configs.max_in_flight_gb,
113115
max_coll_distance=configs.max_coll_distance,
114116
)
117+
118+
119+
def configure_inductor_for_autobucketing(mode: str = "aten"):
120+
# allow configuring inductor comms optimizations from torchtitan commandline
121+
if mode == "aten":
122+
from autoparallel.auto_bucketing import (
123+
aten_autobucketing_config,
124+
aten_autobucketing_reordering_pass,
125+
)
126+
127+
# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
128+
torch._inductor.config.reorder_for_peak_memory = False
129+
torch._inductor.config.reorder_for_compute_comm_overlap = False
130+
aten_autobucketing_reordering_pass = partial(
131+
aten_autobucketing_reordering_pass,
132+
configs=aten_autobucketing_config, # type: ignore
133+
)
134+
torch._inductor.config.post_grad_custom_post_pass = (
135+
aten_autobucketing_reordering_pass # type: ignore
136+
)
137+
elif mode == "inductor":
138+
from autoparallel.auto_bucketing import (
139+
simple_fsdp_autobucketing_reordering_pass,
140+
simplefsdp_autobucketing_config,
141+
)
142+
143+
torch._inductor.config.allow_buffer_reuse = False
144+
torch._inductor.config.reorder_for_peak_memory = False
145+
torch._inductor.config.reorder_for_compute_comm_overlap = True
146+
simplefsdp_autobucketing_config.calibrate_number = 5
147+
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
148+
simple_fsdp_autobucketing_reordering_pass = partial(
149+
simple_fsdp_autobucketing_reordering_pass,
150+
configs=simplefsdp_autobucketing_config, # type: ignore
151+
)
152+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
153+
simple_fsdp_autobucketing_reordering_pass
154+
]
155+
elif mode == "none":
156+
torch._inductor.config.reorder_for_peak_memory = False
157+
torch._inductor.config.reorder_for_compute_comm_overlap = False
158+
else:
159+
raise ValueError(f"Unknown comms bucket reorder strategy: {mode}")

examples/example_llama3.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import time
7-
from functools import partial
87

98
import torch
109
from torch.distributed.fsdp import MixedPrecisionPolicy
@@ -13,12 +12,7 @@
1312

1413
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
1514
from autoparallel.api import AutoParallel
16-
from autoparallel.auto_bucketing import (
17-
aten_autobucketing_config,
18-
aten_autobucketing_reordering_pass,
19-
simple_fsdp_autobucketing_reordering_pass,
20-
simplefsdp_autobucketing_config,
21-
)
15+
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing
2216

2317
world_size = 64
2418

@@ -89,35 +83,7 @@ def input_fn():
8983
return x
9084

9185

92-
autobucketing_level = "aten"
93-
94-
if autobucketing_level == "aten":
95-
# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
96-
torch._inductor.config.reorder_for_peak_memory = False
97-
torch._inductor.config.reorder_for_compute_comm_overlap = False
98-
aten_autobucketing_reordering_pass = partial(
99-
aten_autobucketing_reordering_pass,
100-
configs=aten_autobucketing_config,
101-
)
102-
torch._inductor.config.post_grad_custom_post_pass = (
103-
aten_autobucketing_reordering_pass
104-
)
105-
elif autobucketing_level == "inductor":
106-
torch._inductor.config.allow_buffer_reuse = False
107-
torch._inductor.config.reorder_for_peak_memory = False
108-
torch._inductor.config.reorder_for_compute_comm_overlap = True
109-
simplefsdp_autobucketing_config.calibrate_number = 5
110-
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
111-
simple_fsdp_autobucketing_reordering_pass = partial(
112-
simple_fsdp_autobucketing_reordering_pass,
113-
configs=simplefsdp_autobucketing_config,
114-
)
115-
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
116-
simple_fsdp_autobucketing_reordering_pass
117-
]
118-
else:
119-
raise ValueError(f"Unknown autobucketing_level {autobucketing_level}")
120-
86+
configure_inductor_for_autobucketing("aten")
12187

12288
# parallelize the model
12389
with torch.device("meta"):

mast/sweep.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,13 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
106106
+ [
107107
"--model.name=auto_parallel.llama3",
108108
"--compile.enable",
109+
"--experimental.comms_bucket_reorder_strategy=none",
109110
],
110-
"llama3_autop_1d_compile_bucket_reorder": llama3_1d_common_opts
111+
"llama3_autop_1d_compile_aten_bucket_reorder": llama3_1d_common_opts
111112
+ [
112113
"--model.name=auto_parallel.llama3",
113114
"--compile.enable",
114-
"--experimental.bucket_all_gathers_fx=fsdp",
115-
"--experimental.bucket_reduce_scatters_fx=fsdp",
116-
"--experimental.reorder_for_compute_comm_overlap",
115+
"--experimental.comms_bucket_reorder_strategy=aten",
117116
],
118117
}
119118

@@ -127,41 +126,31 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]:
127126
+ [
128127
"--model.name=auto_parallel.llama3",
129128
"--compile.enable",
129+
"--experimental.comms_bucket_reorder_strategy=none",
130130
],
131-
"llama3_autop_2d_compile_bucket_reorder": llama3_2d_common_opts
131+
"llama3_autop_2d_compile_aten_bucket_reorder": llama3_2d_common_opts
132132
+ [
133133
"--model.name=auto_parallel.llama3",
134134
"--compile.enable",
135-
"--experimental.bucket_all_gathers_fx=fsdp",
136-
"--experimental.bucket_reduce_scatters_fx=fsdp",
137-
"--experimental.reorder_for_compute_comm_overlap",
135+
"--experimental.comms_bucket_reorder_strategy=aten",
138136
],
139137
}
140138

141-
test_run = {
142-
"FSDP_tp_compile": llama3_2d_common_opts
143-
+ [
144-
"--model.name=llama3",
145-
"--compile.enable",
146-
],
147-
}
148-
149-
150139
all_runs = (
151140
llama3_1d
152141
| llama3_2d
153142
| {
154-
"llama3_autop_1d_compile_ruisi_bucket_reorder": llama3_1d_common_opts
143+
"llama3_autop_1d_compile_inductor_bucket_reorder": llama3_1d_common_opts
155144
+ [
156145
"--model.name=auto_parallel.llama3",
157146
"--compile.enable",
158-
"--experimental.enable_simplefsdp_passes",
147+
"--experimental.comms_bucket_reorder_strategy=inductor",
159148
],
160-
"llama3_autop_2d_compile_ruisi_bucket_reorder": llama3_2d_common_opts
149+
"llama3_autop_2d_compile_inductor_bucket_reorder": llama3_2d_common_opts
161150
+ [
162151
"--model.name=auto_parallel.llama3",
163152
"--compile.enable",
164-
"--experimental.enable_simplefsdp_passes",
153+
"--experimental.comms_bucket_reorder_strategy=inductor",
165154
],
166155
}
167156
)
@@ -178,10 +167,26 @@ def build_sweep(names):
178167
[
179168
"llama3_FSDP_compile",
180169
"llama3_autop_1d_compile",
181-
"llama3_autop_1d_compile_ruisi_bucket_reorder",
170+
"llama3_autop_1d_compile_inductor_bucket_reorder",
171+
"llama3_FSDP_tp_compile",
172+
"llama3_autop_2d_compile",
173+
"llama3_autop_2d_compile_inductor_bucket_reorder",
174+
]
175+
),
176+
"compare_1d_bucketing": build_sweep(
177+
[
178+
"llama3_FSDP_compile",
179+
"llama3_autop_1d_compile",
180+
"llama3_autop_1d_compile_aten_bucket_reorder",
181+
"llama3_autop_1d_compile_inductor_bucket_reorder",
182+
]
183+
),
184+
"compare_2d_bucketing": build_sweep(
185+
[
182186
"llama3_FSDP_tp_compile",
183187
"llama3_autop_2d_compile",
184-
"llama3_autop_2d_compile_ruisi_bucket_reorder",
188+
"llama3_autop_2d_compile_aten_bucket_reorder",
189+
"llama3_autop_2d_compile_inductor_bucket_reorder",
185190
]
186191
),
187192
}

0 commit comments

Comments
 (0)