Skip to content

Commit cee0b27

Browse files
Simplify access to MCR-DL
Signed-off-by: Radha Guhane <[email protected]>
1 parent 870751f commit cee0b27

File tree

9 files changed

+77
-108
lines changed

9 files changed

+77
-108
lines changed

benchmarks/all_gather.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ def timed_all_gather(input, output, start_event, end_event, args):
6767

6868

6969
def run_all_gather(local_rank, args):
70-
if args.dist == 'torch':
71-
import torch.distributed as dist
72-
elif args.dist == 'mcr_dl':
73-
import mcr_dl as dist
70+
dist = mcr_dl.get_distributed_engine()
7471

7572
# Prepare benchmark header
7673
print_header(args, 'all_gather')
@@ -98,6 +95,7 @@ def run_all_gather(local_rank, args):
9895
# Delete original mat to avoid OOM
9996
del mat
10097
get_accelerator().empty_cache()
98+
print(f"#######All gather world size : {world_size}")
10199
output = torch.zeros(input.nelement() * world_size,
102100
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
103101
except RuntimeError as e:
@@ -149,5 +147,5 @@ def run_all_gather(local_rank, args):
149147
if __name__ == "__main__":
150148
args = benchmark_parser().parse_args()
151149
rank = args.local_rank
152-
init_processes(local_rank=rank, args=args)
150+
mcr_dl.init_processes(args.dist, args.backend)
153151
run_all_gather(local_rank=rank, args=args)

benchmarks/all_reduce.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828

2929

3030
def timed_all_reduce(input, start_event, end_event, args):
31-
if args.dist == 'torch':
32-
import torch.distributed as dist
33-
elif args.dist == 'mcr_dl':
34-
import mcr_dl as dist
31+
import mcr_dl
32+
dist = mcr_dl.get_distributed_engine()
3533

3634
sync_all()
3735
# Warmups, establish connections, etc.
@@ -62,10 +60,12 @@ def timed_all_reduce(input, start_event, end_event, args):
6260

6361

6462
def run_all_reduce(local_rank, args):
65-
if args.dist == 'torch':
66-
import torch.distributed as dist
67-
elif args.dist == 'mcr_dl':
68-
import mcr_dl as dist
63+
# if args.dist == 'torch':
64+
# import torch.distributed as dist
65+
# elif args.dist == 'mcr_dl':
66+
# import mcr_dl as dist
67+
import mcr_dl
68+
dist = mcr_dl.get_distributed_engine()
6969

7070
# Prepare benchmark header
7171
print_header(args, 'all_reduce')
@@ -125,7 +125,8 @@ def run_all_reduce(local_rank, args):
125125

126126

127127
if __name__ == "__main__":
128+
import mcr_dl
128129
args = benchmark_parser().parse_args()
129130
rank = args.local_rank
130-
init_processes(local_rank=rank, args=args)
131+
mcr_dl.init_processes(args.dist, args.backend)
131132
run_all_reduce(local_rank=rank, args=args)

benchmarks/all_to_all.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828

2929

3030
def timed_all_to_all(input, output, start_event, end_event, args):
31-
if args.dist == 'torch':
32-
import torch.distributed as dist
33-
elif args.dist == 'mcr_dl':
34-
import mcr_dl as dist
31+
dist = mcr_dl.get_distributed_engine()
3532

3633
sync_all()
3734
# Warmups, establish connections, etc.
@@ -62,10 +59,7 @@ def timed_all_to_all(input, output, start_event, end_event, args):
6259

6360

6461
def run_all_to_all(local_rank, args):
65-
if args.dist == 'torch':
66-
import torch.distributed as dist
67-
elif args.dist == 'mcr_dl':
68-
import mcr_dl as dist
62+
dist = mcr_dl.get_distributed_engine()
6963

7064
world_size = dist.get_world_size()
7165
global_rank = dist.get_rank()
@@ -147,5 +141,5 @@ def run_all_to_all(local_rank, args):
147141
if __name__ == "__main__":
148142
args = benchmark_parser().parse_args()
149143
rank = args.local_rank
150-
init_processes(local_rank=rank, args=args)
144+
mcr_dl.init_processes(args.dist, args.backend)
151145
run_all_to_all(local_rank=rank, args=args)

benchmarks/broadcast.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828

2929

3030
def timed_broadcast(input, start_event, end_event, args):
31-
if args.dist == 'torch':
32-
import torch.distributed as dist
33-
elif args.dist == 'mcr_dl':
34-
import mcr_dl as dist
31+
dist = mcr_dl.get_distributed_engine()
3532

3633
sync_all()
3734
# Warmups, establish connections, etc.
@@ -62,10 +59,7 @@ def timed_broadcast(input, start_event, end_event, args):
6259

6360

6461
def run_broadcast(local_rank, args):
65-
if args.dist == 'torch':
66-
import torch.distributed as dist
67-
elif args.dist == 'mcr_dl':
68-
import mcr_dl as dist
62+
dist = mcr_dl.get_distributed_engine()
6963

7064
# Prepare benchmark header
7165
print_header(args, 'broadcast')
@@ -125,5 +119,5 @@ def run_broadcast(local_rank, args):
125119
if __name__ == "__main__":
126120
args = benchmark_parser().parse_args()
127121
rank = args.local_rank
128-
init_processes(local_rank=rank, args=args)
122+
mcr_dl.init_processes(args.dist, args.backend)
129123
run_broadcast(local_rank=rank, args=args)

benchmarks/pt2pt.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828

2929

3030
def timed_pt2pt(input, start_event, end_event, args):
31-
if args.dist == 'torch':
32-
import torch.distributed as dist
33-
elif args.dist == 'mcr_dl':
34-
import mcr_dl as dist
31+
dist = mcr_dl.get_distributed_engine()
3532

3633
sync_all()
3734
# Warmups, establish connections, etc.
@@ -81,10 +78,7 @@ def timed_pt2pt(input, start_event, end_event, args):
8178

8279

8380
def run_pt2pt(local_rank, args):
84-
if args.dist == 'torch':
85-
import torch.distributed as dist
86-
elif args.dist == 'mcr_dl':
87-
import mcr_dl as dist
81+
dist = mcr_dl.get_distributed_engine()
8882

8983
# Prepare benchmark header
9084
print_header(args, 'pt2pt')
@@ -144,5 +138,5 @@ def run_pt2pt(local_rank, args):
144138
if __name__ == "__main__":
145139
args = benchmark_parser().parse_args()
146140
rank = args.local_rank
147-
init_processes(local_rank=rank, args=args)
141+
mcr_dl.init_processes(args.dist, args.backend)
148142
run_pt2pt(local_rank=rank, args=args)

benchmarks/run_all.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# For importing
3434
def main(args, rank):
3535

36-
init_processes(local_rank=rank, args=args)
36+
mcr_dl.init_processes(args.dist, args.backend)
3737

3838
ops_to_run = []
3939
if args.all_reduce:

benchmarks/utils.py

+6-32
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
from mcr_dl.cuda_accelerator import get_accelerator
2929
from mcr_dl.comm import mpi_discovery
3030
from mcr_dl.utils import set_mpi_dist_environemnt
31-
32-
global dist
33-
31+
import mcr_dl
3432

3533
def env2int(env_list, default=-1):
3634
for e in env_list:
@@ -39,41 +37,14 @@ def env2int(env_list, default=-1):
3937
return default
4038

4139

42-
def init_torch_distributed(backend):
43-
global dist
44-
import torch.distributed as dist
45-
if backend == 'nccl':
46-
mpi_discovery()
47-
elif backend == 'mpi':
48-
set_mpi_dist_environemnt()
49-
dist.init_process_group(backend)
50-
local_rank = int(os.environ['LOCAL_RANK'])
51-
get_accelerator().set_device(local_rank)
52-
53-
def init_mcr_dl_comm(backend):
54-
global dist
55-
import mcr_dl as dist
56-
dist.init_distributed(dist_backend=backend, use_mcr_dl=True)
57-
local_rank = int(os.environ['LOCAL_RANK'])
58-
get_accelerator().set_device(local_rank)
59-
60-
61-
def init_processes(local_rank, args):
62-
if args.dist == 'mcr_dl':
63-
init_mcr_dl_comm(args.backend)
64-
elif args.dist == 'torch':
65-
init_torch_distributed(args.backend)
66-
else:
67-
print_rank_0(f"distributed framework {args.dist} not supported")
68-
exit(0)
69-
70-
7140
def print_rank_0(message):
41+
dist = mcr_dl.get_distributed_engine()
7242
if dist.get_rank() == 0:
7343
print(message)
7444

7545

7646
def print_header(args, comm_op):
47+
dist = mcr_dl.get_distributed_engine()
7748
if comm_op == 'pt2pt':
7849
world_size = 2
7950
else:
@@ -90,6 +61,7 @@ def print_header(args, comm_op):
9061

9162

9263
def get_bw(comm_op, size, duration, args):
64+
dist = mcr_dl.get_distributed_engine()
9365
n = dist.get_world_size()
9466
tput = 0
9567
busbw = 0
@@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration):
133105

