28
28
from mcr_dl .cuda_accelerator import get_accelerator
29
29
from mcr_dl .comm import mpi_discovery
30
30
from mcr_dl .utils import set_mpi_dist_environemnt
31
-
32
- global dist
33
-
31
+ import mcr_dl
34
32
35
33
def env2int (env_list , default = - 1 ):
36
34
for e in env_list :
@@ -39,41 +37,14 @@ def env2int(env_list, default=-1):
39
37
return default
40
38
41
39
42
- def init_torch_distributed (backend ):
43
- global dist
44
- import torch .distributed as dist
45
- if backend == 'nccl' :
46
- mpi_discovery ()
47
- elif backend == 'mpi' :
48
- set_mpi_dist_environemnt ()
49
- dist .init_process_group (backend )
50
- local_rank = int (os .environ ['LOCAL_RANK' ])
51
- get_accelerator ().set_device (local_rank )
52
-
53
- def init_mcr_dl_comm (backend ):
54
- global dist
55
- import mcr_dl as dist
56
- dist .init_distributed (dist_backend = backend , use_mcr_dl = True )
57
- local_rank = int (os .environ ['LOCAL_RANK' ])
58
- get_accelerator ().set_device (local_rank )
59
-
60
-
61
- def init_processes (local_rank , args ):
62
- if args .dist == 'mcr_dl' :
63
- init_mcr_dl_comm (args .backend )
64
- elif args .dist == 'torch' :
65
- init_torch_distributed (args .backend )
66
- else :
67
- print_rank_0 (f"distributed framework { args .dist } not supported" )
68
- exit (0 )
69
-
70
-
71
40
def print_rank_0 (message ):
41
+ dist = mcr_dl .get_distributed_engine ()
72
42
if dist .get_rank () == 0 :
73
43
print (message )
74
44
75
45
76
46
def print_header (args , comm_op ):
47
+ dist = mcr_dl .get_distributed_engine ()
77
48
if comm_op == 'pt2pt' :
78
49
world_size = 2
79
50
else :
@@ -90,6 +61,7 @@ def print_header(args, comm_op):
90
61
91
62
92
63
def get_bw (comm_op , size , duration , args ):
64
+ dist = mcr_dl .get_distributed_engine ()
93
65
n = dist .get_world_size ()
94
66
tput = 0
95
67
busbw = 0
@@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration):
133
105
134
106
135
107
def sync_all ():
108
+ dist = mcr_dl .get_distributed_engine ()
136
109
get_accelerator ().synchronize ()
137
110
dist .barrier ()
138
111
139
112
140
113
def max_numel (comm_op , dtype , mem_factor , local_rank , args ):
114
+ dist = mcr_dl .get_distributed_engine ()
141
115
dtype_size = _element_size (dtype )
142
116
max_memory_per_gpu = get_accelerator ().total_memory (local_rank ) * mem_factor
143
117
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast' :
0 commit comments