diff --git a/benchmarks/all_gather.py b/benchmarks/all_gather.py
index 59ed3247..93577704 100644
--- a/benchmarks/all_gather.py
+++ b/benchmarks/all_gather.py
@@ -67,10 +67,7 @@ def timed_all_gather(input, output, start_event, end_event, args):
 
 
 def run_all_gather(local_rank, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     # Prepare benchmark header
     print_header(args, 'all_gather')
@@ -149,5 +146,5 @@ def run_all_gather(local_rank, args):
 if __name__ == "__main__":
     args = benchmark_parser().parse_args()
     rank = args.local_rank
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
     run_all_gather(local_rank=rank, args=args)
\ No newline at end of file
diff --git a/benchmarks/all_reduce.py b/benchmarks/all_reduce.py
index 041fc3a9..dcb146b7 100644
--- a/benchmarks/all_reduce.py
+++ b/benchmarks/all_reduce.py
@@ -28,10 +28,8 @@
 
 
 def timed_all_reduce(input, start_event, end_event, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    import mcr_dl
+    dist = mcr_dl.get_distributed_engine()
 
     sync_all()
     # Warmups, establish connections, etc.
@@ -62,10 +60,8 @@ def timed_all_reduce(input, start_event, end_event, args):
 
 
 def run_all_reduce(local_rank, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    import mcr_dl
+    dist = mcr_dl.get_distributed_engine()
 
     # Prepare benchmark header
     print_header(args, 'all_reduce')
@@ -125,7 +121,8 @@ def run_all_reduce(local_rank, args):
 
 
 if __name__ == "__main__":
+    import mcr_dl
     args = benchmark_parser().parse_args()
     rank = args.local_rank
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
     run_all_reduce(local_rank=rank, args=args)
\ No newline at end of file
diff --git a/benchmarks/all_to_all.py b/benchmarks/all_to_all.py
index 3317d1e4..1c7def55 100644
--- a/benchmarks/all_to_all.py
+++ b/benchmarks/all_to_all.py
@@ -28,10 +28,7 @@
 
 
 def timed_all_to_all(input, output, start_event, end_event, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     sync_all()
     # Warmups, establish connections, etc.
@@ -62,10 +59,7 @@ def timed_all_to_all(input, output, start_event, end_event, args):
 
 
 def run_all_to_all(local_rank, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     world_size = dist.get_world_size()
     global_rank = dist.get_rank()
@@ -147,5 +141,5 @@ def run_all_to_all(local_rank, args):
 if __name__ == "__main__":
     args = benchmark_parser().parse_args()
     rank = args.local_rank
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
     run_all_to_all(local_rank=rank, args=args)
\ No newline at end of file
diff --git a/benchmarks/broadcast.py b/benchmarks/broadcast.py
index f7ca6806..0e3ba250 100644
--- a/benchmarks/broadcast.py
+++ b/benchmarks/broadcast.py
@@ -28,10 +28,7 @@
 
 
 def timed_broadcast(input, start_event, end_event, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     sync_all()
     # Warmups, establish connections, etc.
@@ -62,10 +59,7 @@ def timed_broadcast(input, start_event, end_event, args):
 
 
 def run_broadcast(local_rank, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     # Prepare benchmark header
     print_header(args, 'broadcast')
@@ -125,5 +119,5 @@ def run_broadcast(local_rank, args):
 if __name__ == "__main__":
     args = benchmark_parser().parse_args()
     rank = args.local_rank
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
     run_broadcast(local_rank=rank, args=args)
\ No newline at end of file
diff --git a/benchmarks/pt2pt.py b/benchmarks/pt2pt.py
index 9b51a8b0..5a680d92 100644
--- a/benchmarks/pt2pt.py
+++ b/benchmarks/pt2pt.py
@@ -28,10 +28,7 @@
 
 
 def timed_pt2pt(input, start_event, end_event, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     sync_all()
     # Warmups, establish connections, etc.
@@ -81,10 +78,7 @@ def timed_pt2pt(input, start_event, end_event, args):
 
 
 def run_pt2pt(local_rank, args):
-    if args.dist == 'torch':
-        import torch.distributed as dist
-    elif args.dist == 'mcr_dl':
-        import mcr_dl as dist
+    dist = mcr_dl.get_distributed_engine()
 
     # Prepare benchmark header
     print_header(args, 'pt2pt')
@@ -144,5 +138,5 @@ def run_pt2pt(local_rank, args):
 if __name__ == "__main__":
     args = benchmark_parser().parse_args()
     rank = args.local_rank
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
     run_pt2pt(local_rank=rank, args=args)
\ No newline at end of file
diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py
index f5695310..ef6adbd6 100644
--- a/benchmarks/run_all.py
+++ b/benchmarks/run_all.py
@@ -33,7 +33,7 @@
 # For importing
 def main(args, rank):
 
-    init_processes(local_rank=rank, args=args)
+    mcr_dl.init_processes(args.dist, args.backend)
 
     ops_to_run = []
     if args.all_reduce:
diff --git a/benchmarks/utils.py b/benchmarks/utils.py
index aca7bafe..35cf33ea 100644
--- a/benchmarks/utils.py
+++ b/benchmarks/utils.py
@@ -28,9 +28,7 @@
 from mcr_dl.cuda_accelerator import get_accelerator
 from mcr_dl.comm import mpi_discovery
 from mcr_dl.utils import set_mpi_dist_environemnt
-
-global dist
-
+import mcr_dl
 
 def env2int(env_list, default=-1):
     for e in env_list:
@@ -39,41 +37,14 @@ def env2int(env_list, default=-1):
     return default
 
 
-def init_torch_distributed(backend):
-    global dist
-    import torch.distributed as dist
-    if backend == 'nccl':
-        mpi_discovery()
-    elif backend == 'mpi':
-        set_mpi_dist_environemnt()
-    dist.init_process_group(backend)
-    local_rank = int(os.environ['LOCAL_RANK'])
-    get_accelerator().set_device(local_rank)
-
-def init_mcr_dl_comm(backend):
-    global dist
-    import mcr_dl as dist
-    dist.init_distributed(dist_backend=backend, use_mcr_dl=True)
-    local_rank = int(os.environ['LOCAL_RANK'])
-    get_accelerator().set_device(local_rank)
-
-
-def init_processes(local_rank, args):
-    if args.dist == 'mcr_dl':
-        init_mcr_dl_comm(args.backend)
-    elif args.dist == 'torch':
-        init_torch_distributed(args.backend)
-    else:
-        print_rank_0(f"distributed framework {args.dist} not supported")
-        exit(0)
-
-
 def print_rank_0(message):
+    dist = mcr_dl.get_distributed_engine()
     if dist.get_rank() == 0:
         print(message)
 
 
 def print_header(args, comm_op):
+    dist = mcr_dl.get_distributed_engine()
     if comm_op == 'pt2pt':
         world_size = 2
     else:
@@ -90,6 +61,7 @@ def print_header(args, comm_op):
 
 
 def get_bw(comm_op, size, duration, args):
+    dist = mcr_dl.get_distributed_engine()
     n = dist.get_world_size()
     tput = 0
     busbw = 0
@@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration):
 
 
 def sync_all():
+    dist = mcr_dl.get_distributed_engine()
     get_accelerator().synchronize()
     dist.barrier()
 
 
 def max_numel(comm_op, dtype, mem_factor, local_rank, args):
+    dist = mcr_dl.get_distributed_engine()
     dtype_size = _element_size(dtype)
     max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
     if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
diff --git a/mcr_dl/__init__.py b/mcr_dl/__init__.py
index c276061f..33c5929e 100644
--- a/mcr_dl/__init__.py
+++ b/mcr_dl/__init__.py
@@ -17,4 +17,54 @@
 # limitations under the License.
 
 from .utils import *
-from .comm import *
\ No newline at end of file
+from .comm import *
+
+global __dist_engine
+global __dist_backend
+
+__dist_engine = None
+__dist_backend = None
+
+def init_torch_distributed(backend):
+    import torch.distributed as dist
+    if backend == 'nccl':
+        mpi_discovery()
+    elif backend == 'mpi':
+        set_mpi_dist_environemnt()
+    dist.init_process_group(backend=backend)
+    local_rank = int(os.environ['LOCAL_RANK'])
+    # get_accelerator().set_device(local_rank)
+    print(f'Rank : {dist.get_rank()}  World_Size : {dist.get_world_size()}', flush = True)
+
+def init_mcr_dl_comm(backend):
+    import mcr_dl
+    mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
+    local_rank = int(os.environ['LOCAL_RANK'])
+    #get_accelerator().set_device(local_rank)
+
+def init_processes(dist_engine, dist_backend, world_size = -1, rank = -1, timeout = None, init_method = None):
+    print(f'Comm : {dist_engine}  Backend : {dist_backend}')
+
+    global __dist_engine
+    global __dist_backend
+    __dist_engine = dist_engine
+    __dist_backend = dist_backend
+    if dist_engine == 'mcr_dl':
+        init_mcr_dl_comm(dist_backend)
+    elif dist_engine == 'torch':
+        init_torch_distributed(dist_backend)
+    else:
+        print(f"distributed framework {dist_engine} not supported")
+        exit(0)
+
+def get_distributed_engine():
+    global __dist_engine
+    if  __dist_engine is None:
+        return None
+    if __dist_engine == 'torch':
+        return torch.distributed
+    elif __dist_engine == 'mcr_dl':
+        import mcr_dl
+        return mcr_dl
+    print(f"Unsupported values for __dist_engine. Expected values 'torch' or 'mcr_dl'")
+    exit(0)
\ No newline at end of file
diff --git a/mcr_dl/mpi.py b/mcr_dl/mpi.py
index 8bbd14c6..22cad991 100644
--- a/mcr_dl/mpi.py
+++ b/mcr_dl/mpi.py
@@ -72,12 +72,13 @@ def destroy_process_group(self, group=None):
         pass
 
     def new_group(self, ranks):
-        # TODO: Change this to use comm_op.new_group when the impl. is ready.
+        # TODO: Change this to use self.mpi_comm_op.new_group(ranks) when the impl. is ready.
         if not torch.distributed.is_initialized():
             from mcr_dl.torch import TorchBackend
-            d = TorchBackend(rank=self.rank, size=self.size)
+            d = TorchBackend(rank=self.rank, world_size=self.size)
         logger.info(f"new group called with {ranks}")
         return torch.distributed.new_group(ranks)
+        # return self.mpi_comm_op.new_group(ranks)
 
     def get_rank(self, group=None):
         return self.mpi_comm_op.get_rank(0)
diff --git a/mcr_dl/torch.py b/mcr_dl/torch.py
index f3ab8c09..e93079a3 100644
--- a/mcr_dl/torch.py
+++ b/mcr_dl/torch.py
@@ -23,6 +23,7 @@
 from .utils import *
 from .backend import *
 from .comm import *
+from .constants import default_pg_timeout
 
 DS_COMM_ALL_GATHER_OFF = False
 DS_COMM_REDUCE_SCATTER_OFF = False
@@ -119,7 +120,7 @@ class TorchBackend(Backend):
         needed.
     """
 
-    def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
+    def __init__(self, backend="mpi", init_method = None, timeout = default_pg_timeout, rank=-1, world_size=-1, name='torch'):
         super(TorchBackend, self).__init__()
         self.has_all_reduce_coalesced = has_all_reduce_coalesced()
         self.has_coalescing_manager = has_coalescing_manager()
@@ -131,7 +132,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
         # The idea is to fake that dist backend is initialized even when
         # it is not so we can run on a single GPU without doing any init_process_group
         self.single_gpu_mode = True
-        self.init_process_group(backend, timeout, init_method, rank, world_size)
+        self.init_process_group(backend=backend, init_method=init_method, timeout= timeout, rank=rank, world_size= world_size)
 
     @classmethod
     def get_all_gather_function(self):
diff --git a/tests/main.py b/tests/main.py
index e22e72ca..76e6e805 100644
--- a/tests/main.py
+++ b/tests/main.py
@@ -20,7 +20,7 @@
 import time
 import argparse
 import torch
-
+import mcr_dl
 from mcr_dl.constants import TORCH_DISTRIBUTED_DEFAULT_PORT
 from common import set_accelerator_visible
 from mcr_dl.cuda_accelerator import get_accelerator
@@ -32,38 +32,8 @@
 parser.add_argument("--dist", choices=['mcr_dl', 'torch'], help = "torch.distributed or mcr-dl for distributed")
 args = parser.parse_args()
 
-def init_torch_distributed(backend):
-    global dist
-    import torch.distributed as dist
-    if backend == 'nccl':
-        mpi_discovery()
-    elif backend == 'mpi':
-        set_mpi_dist_environemnt()
-    dist.init_process_group(backend)
-    local_rank = int(os.environ['LOCAL_RANK'])
-    get_accelerator().set_device(local_rank)
-
-
-def init_mcr_dl_comm(backend):
-    global dist
-    import mcr_dl
-    import mcr_dl as dist
-    mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
-    local_rank = int(os.environ['LOCAL_RANK'])
-    #get_accelerator().set_device(local_rank)
-
-
-def init_processes(args_dist, args_backend):
-    print(f'Comm : {args_dist}  Backend : {args_backend}')
-    if args_dist == 'mcr_dl':
-        init_mcr_dl_comm(args_backend)
-    elif args_dist == 'torch':
-        init_torch_distributed(args_backend)
-    else:
-        print(f"distributed framework {args_dist} not supported")
-        exit(0)
-
 def all_reduce():
+    dist = mcr_dl.get_distributed_engine()
     x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
     sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
     result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
@@ -71,6 +41,7 @@ def all_reduce():
     assert torch.all(x == result)
 
 def all_reduce_benchmark():
+    dist = mcr_dl.get_distributed_engine()
     start_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)]
     end_events = [torch.cuda.Event(enable_timing=True) for _ in range(2, 30)]
     rank = dist.get_rank()
@@ -101,7 +72,7 @@ def all_reduce_benchmark():
 
 if __name__ == "__main__":
     set_accelerator_visible()
-    init_processes(args_dist=args.dist, args_backend=args.backend)
+    mcr_dl.init_processes(dist_engine = args.dist, dist_backend = args.backend)
     # all_reduce()
     all_reduce_benchmark()