134106

135107
def sync_all():
108+
dist = mcr_dl.get_distributed_engine()
136109
get_accelerator().synchronize()
137110
dist.barrier()
138111

139112

140113
def max_numel(comm_op, dtype, mem_factor, local_rank, args):
114+
dist = mcr_dl.get_distributed_engine()
141115
dtype_size = _element_size(dtype)
142116
max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
143117
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':

mcr_dl/__init__.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,47 @@
1717
# limitations under the License.
1818

1919
from .utils import *
20-
from .comm import *
20+
from .comm import *
21+
22+
global __dist_engine
23+
global __dist_backend
24+
25+
def init_torch_distributed(backend):
26+
import torch.distributed as dist
27+
if backend == 'nccl':
28+
mpi_discovery()
29+
elif backend == 'mpi':
30+
set_mpi_dist_environemnt()
31+
dist.init_process_group(backend)
32+
local_rank = int(os.environ['LOCAL_RANK'])
33+
get_accelerator().set_device(local_rank)
34+
35+
def init_mcr_dl_comm(backend):
36+
import mcr_dl
37+
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
38+
local_rank = int(os.environ['LOCAL_RANK'])
39+
#get_accelerator().set_device(local_rank)
40+
41+
def init_processes(dist_engine, dist_backend):
42+
print(f'Comm : {dist_engine} Backend : {dist_backend}')
43+
44+
global __dist_engine
45+
global __dist_backend
46+
__dist_engine = dist_engine
47+
__dist_backend = dist_backend
48+
49+
if dist_engine == 'mcr_dl':
50+
init_mcr_dl_comm(dist_backend)
51+
elif dist_engine == 'torch':
52+
init_torch_distributed(dist_backend)
53+
else:
54+
print(f"distributed framework {dist_engine} not supported")
55+
exit(0)
56+
57+
def get_distributed_engine():
58+
global __dist_engine
59+
if __dist_engine == 'torch':
60+
return torch.distributed
61+
elif __dist_engine == 'mcr_dl':
62+
import mcr_dl
63+
return mcr_dl

