Skip to content

Commit 1982e59

Browse files
mcr_dl_megatron changes
Signed-off-by: Radha Guhane <[email protected]>
1 parent cee0b27 commit 1982e59

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

mcr_dl/__init__.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,33 @@
2222
global __dist_engine
2323
global __dist_backend
2424

25+
__dist_engine = None
26+
__dist_backend = None
27+
2528
def init_torch_distributed(backend):
2629
import torch.distributed as dist
2730
if backend == 'nccl':
2831
mpi_discovery()
2932
elif backend == 'mpi':
3033
set_mpi_dist_environemnt()
31-
dist.init_process_group(backend)
34+
dist.init_process_group(backend=backend)
3235
local_rank = int(os.environ['LOCAL_RANK'])
33-
get_accelerator().set_device(local_rank)
36+
# get_accelerator().set_device(local_rank)
37+
print(f'Rank : {dist.get_rank()} World_Size : {dist.get_world_size()}', flush = True)
3438

3539
def init_mcr_dl_comm(backend):
3640
import mcr_dl
3741
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
3842
local_rank = int(os.environ['LOCAL_RANK'])
3943
#get_accelerator().set_device(local_rank)
4044

41-
def init_processes(dist_engine, dist_backend):
45+
def init_processes(dist_engine, dist_backend, world_size = -1, rank = -1, timeout = None, init_method = None):
4246
print(f'Comm : {dist_engine} Backend : {dist_backend}')
4347

4448
global __dist_engine
4549
global __dist_backend
4650
__dist_engine = dist_engine
4751
__dist_backend = dist_backend
48-
4952
if dist_engine == 'mcr_dl':
5053
init_mcr_dl_comm(dist_backend)
5154
elif dist_engine == 'torch':
@@ -56,8 +59,12 @@ def init_processes(dist_engine, dist_backend):
5659

5760
def get_distributed_engine():
5861
global __dist_engine
62+
if __dist_engine is None:
63+
return None
5964
if __dist_engine == 'torch':
6065
return torch.distributed
6166
elif __dist_engine == 'mcr_dl':
6267
import mcr_dl
63-
return mcr_dl
68+
return mcr_dl
69+
print(f"Unsupported values for __dist_engine. Expected values 'torch' or 'mcr_dl'")
70+
exit(0)

mcr_dl/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
#############################################
7171
# Torch distributed constants
7272
#############################################
73-
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
73+
TORCH_DISTRIBUTED_DEFAULT_PORT = 29600
7474

7575
# Default process group wide timeout, if applicable.
7676
# This only applies to the gloo and nccl backends

mcr_dl/mpi.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ def destroy_process_group(self, group=None):
7272
pass
7373

7474
def new_group(self, ranks):
75-
# TODO: Change this to use comm_op.new_group when the impl. is ready.
75+
# TODO: Change this to use self.mpi_comm_op.new_group(ranks) when the impl. is ready.
7676
if not torch.distributed.is_initialized():
7777
from mcr_dl.torch import TorchBackend
78-
d = TorchBackend(rank=self.rank, size=self.size)
78+
d = TorchBackend(rank=self.rank, world_size=self.size)
7979
logger.info(f"new group called with {ranks}")
8080
return torch.distributed.new_group(ranks)
81+
# return self.mpi_comm_op.new_group(ranks)
8182

8283
def get_rank(self, group=None):
8384
return self.mpi_comm_op.get_rank(0)

mcr_dl/torch.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .utils import *
2424
from .backend import *
2525
from .comm import *
26+
from .constants import default_pg_timeout
2627

2728
DS_COMM_ALL_GATHER_OFF = False
2829
DS_COMM_REDUCE_SCATTER_OFF = False
@@ -119,7 +120,7 @@ class TorchBackend(Backend):
119120
needed.
120121
"""
121122

122-
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
123+
def __init__(self, backend="mpi", init_method = None, timeout = default_pg_timeout, rank=-1, world_size=-1, name='torch'):
123124
super(TorchBackend, self).__init__()
124125
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
125126
self.has_coalescing_manager = has_coalescing_manager()
@@ -131,7 +132,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
131132
# The idea is to fake that dist backend is initialized even when
132133
# it is not so we can run on a single GPU without doing any init_process_group
133134
self.single_gpu_mode = True
134-
self.init_process_group(backend, timeout, init_method, rank, world_size)
135+
self.init_process_group(backend=backend, init_method=init_method, timeout= timeout, rank=rank, world_size= world_size)
135136

136137
@classmethod
137138
def get_all_gather_function(self):

0 commit comments

Comments
 (0)