Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions autoparallel/asynctp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
from torch._inductor import inductor_prims
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
from torch._inductor.pattern_matcher import (
MULTIPLE,
CallFunction,
Expand All @@ -23,6 +24,7 @@
PatternExpr,
PatternMatcherPass,
)
from torch._logging import trace_structured
from torch.utils._ordered_set import OrderedSet

import autoparallel.asynctp_ops # noqa: F401
Expand All @@ -34,7 +36,7 @@
_micro_pipeline_tp_ag_transpose_mm_enabled = True

# Check performance if overhead of decomposition outweights pipeline wins
_micro_pipeline_tp_ag_mm_last_dim_enabled = False
_micro_pipeline_tp_ag_mm_last_dim_enabled = True

_micro_pipeline_tp_mm_rs_last_dim_enabled = True

Expand Down Expand Up @@ -720,7 +722,7 @@ def _insert_fused_all_gather_transpose_matmul(
raise AssertionError(f"Unexpected matmul match type: {mm_type}")


def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
"""
Fused the pattern

Expand Down Expand Up @@ -755,6 +757,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
all_gather.group_name,
)

log_strs.append(f"fuse_agmm {all_gather}")
if not is_symm_mem_enabled_for_group(group_name):
return

Expand All @@ -774,6 +777,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
for matmul in matmuls
if all_gather.res_node not in matmul.arg_ancestor_nodes
]
log_strs.append(f"fuse_agmm matmuls:{matmuls}")

if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1:
return
Expand Down Expand Up @@ -870,6 +874,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
for node in nodes_to_raise:
if order[node] > order[fused_node]:
fused_node.prepend(node)
log_strs.append("fuse_agmm DONE")


def _scatter_dim_after_reshape(
Expand Down Expand Up @@ -990,7 +995,7 @@ def _insert_fused_matmul_reduce_scatter(
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")


def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> None:
"""
Fused the pattern

Expand All @@ -1004,6 +1009,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:

Returns boolean indicating if fusion was successful or not.
"""
log_strs.append(f"fuse_mmrs {reduce_scatter}")
if (
not torch.distributed.is_available()
or not torch.distributed.is_nccl_available()
Expand Down Expand Up @@ -1032,6 +1038,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
)

if not is_symm_mem_enabled_for_group(group_name):
log_strs.append("fuse_mmrs not symm mem group")
return

if (
Expand All @@ -1048,16 +1055,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
log.warning(
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
)
log_strs.append("fuse_mmrs input.node.users != 1")
return

matmul = _find_producer_matmul(input_node)
if matmul is None:
log_strs.append("fuse_mmrs no matmul")
log.warning(
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
)
return

if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
log_strs.append("fuse_mmrs wait in matmul.arg_ancestors")
log.warning(
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
)
Expand Down Expand Up @@ -1123,6 +1133,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
if order[node] > order[fused_node]:
fused_node.prepend(node)

log_strs.append("fuse_mmrs DONE")
log.debug("successfully fused matmul reduce scatter")


Expand Down Expand Up @@ -1173,6 +1184,7 @@ def is_collective(node) -> bool:
return collective_to_overlappable_nodes


# TODO: Convert return type to set
def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]:
"""
Find all unexposed collectives in the graph.
Expand Down Expand Up @@ -1209,26 +1221,71 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool:
return unexposed_collectives


def micro_pipeline_tp_pass(graph: torch.fx.Graph):
def micro_pipeline_tp_pass(
graph: torch.fx.Graph,
collective_info: Optional[dict[torch.fx.Node, CollectiveInfo]] = None,
):
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_pre_graph",
"encoding": "string",
},
payload_fn=lambda: graph.owning_module.print_readable(False),
)
all_gathers = find_all_gather_patterns(graph)
reduce_scatters = find_reduce_scatter_patterns(graph)

# When a collective can be hidden through either simple overlapping or
# micro-pipeline TP, we prefer simple overlapping to avoid the overhead
# associated with decomposition.
unexposed_collectives = _get_unexposed_collectives(graph)
# Known problem: collective info contains only pre-bucketing collectives
if collective_info is None:
unexposed_collectives = _get_unexposed_collectives(graph)
else:
unexposed_collectives = [
n
for n, col_info in collective_info.items()
if col_info.hiding_node is not None
]

log_strs = []
log_strs.append(f"\n all_gathers:{all_gathers}")
log_strs.append(f"\n reduce_scatters:{reduce_scatters}")

all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
reduce_scatters = [
x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives
]
for n, coll_info in collective_info.items():
log_strs.append(f"coll_info {n}: {coll_info}")
log_strs.append(f"\n unexposed_collectives:{unexposed_collectives}")
log_strs.append(f"\n all_gathers_exposed:{all_gathers}")
log_strs.append(f"\n reduce_scatters_exposed:{reduce_scatters}")

if not all_gathers and not reduce_scatters:
log.warning(
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
)

for reduce_scatter in reduce_scatters:
fuse_matmul_reduce_scatter(reduce_scatter)
fuse_matmul_reduce_scatter(reduce_scatter, log_strs)

for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather)
fuse_all_gather_matmul(all_gather, log_strs)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_log",
"encoding": "string",
},
payload_fn=lambda: "\n".join(log_strs),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "asynctp_post_graph",
"encoding": "string",
},
payload_fn=lambda: graph.owning_module.print_readable(False),
)
45 changes: 38 additions & 7 deletions examples/example_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,22 +192,53 @@ def add_tp_constraints(autop):
if enable_manual_constraint and not use_1d_mesh:
add_tp_constraints(autop)

if enable_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
enable_overlap_scheduling = True
enable_overlap_scheduling_bucketing = True
if enable_overlap_scheduling_bucketing:
assert (
enable_overlap_scheduling
), "bucketing can not be used without overlap scheduling"
enable_asynctp = True

if (
enable_overlap_scheduling
or enable_overlap_scheduling_bucketing
or enable_asynctp
):
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.allow_buffer_reuse = False
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
enable_overlap_scheduling_bucketing
)

if enable_asynctp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
from autoparallel.asynctp import micro_pipeline_tp_pass
enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
enable_symm_mem_for_group(mesh["dp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = False
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
# TODO: Switch to Inductor AsyncTP passes, when all additions landed.
from autoparallel.asynctp import micro_pipeline_tp_pass

existing_post_grad_custom_post_pass = (
torch._inductor.config.post_grad_custom_post_pass
)
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler

def _pass(graph):
if existing_post_grad_custom_post_pass is not None:
existing_post_grad_custom_post_pass(graph)
micro_pipeline_tp_pass(graph)

collective_info = None
if enable_overlap_scheduling:
overlap_scheduler = OverlapScheduler(graph.owning_module)
overlap_scheduler.run()
collective_info = overlap_scheduler.collective_info

if enable_asynctp:
micro_pipeline_tp_pass(graph, collective_info)

torch._inductor.config.post_grad_custom_post_pass = _pass

Expand Down
Loading