forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_functional_collectives_impl.py
117 lines (95 loc) · 3.15 KB
/
_functional_collectives_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.distributed.distributed_c10d as c10d
"""
This file contains the op impls for the legacy (c10d_functional) functional collectives.
These impls simply call into the native (_c10d_functional) functional collectives.
"""
def _broadcast(input, src, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.broadcast(
input,
src,
group_name,
)
def _all_reduce(input, reduce_op, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_reduce(
input,
reduce_op,
group_name,
)
def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_reduce_coalesced(
inputs,
reduce_op,
group_name,
)
def _all_gather_into_tensor(input, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_gather_into_tensor(
input,
group_size,
group_name,
)
def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
input,
group_size,
group_name,
)
def _reduce_scatter_tensor(
input: torch.Tensor,
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.reduce_scatter_tensor(
input,
reduce_op,
group_size,
group_name,
)
def _reduce_scatter_tensor_coalesced(
inputs: List[torch.Tensor],
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
inputs,
reduce_op,
group_size,
group_name,
)
def _all_to_all_single(
input: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
tag: str,
ranks: List[int],
group_size: int,
):
if output_split_sizes is None or input_split_sizes is None:
assert output_split_sizes is None and input_split_sizes is None, (
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
output_split_sizes = [input.shape[0] // group_size] * group_size
input_split_sizes = output_split_sizes
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
return torch.ops._c10d_functional.all_to_all_single(
input,
output_split_sizes,
input_split_sizes,
group_name,
)
def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._c10d_functional.wait_tensor(tensor)