-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathall_gather.py
151 lines (130 loc) · 5.88 KB
/
all_gather.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright 2023, The Ohio State University. All rights reserved.
# The MVAPICH software package is developed by the team members of
# The Ohio State University's Network-Based Computing Laboratory (NBCL),
# headed by Professor Dhabaleswar K. (DK) Panda.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import sys, os, time
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
from utils import *
from constants import *
from mcr_dl.cuda_accelerator import get_accelerator
from mcr_dl import TorchBackend
# Run all_gather and print metrics
def timed_all_gather(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
all_gather_func = TorchBackend.get_all_gather_function()
elif args.dist == 'mcr_dl':
import mcr_dl as dist
all_gather_func = dist.allgather_fn
sync_all()
# Warmups, establish connections, etc.
for i in range(args.warmups):
all_gather_func(output, input, group=None, async_op=args.async_op)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
all_gather_func(output, input, group=None, async_op=args.async_op)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
tput, busbw = get_bw('all_gather', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
if not args.raw:
size = convert_size(size)
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_all_gather(local_rank, args):
dist = mcr_dl.get_distributed_engine()
# Prepare benchmark header
print_header(args, 'all_gather')
global_rank = dist.get_rank()
world_size = dist.get_world_size()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.scan:
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
print(f"#######All gather world size : {world_size}")
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
timed_all_gather(input, output, start_event, end_event, args)
else:
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' and TorchBackend.has_all_gather_into_tensor) or (args.dist == 'mcr_dl' and dist.has_all_gather_into_tensor())):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
sync_all()
elements_per_gpu = max_numel(comm_op='all_gather',
dtype=getattr(torch, args.dtype),
mem_factor=mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
else:
raise e
sync_all()
timed_all_gather(input, output, start_event, end_event, args)
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
mcr_dl.init_processes(args.dist, args.backend)
run_all_gather(local_rank=rank, args=args)