diff --git a/benchmarks/all_gather.py b/benchmarks/all_gather.py index 59ed3247..93577704 100644 --- a/benchmarks/all_gather.py +++ b/benchmarks/all_gather.py @@ -67,10 +67,7 @@ def timed_all_gather(input, output, start_event, end_event, args): def run_all_gather(local_rank, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() # Prepare benchmark header print_header(args, 'all_gather') @@ -149,5 +146,5 @@ def run_all_gather(local_rank, args): if __name__ == "__main__": args = benchmark_parser().parse_args() rank = args.local_rank - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) run_all_gather(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/all_reduce.py b/benchmarks/all_reduce.py index 041fc3a9..dcb146b7 100644 --- a/benchmarks/all_reduce.py +++ b/benchmarks/all_reduce.py @@ -28,10 +28,8 @@ def timed_all_reduce(input, start_event, end_event, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + import mcr_dl + dist = mcr_dl.get_distributed_engine() sync_all() # Warmups, establish connections, etc. @@ -62,10 +60,8 @@ def timed_all_reduce(input, start_event, end_event, args): def run_all_reduce(local_rank, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + import mcr_dl + dist = mcr_dl.get_distributed_engine() # Prepare benchmark header print_header(args, 'all_reduce') @@ -125,7 +121,8 @@ def run_all_reduce(local_rank, args): if __name__ == "__main__": + import mcr_dl args = benchmark_parser().parse_args() rank = args.local_rank - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) run_all_reduce(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/all_to_all.py b/benchmarks/all_to_all.py index 3317d1e4..1c7def55 100644 --- a/benchmarks/all_to_all.py +++ b/benchmarks/all_to_all.py @@ -28,10 +28,7 @@ def timed_all_to_all(input, output, start_event, end_event, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() sync_all() # Warmups, establish connections, etc. @@ -62,10 +59,7 @@ def timed_all_to_all(input, output, start_event, end_event, args): def run_all_to_all(local_rank, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() world_size = dist.get_world_size() global_rank = dist.get_rank() @@ -147,5 +141,5 @@ def run_all_to_all(local_rank, args): if __name__ == "__main__": args = benchmark_parser().parse_args() rank = args.local_rank - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) run_all_to_all(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/broadcast.py b/benchmarks/broadcast.py index f7ca6806..0e3ba250 100644 --- a/benchmarks/broadcast.py +++ b/benchmarks/broadcast.py @@ -28,10 +28,7 @@ def timed_broadcast(input, start_event, end_event, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() sync_all() # Warmups, establish connections, etc. @@ -62,10 +59,7 @@ def timed_broadcast(input, start_event, end_event, args): def run_broadcast(local_rank, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() # Prepare benchmark header print_header(args, 'broadcast') @@ -125,5 +119,5 @@ def run_broadcast(local_rank, args): if __name__ == "__main__": args = benchmark_parser().parse_args() rank = args.local_rank - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) run_broadcast(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/pt2pt.py b/benchmarks/pt2pt.py index 9b51a8b0..5a680d92 100644 --- a/benchmarks/pt2pt.py +++ b/benchmarks/pt2pt.py @@ -28,10 +28,7 @@ def timed_pt2pt(input, start_event, end_event, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() sync_all() # Warmups, establish connections, etc. @@ -81,10 +78,7 @@ def timed_pt2pt(input, start_event, end_event, args): def run_pt2pt(local_rank, args): - if args.dist == 'torch': - import torch.distributed as dist - elif args.dist == 'mcr_dl': - import mcr_dl as dist + dist = mcr_dl.get_distributed_engine() # Prepare benchmark header print_header(args, 'pt2pt') @@ -144,5 +138,5 @@ def run_pt2pt(local_rank, args): if __name__ == "__main__": args = benchmark_parser().parse_args() rank = args.local_rank - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) run_pt2pt(local_rank=rank, args=args) \ No newline at end of file diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py index f5695310..ef6adbd6 100644 --- a/benchmarks/run_all.py +++ b/benchmarks/run_all.py @@ -33,7 +33,7 @@ # For importing def main(args, rank): - init_processes(local_rank=rank, args=args) + mcr_dl.init_processes(args.dist, args.backend) ops_to_run = [] if args.all_reduce: diff --git a/benchmarks/utils.py b/benchmarks/utils.py index aca7bafe..35cf33ea 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -28,9 +28,7 @@ from mcr_dl.cuda_accelerator import get_accelerator from mcr_dl.comm import mpi_discovery from mcr_dl.utils import set_mpi_dist_environemnt - -global dist - +import mcr_dl def env2int(env_list, default=-1): for e in env_list: @@ -39,41 +37,14 @@ def env2int(env_list, default=-1): return default -def init_torch_distributed(backend): - global dist - import torch.distributed as dist - if backend == 'nccl': - mpi_discovery() - elif backend == 'mpi': - set_mpi_dist_environemnt() - dist.init_process_group(backend) - local_rank = int(os.environ['LOCAL_RANK']) - get_accelerator().set_device(local_rank) - -def init_mcr_dl_comm(backend): - global dist - import mcr_dl as dist - dist.init_distributed(dist_backend=backend, use_mcr_dl=True) - local_rank = int(os.environ['LOCAL_RANK']) - get_accelerator().set_device(local_rank) - - -def init_processes(local_rank, args): - if args.dist == 'mcr_dl': - init_mcr_dl_comm(args.backend) - elif args.dist == 'torch': - init_torch_distributed(args.backend) - else: - print_rank_0(f"distributed framework {args.dist} not supported") - exit(0) - - def print_rank_0(message): + dist = mcr_dl.get_distributed_engine() if dist.get_rank() == 0: print(message) def print_header(args, comm_op): + dist = mcr_dl.get_distributed_engine() if comm_op == 'pt2pt': world_size = 2 else: @@ -90,6 +61,7 @@ def print_header(args, comm_op): def get_bw(comm_op, size, duration, args): + dist = mcr_dl.get_distributed_engine() n = dist.get_world_size() tput = 0 busbw = 0 @@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration): def sync_all(): + dist = mcr_dl.get_distributed_engine() get_accelerator().synchronize() dist.barrier() def max_numel(comm_op, dtype, mem_factor, local_rank, args): + dist = mcr_dl.get_distributed_engine() dtype_size = _element_size(dtype) max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast': diff --git a/mcr_dl/__init__.py b/mcr_dl/__init__.py index c276061f..33c5929e 100644 --- a/mcr_dl/__init__.py +++ b/mcr_dl/__init__.py @@ -17,4 +17,54 @@ # limitations under the License. from .utils import * -from .comm import * \ No newline at end of file +from .comm import * + +global __dist_engine +global __dist_backend + +__dist_engine = None +__dist_backend = None + +def init_torch_distributed(backend): + import torch.distributed as dist + if backend == 'nccl': + mpi_discovery() + elif backend == 'mpi': + set_mpi_dist_environemnt() + dist.init_process_group(backend=backend) + local_rank = int(os.environ['LOCAL_RANK']) + # get_accelerator().set_device(local_rank) + print(f'Rank : {dist.get_rank()} World_Size : {dist.get_world_size()}', flush = True) + +def init_mcr_dl_comm(backend): + import mcr_dl + mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True) + local_rank = int(os.environ['LOCAL_RANK']) + #get_accelerator().set_device(local_rank) + +def init_processes(dist_engine, dist_backend, world_size = -1, rank = -1, timeout = None, init_method = None): + print(f'Comm : {dist_engine} Backend : {dist_backend}') + + global __dist_engine + global __dist_backend + __dist_engine = dist_engine + __dist_backend = dist_backend + if dist_engine == 'mcr_dl': + init_mcr_dl_comm(dist_backend) + elif dist_engine == 'torch': + init_torch_distributed(dist_backend) + else: + print(f"distributed framework {dist_engine} not supported") + exit(0) + +def get_distributed_engine(): + global __dist_engine + if __dist_engine is None: + return None + if __dist_engine == 'torch': + return torch.distributed + elif __dist_engine == 'mcr_dl': + import mcr_dl + return mcr_dl + print(f"Unsupported values for __dist_engine. Expected values 'torch' or 'mcr_dl'") + exit(0) \ No newline at end of file diff --git a/mcr_dl/mpi.py b/mcr_dl/mpi.py index 8bbd14c6..22cad991 100644 --- a/mcr_dl/mpi.py +++ b/mcr_dl/mpi.py @@ -72,12 +72,13 @@ def destroy_process_group(self, group=None): pass def new_group(self, ranks): - # TODO: Change this to use comm_op.new_group when the impl. is ready. + # TODO: Change this to use self.mpi_comm_op.new_group(ranks) when the impl. is ready. if not torch.distributed.is_initialized(): from mcr_dl.torch import TorchBackend - d = TorchBackend(rank=self.rank, size=self.size) + d = TorchBackend(rank=self.rank, world_size=self.size) logger.info(f"new group called with {ranks}") return torch.distributed.new_group(ranks) + # return self.mpi_comm_op.new_group(ranks) def get_rank(self, group=None): return self.mpi_comm_op.get_rank(0) diff --git a/mcr_dl/torch.py b/mcr_dl/torch.py index f3ab8c09..e93079a3 100644 --- a/mcr_dl/torch.py +++ b/mcr_dl/torch.py @@ -23,6 +23,7 @@ from .utils import * from .backend import * from .comm import * +from .constants import default_pg_timeout DS_COMM_ALL_GATHER_OFF = False DS_COMM_REDUCE_SCATTER_OFF = False @@ -119,7 +120,7 @@ class TorchBackend(Backend): needed. """ - def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'): + def __init__(self, backend="mpi", init_method = None, timeout = default_pg_timeout, rank=-1, world_size=-1, name='torch'): super(TorchBackend, self).__init__() self.has_all_reduce_coalesced = has_all_reduce_coalesced() self.has_coalescing_manager = has_coalescing_manager() @@ -131,7 +132,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' # The idea is to fake that dist backend is initialized even when # it is not so we can run on a single GPU without doing any init_process_group self.single_gpu_mode = True - self.init_process_group(backend, timeout, init_method, rank, world_size) + self.init_process_group(backend=backend, init_method=init_method, timeout= timeout, rank=rank, world_size= world_size) @classmethod def get_all_gather_function(self): diff --git a/tests/main.py b/tests/main.py index e22e72ca..76e6e805 100644 --- a/tests/main.py +++ b/tests/main.py @@ -20,7 +20,7 @@ import time import argparse import torch - +import mcr_dl from mcr_dl.constants import TORCH_DISTRIBUTED_DEFAULT_PORT from common import set_accelerator_visible from mcr_dl.cuda_accelerator import get_accelerator @@ -32,38 +32,8 @@ parser.add_argument("--dist", choices=['mcr_dl', 'torch'], help = "torch.distributed or mcr-dl for distributed") args = parser.parse_args() -def init_torch_distributed(backend): - global dist - import torch.distributed as dist - if backend == 'nccl': - mpi_discovery() - elif backend == 'mpi': - set_mpi_dist_environemnt() - dist.init_process_group(backend) - local_rank = int(os.environ['LOCAL_RANK']) - get_accelerator().set_device(local_rank) - - -def init_mcr_dl_comm(backend): - global dist - import mcr_dl - import mcr_dl as dist - mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True) - local_rank = int(os.environ['LOCAL_RANK']) - #get_accelerator().set_device(local_rank) - - -def init_processes(args_dist, args_backend): - print(f'Comm : {args_dist} Backend : {args_backend}') - if args_dist == 'mcr_dl': - init_mcr_dl_comm(args_backend) - elif args_dist == 'torch': - init_torch_distributed(args_backend) - else: - print(f"distributed framework {args_dist} not supported") - exit(0) - def all_reduce(): + dist = mcr_dl.get_distributed_engine() x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks @@ -71,6 +41,7 @@ def all_reduce(): assert torch.all(x == result) def all_reduce_benchmark(): + dist = mcr_dl.get_distributed_engine() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)] rank = dist.get_rank() @@ -101,7 +72,7 @@ def all_reduce_benchmark(): if __name__ == "__main__": set_accelerator_visible() - init_processes(args_dist=args.dist, args_backend=args.backend) + mcr_dl.init_processes(dist_engine = args.dist, dist_backend = args.backend) # all_reduce() all_reduce_benchmark()