Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to simplify access to MCR-DL #16

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__pycache__
mcr_dl/git_version_info_installed.py
mcr_dl.egg-info/
mcr_dl/config.yml
mcr_dl/build_config.yml
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ python setup.py install
```

### Update Configurations
Update mpi, cuda, and nccl paths appropriately in [mcr_dl/config.yml](/mcr_dl/config.yml)
Update mpi, cuda, and nccl paths appropriately in [mcr_dl/config.yml](/mcr_dl/build_config.yml)

### The MCR-DL Communication Benchmarking Suite

Expand Down
8 changes: 3 additions & 5 deletions benchmarks/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def timed_all_gather(input, output, start_event, end_event, args):


def run_all_gather(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'all_gather')
Expand Down Expand Up @@ -98,6 +95,7 @@ def run_all_gather(local_rank, args):
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
print(f"#######All gather world size : {world_size}")
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
Expand Down Expand Up @@ -149,5 +147,5 @@ def run_all_gather(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_gather(local_rank=rank, args=args)
19 changes: 10 additions & 9 deletions benchmarks/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@


def timed_all_reduce(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
import mcr_dl
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +60,12 @@ def timed_all_reduce(input, start_event, end_event, args):


def run_all_reduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
# if args.dist == 'torch':
# import torch.distributed as dist
# elif args.dist == 'mcr_dl':
# import mcr_dl as dist
import mcr_dl
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'all_reduce')
Expand Down Expand Up @@ -125,7 +125,8 @@ def run_all_reduce(local_rank, args):


if __name__ == "__main__":
import mcr_dl
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_reduce(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_all_to_all(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +59,7 @@ def timed_all_to_all(input, output, start_event, end_event, args):


def run_all_to_all(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

world_size = dist.get_world_size()
global_rank = dist.get_rank()
Expand Down Expand Up @@ -147,5 +141,5 @@ def run_all_to_all(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_to_all(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_broadcast(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +59,7 @@ def timed_broadcast(input, start_event, end_event, args):


def run_broadcast(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'broadcast')
Expand Down Expand Up @@ -125,5 +119,5 @@ def run_broadcast(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_broadcast(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_pt2pt(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -81,10 +78,7 @@ def timed_pt2pt(input, start_event, end_event, args):


def run_pt2pt(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'pt2pt')
Expand Down Expand Up @@ -144,5 +138,5 @@ def run_pt2pt(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_pt2pt(local_rank=rank, args=args)
2 changes: 1 addition & 1 deletion benchmarks/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# For importing
def main(args, rank):

init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)

ops_to_run = []
if args.all_reduce:
Expand Down
38 changes: 6 additions & 32 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from mcr_dl.cuda_accelerator import get_accelerator
from mcr_dl.comm import mpi_discovery
from mcr_dl.utils import set_mpi_dist_environemnt

global dist

import mcr_dl

def env2int(env_list, default=-1):
for e in env_list:
Expand All @@ -39,41 +37,14 @@ def env2int(env_list, default=-1):
return default


def init_torch_distributed(backend):
global dist
import torch.distributed as dist
if backend == 'nccl':
mpi_discovery()
elif backend == 'mpi':
set_mpi_dist_environemnt()
dist.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)

def init_mcr_dl_comm(backend):
global dist
import mcr_dl as dist
dist.init_distributed(dist_backend=backend, use_mcr_dl=True)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)


def init_processes(local_rank, args):
if args.dist == 'mcr_dl':
init_mcr_dl_comm(args.backend)
elif args.dist == 'torch':
init_torch_distributed(args.backend)
else:
print_rank_0(f"distributed framework {args.dist} not supported")
exit(0)


def print_rank_0(message):
dist = mcr_dl.get_distributed_engine()
if dist.get_rank() == 0:
print(message)


def print_header(args, comm_op):
dist = mcr_dl.get_distributed_engine()
if comm_op == 'pt2pt':
world_size = 2
else:
Expand All @@ -90,6 +61,7 @@ def print_header(args, comm_op):


def get_bw(comm_op, size, duration, args):
dist = mcr_dl.get_distributed_engine()
n = dist.get_world_size()
tput = 0
busbw = 0
Expand Down Expand Up @@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration):


def sync_all():
dist = mcr_dl.get_distributed_engine()
get_accelerator().synchronize()
dist.barrier()


def max_numel(comm_op, dtype, mem_factor, local_rank, args):
dist = mcr_dl.get_distributed_engine()
dtype_size = _element_size(dtype)
max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
Expand Down
45 changes: 44 additions & 1 deletion mcr_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,47 @@
# limitations under the License.

from .utils import *
from .comm import *
from .comm import *

global __dist_engine
global __dist_backend

def init_torch_distributed(backend):
import torch.distributed as dist
if backend == 'nccl':
mpi_discovery()
elif backend == 'mpi':
set_mpi_dist_environemnt()
dist.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)

def init_mcr_dl_comm(backend):
import mcr_dl
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
local_rank = int(os.environ['LOCAL_RANK'])
#get_accelerator().set_device(local_rank)

def init_processes(dist_engine, dist_backend):
print(f'Comm : {dist_engine} Backend : {dist_backend}')

global __dist_engine
global __dist_backend
__dist_engine = dist_engine
__dist_backend = dist_backend

if dist_engine == 'mcr_dl':
init_mcr_dl_comm(dist_backend)
elif dist_engine == 'torch':
init_torch_distributed(dist_backend)
else:
print(f"distributed framework {dist_engine} not supported")
exit(0)

def get_distributed_engine():
global __dist_engine
if __dist_engine == 'torch':
return torch.distributed
elif __dist_engine == 'mcr_dl':
import mcr_dl
return mcr_dl
11 changes: 0 additions & 11 deletions mcr_dl/config.yml

This file was deleted.

3 changes: 1 addition & 2 deletions mcr_dl/ops/op_builder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

class ConfigPath():
def __init__(self, file_path = None):
self.file_path = os.path.join(os.path.dirname(mcr_dl.__file__), "config.yml") if file_path is None else file_path
print(self.file_path)
self.file_path = os.path.join(os.path.dirname(mcr_dl.__file__), "build_config.yml") if file_path is None else file_path
self.config_data = self.load_config()
self.mpi_path = self.config_data.get("mpi", {}).get("path")
self.mpi_include = self.config_data.get("mpi", {}).get("include")
Expand Down
6 changes: 3 additions & 3 deletions mcr_dl/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ def set_mpi_dist_environemnt(master_addr = None):
if master_addr is not None:
os.environ['MASTER_ADDR'] = master_addr
local_rank = env2int(
['LOCAL_RANK', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'MV2_COMM_WORLD_LOCAL_RANK', 'SLURM_LOCALID'])
['LOCAL_RANK', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'MV2_COMM_WORLD_LOCAL_RANK', 'SLURM_LOCALID', 'MVP_COMM_WORLD_LOCAL_RANK'])
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(local_rank)
rank = env2int(['RANK', 'MPI_RANKID', 'OMPI_COMM_WORLD_RANK', 'MV2_COMM_WORLD_RANK', 'SLURM_PROCID'])
rank = env2int(['RANK', 'MPI_RANKID', 'OMPI_COMM_WORLD_RANK', 'MV2_COMM_WORLD_RANK', 'SLURM_PROCID', 'MVP_COMM_WORLD_LOCAL_RANK'])
if 'RANK' not in os.environ:
os.environ['RANK'] = str(rank)
world_size = env2int(['WORLD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'SLURM_NPROCS'])
world_size = env2int(['WORLD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'SLURM_NPROCS', 'MVP_COMM_WORLD_LOCAL_RANK'])
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = str(world_size)
Loading
Loading