1313
1414import torch
1515from torch ._inductor import inductor_prims
16+ from torch ._inductor .fx_passes .overlap_scheduling import CollectiveInfo
1617from torch ._inductor .pattern_matcher import (
1718 MULTIPLE ,
1819 CallFunction ,
2324 PatternExpr ,
2425 PatternMatcherPass ,
2526)
27+ from torch ._logging import trace_structured
2628from torch .utils ._ordered_set import OrderedSet
2729
2830import autoparallel .asynctp_ops # noqa: F401
3436_micro_pipeline_tp_ag_transpose_mm_enabled = True
3537
3638# Check performance if overhead of decomposition outweights pipeline wins
37- _micro_pipeline_tp_ag_mm_last_dim_enabled = False
39+ _micro_pipeline_tp_ag_mm_last_dim_enabled = True
3840
3941_micro_pipeline_tp_mm_rs_last_dim_enabled = True
4042
@@ -720,7 +722,7 @@ def _insert_fused_all_gather_transpose_matmul(
720722 raise AssertionError (f"Unexpected matmul match type: { mm_type } " )
721723
722724
723- def fuse_all_gather_matmul (all_gather : _AllGatherMatch ) -> None :
725+ def fuse_all_gather_matmul (all_gather : _AllGatherMatch , log_strs ) -> None :
724726 """
725727 Fused the pattern
726728
@@ -755,6 +757,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
755757 all_gather .group_name ,
756758 )
757759
760+ log_strs .append (f"fuse_agmm { all_gather } " )
758761 if not is_symm_mem_enabled_for_group (group_name ):
759762 return
760763
@@ -774,6 +777,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
774777 for matmul in matmuls
775778 if all_gather .res_node not in matmul .arg_ancestor_nodes
776779 ]
780+ log_strs .append (f"fuse_agmm matmuls:{ matmuls } " )
777781
778782 if len (matmuls ) == 0 or len (OrderedSet (map (type , matmuls ))) != 1 :
779783 return
@@ -870,6 +874,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
870874 for node in nodes_to_raise :
871875 if order [node ] > order [fused_node ]:
872876 fused_node .prepend (node )
877+ log_strs .append ("fuse_agmm DONE" )
873878
874879
875880def _scatter_dim_after_reshape (
@@ -990,7 +995,7 @@ def _insert_fused_matmul_reduce_scatter(
990995 raise AssertionError (f"Unexpected matmul match type: { type (matmul )} " )
991996
992997
993- def fuse_matmul_reduce_scatter (reduce_scatter : _ReduceScatterMatch ) -> None :
998+ def fuse_matmul_reduce_scatter (reduce_scatter : _ReduceScatterMatch , log_strs ) -> None :
994999 """
9951000 Fused the pattern
9961001
@@ -1004,6 +1009,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10041009
10051010 Returns boolean indicating if fusion was successful or not.
10061011 """
1012+ log_strs .append (f"fuse_mmrs { reduce_scatter } " )
10071013 if (
10081014 not torch .distributed .is_available ()
10091015 or not torch .distributed .is_nccl_available ()
@@ -1032,6 +1038,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10321038 )
10331039
10341040 if not is_symm_mem_enabled_for_group (group_name ):
1041+ log_strs .append ("fuse_mmrs not symm mem group" )
10351042 return
10361043
10371044 if (
@@ -1048,16 +1055,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
10481055 log .warning (
10491056 "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
10501057 )
1058+ log_strs .append ("fuse_mmrs input.node.users != 1" )
10511059 return
10521060
10531061 matmul = _find_producer_matmul (input_node )
10541062 if matmul is None :
1063+ log_strs .append ("fuse_mmrs no matmul" )
10551064 log .warning (
10561065 "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
10571066 )
10581067 return
10591068
10601069 if rs_wait_tensor_node in matmul .arg_ancestor_nodes :
1070+ log_strs .append ("fuse_mmrs wait in matmul.arg_ancestors" )
10611071 log .warning (
10621072 "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
10631073 )
@@ -1123,6 +1133,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
11231133 if order [node ] > order [fused_node ]:
11241134 fused_node .prepend (node )
11251135
1136+ log_strs .append ("fuse_mmrs DONE" )
11261137 log .debug ("successfully fused matmul reduce scatter" )
11271138
11281139
@@ -1173,6 +1184,7 @@ def is_collective(node) -> bool:
11731184 return collective_to_overlappable_nodes
11741185
11751186
1187+ # TODO: Convert return type to set
11761188def _get_unexposed_collectives (graph : torch .fx .Graph ) -> list [torch .fx .Node ]:
11771189 """
11781190 Find all unexposed collectives in the graph.
@@ -1209,26 +1221,71 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool:
12091221 return unexposed_collectives
12101222
12111223
1212- def micro_pipeline_tp_pass (graph : torch .fx .Graph ):
1224+ def micro_pipeline_tp_pass (
1225+ graph : torch .fx .Graph ,
1226+ collective_info : Optional [dict [torch .fx .Node , CollectiveInfo ]] = None ,
1227+ ):
1228+ trace_structured (
1229+ "artifact" ,
1230+ metadata_fn = lambda : {
1231+ "name" : "asynctp_pre_graph" ,
1232+ "encoding" : "string" ,
1233+ },
1234+ payload_fn = lambda : graph .owning_module .print_readable (False ),
1235+ )
12131236 all_gathers = find_all_gather_patterns (graph )
12141237 reduce_scatters = find_reduce_scatter_patterns (graph )
12151238
12161239 # When a collective can be hidden through either simple overlapping or
12171240 # micro-pipeline TP, we prefer simple overlapping to avoid the overhead
12181241 # associated with decomposition.
1219- unexposed_collectives = _get_unexposed_collectives (graph )
1242+ # Known problem: collective info contains only pre-bucketing collectives
1243+ if collective_info is None :
1244+ unexposed_collectives = _get_unexposed_collectives (graph )
1245+ else :
1246+ unexposed_collectives = [
1247+ n
1248+ for n , col_info in collective_info .items ()
1249+ if col_info .hiding_node is not None
1250+ ]
1251+
1252+ log_strs = []
1253+ log_strs .append (f"\n all_gathers:{ all_gathers } " )
1254+ log_strs .append (f"\n reduce_scatters:{ reduce_scatters } " )
1255+
12201256 all_gathers = [x for x in all_gathers if x .ag_node not in unexposed_collectives ]
12211257 reduce_scatters = [
12221258 x for x in reduce_scatters if x .reduce_scatter_node not in unexposed_collectives
12231259 ]
1260+ for n , coll_info in collective_info .items ():
1261+ log_strs .append (f"coll_info { n } : { coll_info } " )
1262+ log_strs .append (f"\n unexposed_collectives:{ unexposed_collectives } " )
1263+ log_strs .append (f"\n all_gathers_exposed:{ all_gathers } " )
1264+ log_strs .append (f"\n reduce_scatters_exposed:{ reduce_scatters } " )
12241265
12251266 if not all_gathers and not reduce_scatters :
12261267 log .warning (
12271268 "async TP found no matching all-gather/reduce-scatter patterns for fusion"
12281269 )
12291270
12301271 for reduce_scatter in reduce_scatters :
1231- fuse_matmul_reduce_scatter (reduce_scatter )
1272+ fuse_matmul_reduce_scatter (reduce_scatter , log_strs )
12321273
12331274 for all_gather in all_gathers :
1234- fuse_all_gather_matmul (all_gather )
1275+ fuse_all_gather_matmul (all_gather , log_strs )
1276+ trace_structured (
1277+ "artifact" ,
1278+ metadata_fn = lambda : {
1279+ "name" : "asynctp_log" ,
1280+ "encoding" : "string" ,
1281+ },
1282+ payload_fn = lambda : "\n " .join (log_strs ),
1283+ )
1284+ trace_structured (
1285+ "artifact" ,
1286+ metadata_fn = lambda : {
1287+ "name" : "asynctp_post_graph" ,
1288+ "encoding" : "string" ,
1289+ },
1290+ payload_fn = lambda : graph .owning_module .print_readable (False ),
1291+ )
0 commit comments