tests/main.py

+4-33
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import time
2121
import argparse
2222
import torch
23-
23+
import mcr_dl
2424
from mcr_dl.constants import TORCH_DISTRIBUTED_DEFAULT_PORT
2525
from common import set_accelerator_visible
2626
from mcr_dl.cuda_accelerator import get_accelerator
@@ -32,45 +32,16 @@
3232
parser.add_argument("--dist", choices=['mcr_dl', 'torch'], help = "torch.distributed or mcr-dl for distributed")
3333
args = parser.parse_args()
3434

35-
def init_torch_distributed(backend):
36-
global dist
37-
import torch.distributed as dist
38-
if backend == 'nccl':
39-
mpi_discovery()
40-
elif backend == 'mpi':
41-
set_mpi_dist_environemnt()
42-
dist.init_process_group(backend)
43-
local_rank = int(os.environ['LOCAL_RANK'])
44-
get_accelerator().set_device(local_rank)
45-
46-
47-
def init_mcr_dl_comm(backend):
48-
global dist
49-
import mcr_dl
50-
import mcr_dl as dist
51-
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
52-
local_rank = int(os.environ['LOCAL_RANK'])
53-
#get_accelerator().set_device(local_rank)
54-
55-
56-
def init_processes(args_dist, args_backend):
57-
print(f'Comm : {args_dist} Backend : {args_backend}')
58-
if args_dist == 'mcr_dl':
59-
init_mcr_dl_comm(args_backend)
60-
elif args_dist == 'torch':
61-
init_torch_distributed(args_backend)
62-
else:
63-
print(f"distributed framework {args_dist} not supported")
64-
exit(0)
65-
6635
def all_reduce():
36+
dist = mcr_dl.get_distributed_engine()
6737
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
6838
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
6939
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
7040
dist.all_reduce(x)
7141
assert torch.all(x == result)
7242

7343
def all_reduce_benchmark():
44+
dist = mcr_dl.get_distributed_engine()
7445
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)]
7546
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)]
7647
rank = dist.get_rank()
@@ -101,7 +72,7 @@ def all_reduce_benchmark():
10172

10273
if __name__ == "__main__":
10374
set_accelerator_visible()
104-
init_processes(args_dist=args.dist, args_backend=args.backend)
75+
mcr_dl.init_processes(dist_engine = args.dist, dist_backend = args.backend)
10576
# all_reduce()
10677
all_reduce_benchmark()
10778

0 commit comments

Comments
 (0)