Skip to content

Commit 77f9701

Browse files
wconstabpytorchmergebot
authored andcommitted
Dynamo remaps legacy allgather to traceable one (pytorch#102232)
Pull Request resolved: pytorch#102232 Approved by: https://github.com/voznesenskym
1 parent c58264c commit 77f9701

File tree

4 files changed

+158
-1
lines changed

4 files changed

+158
-1
lines changed

test/distributed/test_inductor_collectives.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,56 @@ def func(inp, *, pg):
375375
self.assertEqual(counter.op_count, 2)
376376
self.assertTrue(same(out, correct))
377377

378+
def test_dynamo_rewrite_dist_all_gather(self):
379+
380+
def func(inp, out, *, pg):
381+
torch.distributed.all_gather_into_tensor(
382+
out,
383+
inp,
384+
pg,
385+
)
386+
local_size = [4, 4]
387+
# single-proc test
388+
global_size = local_size
389+
390+
inputs = torch.ones(local_size, device=self.device)
391+
outputs = torch.empty(global_size, device=self.device)
392+
correct_outputs = torch.empty(global_size, device=self.device)
393+
counter = CompileCounter()
394+
compiled = torch.compile(func, backend=counter, fullgraph=True)
395+
compiled(inputs, outputs, pg=GroupMember.WORLD)
396+
func(inputs, correct_outputs, pg=GroupMember.WORLD)
397+
assert counter.frame_count == 1
398+
399+
# should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
400+
assert counter.op_count == 3
401+
assert same(outputs, correct_outputs)
402+
403+
def test_dynamo_rewrite_dist_reduce_scatter(self):
404+
405+
def func(inp, out, *, pg):
406+
torch.distributed.reduce_scatter_tensor(
407+
out,
408+
inp,
409+
group=pg,
410+
)
411+
local_size = [4, 4]
412+
# single-proc test
413+
global_size = local_size
414+
415+
inputs = torch.ones(local_size, device=self.device)
416+
outputs = torch.empty(global_size, device=self.device)
417+
correct_outputs = torch.empty(global_size, device=self.device)
418+
counter = CompileCounter()
419+
compiled = torch.compile(func, backend=counter, fullgraph=True)
420+
compiled(inputs, outputs, pg=GroupMember.WORLD)
421+
func(inputs, correct_outputs, pg=GroupMember.WORLD)
422+
assert counter.frame_count == 1
423+
424+
# should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_)
425+
assert counter.op_count == 3
426+
assert same(outputs, correct_outputs)
427+
378428
def test_dynamo_graphbreaks_unsupported_async_op(self):
379429

380430
def func(inp, out, *, pg):

torch/_dynamo/variables/builder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@
7272
DefaultDictVariable,
7373
HFPretrainedConfigVariable,
7474
)
75-
from .functions import UserFunctionVariable, UserMethodVariable
75+
from .functions import (
76+
CollectiveFunctionRewriteVariable,
77+
UserFunctionVariable,
78+
UserMethodVariable,
79+
)
7680
from .lists import (
7781
ListVariable,
7882
NamedTupleVariable,
@@ -447,6 +451,13 @@ def index_source(key):
447451
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
448452
)
449453
# NB: These can't be put in type_dispatch, they have to run later
454+
elif CollectiveFunctionRewriteVariable.can_rewrite(value):
455+
return CollectiveFunctionRewriteVariable(
456+
CollectiveFunctionRewriteVariable.rewrite(value),
457+
orig_fn=value,
458+
source=self.source,
459+
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
460+
)
450461
elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
451462
return UserFunctionVariable(
452463
value,

torch/_dynamo/variables/functions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,55 @@ def reconstruct(self, codegen):
534534
codegen.extend_output(create_call_function(1, True))
535535

536536
return []
537+
538+
539+
def _traceable_collective_remaps():
540+
# We can't rely on importing from distributed, since its not always built
541+
if torch.distributed.is_available():
542+
from torch.distributed._functional_collectives import (
543+
traceable_collective_remaps,
544+
)
545+
546+
return traceable_collective_remaps
547+
return {}
548+
549+
550+
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
551+
"""
552+
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
553+
554+
This class provides both a way to check if a function is remappable, and perform the remapping.
555+
556+
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
557+
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
558+
than status-quo as we currently graph-break on all distributed.* collectives.
559+
"""
560+
561+
def __init__(self, fn, *, orig_fn, **kwargs):
562+
# orig_fn lets us implement any fn-specific args/kwargs restrictions inside call_function
563+
self.orig_fn = orig_fn
564+
565+
# remapped_fn gets stuffed in self.fn and used in super().call_function
566+
super().__init__(fn, **kwargs)
567+
568+
@staticmethod
569+
def can_rewrite(variable):
570+
return (
571+
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
572+
)
573+
574+
@staticmethod
575+
def rewrite(fn):
576+
return _traceable_collective_remaps()[fn]
577+
578+
def call_function(
579+
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
580+
) -> "VariableTracker":
581+
# call_function must check any unsupported arguments and graph-break.
582+
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
583+
# since that's the contract for putting a mapping in `traceable_collective_remaps`
584+
if kwargs.get("async_op", False):
585+
unimplemented(
586+
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.orig_fn}"
587+
)
588+
return super().call_function(tx, args, kwargs)

torch/distributed/_functional_collectives.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,47 @@ def _register_ops():
517517
_register_ops()
518518
else:
519519
warnings.warn("PyTorch Distributed functional collectives do not work with torch::deploy.")
520+
521+
# We allow torchdynamo to convert calls from legacy inplace APIs into traceable APIs
522+
# via a pseudo-inplace version (like a decomp) that uses the functional collective
523+
# and a copy.
524+
#
525+
# These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
526+
def all_gather_tensor_inplace(
527+
output: torch.Tensor,
528+
input: torch.Tensor,
529+
group, # TODO add a type,
530+
async_op: bool = False,
531+
tag: str = "",
532+
gather_dim: int = 0
533+
):
534+
assert not async_op, "Can't remap async version of inplace op to functional collective"
535+
return output.copy_(all_gather_tensor(input, gather_dim, group, tag))
536+
537+
def reduce_scatter_tensor_inplace(
538+
output: torch.Tensor,
539+
input: torch.Tensor,
540+
op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?
541+
group=None, # TODO add a type
542+
async_op: bool = False,
543+
scatter_dim: int = 0,
544+
tag: str = "",
545+
):
546+
assert not async_op, "Can't remap async version of inplace op to functional collective"
547+
return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
548+
549+
from torch.distributed.distributed_c10d import (
550+
all_gather_into_tensor as legacy_allgather,
551+
reduce_scatter_tensor as legacy_reducescatter,
552+
)
553+
554+
"""
555+
This dict should contain sets of functions that dynamo is allowed to remap.
556+
557+
Functions in this set should accept the same args/kwargs 1:1 as their mapping.
558+
"""
559+
560+
traceable_collective_remaps = {
561+
legacy_allgather: all_gather_tensor_inplace,
562+
legacy_reducescatter: reduce_scatter_tensor_inplace,
563+
}

0 commit comments

Comments
 (0)