Skip to content

Commit 1c60810

Browse files
committed
Enable reorder,bucketing,async_tp to example
stack-info: PR: #174, branch: IvanKobzarev/stack/4
1 parent 2b9ef0a commit 1c60810

File tree

2 files changed

+102
-14
lines changed

2 files changed

+102
-14
lines changed

autoparallel/asynctp.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from torch._inductor import inductor_prims
16+
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
1617
from torch._inductor.pattern_matcher import (
1718
MULTIPLE,
1819
CallFunction,
@@ -23,6 +24,7 @@
2324
PatternExpr,
2425
PatternMatcherPass,
2526
)
27+
from torch._logging import trace_structured
2628
from torch.utils._ordered_set import OrderedSet
2729

2830
import autoparallel.asynctp_ops # noqa: F401
@@ -34,7 +36,7 @@
3436
_micro_pipeline_tp_ag_transpose_mm_enabled = True
3537

3638
# Check performance if overhead of decomposition outweights pipeline wins
37-
_micro_pipeline_tp_ag_mm_last_dim_enabled = False
39+
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
3840

3941
_micro_pipeline_tp_mm_rs_last_dim_enabled = True
4042

@@ -720,7 +722,7 @@ def _insert_fused_all_gather_transpose_matmul(
720722
raise AssertionError(f"Unexpected matmul match type: {mm_type}")
721723

722724

723-
def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
725+
def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
724726
"""
725727
Fused the pattern
726728
@@ -755,6 +757,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
755757
all_gather.group_name,
756758
)
757759

760+
log_strs.append(f"fuse_agmm {all_gather}")
758761
if not is_symm_mem_enabled_for_group(group_name):
759762
return
760763

@@ -774,6 +777,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
774777
for matmul in matmuls
775778
if all_gather.res_node not in matmul.arg_ancestor_nodes
776779
]
780+
log_strs.append(f"fuse_agmm matmuls:{matmuls}")
777781

778782
if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
779783
return
@@ -870,6 +874,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
870874
for node in nodes_to_raise:
871875
if order[node] > order[fused_node]:
872876
fused_node.prepend(node)
877+
log_strs.append("fuse_agmm DONE")
873878

874879

875880
def _scatter_dim_after_reshape(
@@ -990,7 +995,7 @@ def _insert_fused_matmul_reduce_scatter(
990995
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
991996

992997

993-
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
998+
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> None:
994999
"""
9951000
Fused the pattern
9961001
@@ -1004,6 +1009,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10041009
10051010
Returns boolean indicating if fusion was successful or not.
10061011
"""
1012+
log_strs.append(f"fuse_mmrs {reduce_scatter}")
10071013
if (
10081014
not torch.distributed.is_available()
10091015
or not torch.distributed.is_nccl_available()
@@ -1032,6 +1038,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10321038
)
10331039

10341040
if not is_symm_mem_enabled_for_group(group_name):
1041+
log_strs.append("fuse_mmrs not symm mem group")
10351042
return
10361043

10371044
if (
@@ -1048,16 +1055,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10481055
log.warning(
10491056
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
10501057
)
1058+
log_strs.append("fuse_mmrs input.node.users != 1")
10511059
return
10521060

10531061
matmul = _find_producer_matmul(input_node)
10541062
if matmul is None:
1063+
log_strs.append("fuse_mmrs no matmul")
10551064
log.warning(
10561065
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
10571066
)
10581067
return
10591068

10601069
if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
1070+
log_strs.append("fuse_mmrs wait in matmul.arg_ancestors")
10611071
log.warning(
10621072
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
10631073
)
@@ -1123,6 +1133,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
11231133
if order[node] > order[fused_node]:
11241134
fused_node.prepend(node)
11251135

1136+
log_strs.append("fuse_mmrs DONE")
11261137
log.debug("successfully fused matmul reduce scatter")
11271138

11281139

@@ -1173,6 +1184,7 @@ def is_collective(node) -> bool:
11731184
return collective_to_overlappable_nodes
11741185

11751186

1187+
# TODO: Convert return type to set
11761188
def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
11771189
"""
11781190
Find all unexposed collectives in the graph.
@@ -1209,26 +1221,71 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool:
12091221
return unexposed_collectives
12101222

12111223

1212-
def micro_pipeline_tp_pass(graph: torch.fx.Graph):
1224+
def micro_pipeline_tp_pass(
1225+
graph: torch.fx.Graph,
1226+
collective_info: Optional[dict[torch.fx.Node, CollectiveInfo]] = None,
1227+
):
1228+
trace_structured(
1229+
"artifact",
1230+
metadata_fn=lambda: {
1231+
"name": "asynctp_pre_graph",
1232+
"encoding": "string",
1233+
},
1234+
payload_fn=lambda: graph.owning_module.print_readable(False),
1235+
)
12131236
all_gathers = find_all_gather_patterns(graph)
12141237
reduce_scatters = find_reduce_scatter_patterns(graph)
12151238

12161239
# When a collective can be hidden through either simple overlapping or
12171240
# micro-pipeline TP, we prefer simple overlapping to avoid the overhead
12181241
# associated with decomposition.
1219-
unexposed_collectives = _get_unexposed_collectives(graph)
1242+
# Known problem: collective info contains only pre-bucketing collectives
1243+
if collective_info is None:
1244+
unexposed_collectives = _get_unexposed_collectives(graph)
1245+
else:
1246+
unexposed_collectives = [
1247+
n
1248+
for n, col_info in collective_info.items()
1249+
if col_info.hiding_node is not None
1250+
]
1251+
1252+
log_strs = []
1253+
log_strs.append(f"\n all_gathers:{all_gathers}")
1254+
log_strs.append(f"\n reduce_scatters:{reduce_scatters}")
1255+
12201256
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
12211257
reduce_scatters = [
12221258
x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives
12231259
]
1260+
for n, coll_info in collective_info.items():
1261+
log_strs.append(f"coll_info {n}: {coll_info}")
1262+
log_strs.append(f"\n unexposed_collectives:{unexposed_collectives}")
1263+
log_strs.append(f"\n all_gathers_exposed:{all_gathers}")
1264+
log_strs.append(f"\n reduce_scatters_exposed:{reduce_scatters}")
12241265

12251266
if not all_gathers and not reduce_scatters:
12261267
log.warning(
12271268
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
12281269
)
12291270

12301271
for reduce_scatter in reduce_scatters:
1231-
fuse_matmul_reduce_scatter(reduce_scatter)
1272+
fuse_matmul_reduce_scatter(reduce_scatter, log_strs)
12321273

12331274
for all_gather in all_gathers:
1234-
fuse_all_gather_matmul(all_gather)
1275+
fuse_all_gather_matmul(all_gather, log_strs)
1276+
trace_structured(
1277+
"artifact",
1278+
metadata_fn=lambda: {
1279+
"name": "asynctp_log",
1280+
"encoding": "string",
1281+
},
1282+
payload_fn=lambda: "\n".join(log_strs),
1283+
)
1284+
trace_structured(
1285+
"artifact",
1286+
metadata_fn=lambda: {
1287+
"name": "asynctp_post_graph",
1288+
"encoding": "string",
1289+
},
1290+
payload_fn=lambda: graph.owning_module.print_readable(False),
1291+
)

examples/example_llama3.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,22 +192,53 @@ def add_tp_constraints(autop):
192192
if enable_manual_constraint and not use_1d_mesh:
193193
add_tp_constraints(autop)
194194

195-
if enable_asynctp:
196-
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
195+
enable_overlap_scheduling = True
196+
enable_overlap_scheduling_bucketing = True
197+
if enable_overlap_scheduling_bucketing:
198+
assert (
199+
enable_overlap_scheduling
200+
), "bucketing can not be used without overlap scheduling"
201+
enable_asynctp = True
202+
203+
if (
204+
enable_overlap_scheduling
205+
or enable_overlap_scheduling_bucketing
206+
or enable_asynctp
207+
):
208+
torch._inductor.config.reorder_for_peak_memory = False
209+
torch._inductor.config.reorder_for_compute_comm_overlap = False
210+
torch._inductor.config.allow_buffer_reuse = False
211+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
212+
enable_overlap_scheduling_bucketing
213+
)
214+
215+
if enable_asynctp:
216+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
197217

198-
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
199-
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
200-
torch._inductor.config._micro_pipeline_tp = False
201-
from autoparallel.asynctp import micro_pipeline_tp_pass
218+
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
219+
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
220+
torch._inductor.config._micro_pipeline_tp = False
221+
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
222+
# TODO: Switch to Inductor AsyncTP passes, when all additions landed.
223+
from autoparallel.asynctp import micro_pipeline_tp_pass
202224

203225
existing_post_grad_custom_post_pass = (
204226
torch._inductor.config.post_grad_custom_post_pass
205227
)
228+
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler
206229

207230
def _pass(graph):
208231
if existing_post_grad_custom_post_pass is not None:
209232
existing_post_grad_custom_post_pass(graph)
210-
micro_pipeline_tp_pass(graph)
233+
234+
collective_info = None
235+
if enable_overlap_scheduling:
236+
overlap_scheduler = OverlapScheduler(graph.owning_module)
237+
overlap_scheduler.run()
238+
collective_info = overlap_scheduler.collective_info
239+
240+
if enable_asynctp:
241+
micro_pipeline_tp_pass(graph, collective_info)
211242

212243
torch._inductor.config.post_grad_custom_post_pass = _pass
213244

0 commit comments

Comments
 (0)