diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ba460fb0..4de6df86 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -43,4 +43,4 @@ jobs: - name: Install pylops-mpi run: pip install . - name: Testing using pytest-mpi - run: mpiexec -n ${{ matrix.rank }} pytest --with-mpi + run: mpiexec -n ${{ matrix.rank }} pytest tests/ --with-mpi diff --git a/Makefile b/Makefile index aa62520c..33065808 100644 --- a/Makefile +++ b/Makefile @@ -36,6 +36,10 @@ lint: tests: mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi +# assuming NUM_PRCESS <= number of gpus available +tests_nccl: + mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi + doc: cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\ rm -rf source/tutorials && rm -rf build &&\ diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 431ddb6e..345cbde1 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -1,12 +1,25 @@ -import numpy as np -from typing import Optional, Union, Tuple, List -from numbers import Integral -from mpi4py import MPI from enum import Enum +from numbers import Integral +from typing import Any, List, Optional, Tuple, Union, NewType +import numpy as np +from mpi4py import MPI from pylops.utils import DTypeLike, NDArray +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_module, get_array_module, get_module_name +from pylops.utils.backend import get_array_module, get_module, get_module_name +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the DistributedArray module") +nccl_message = deps.nccl_import("the DistributedArray module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split + from cupy.cuda.nccl import NcclCommunicator +else: + NcclCommunicator = Any + +NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) class Partition(Enum): @@ -47,8 +60,8 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm, """ if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: local_shape = global_shape - # Split the array else: + # Split the array local_shape = list(global_shape) if base_comm.Get_rank() < (global_shape[axis] % base_comm.Get_size()): local_shape[axis] = global_shape[axis] // base_comm.Get_size() + 1 @@ -57,6 +70,35 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm, return tuple(local_shape) +def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = MPI.COMM_WORLD): + """Create new communicators based on mask + + This method creates new communicators based on ``mask``. + + Parameters + ---------- + mask : :obj:`list` + Mask defining subsets of ranks to consider when performing 'global' + operations on the distributed array such as dot product or norm. + + comm : :obj:`mpi4py.MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator`, optional + A Communicator over which array is distributed + Defaults to ``mpi4py.MPI.COMM_WORLD``. + + Returns: + ------- + sub_comm : :obj:`mpi4py.MPI.Comm` or :obj:`cupy.cuda.nccl.NcclCommunicator` + Subcommunicator according to mask + """ + # NcclCommunicatorType cannot be used with isinstance() so check the negate of MPI.Comm + if deps.nccl_enabled and not isinstance(comm, MPI.Comm): + sub_comm = nccl_split(mask) + else: + rank = comm.Get_rank() + sub_comm = comm.Split(color=mask[rank], key=rank) + return sub_comm + + class DistributedArray: r"""Distributed Numpy Arrays @@ -81,6 +123,9 @@ class DistributedArray: base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Communicator over which array is distributed. Defaults to ``mpi4py.MPI.COMM_WORLD``. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional + NCCL Communicator over which array is distributed. Whenever NCCL + Communicator is provided, the base_comm will be set to MPI.COMM_WORLD. partition : :obj:`Partition`, optional Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``. axis : :obj:`int`, optional @@ -98,6 +143,7 @@ class DistributedArray: def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, + base_comm_nccl: Optional[NcclCommunicatorType] = None, partition: Partition = Partition.SCATTER, axis: int = 0, local_shapes: Optional[List[Union[Tuple, Integral]]] = None, mask: Optional[List[Integral]] = None, @@ -111,18 +157,25 @@ def __init__(self, global_shape: Union[Tuple, Integral], if partition not in Partition: raise ValueError(f"Should be either {Partition.BROADCAST}, " f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}") + if base_comm_nccl and engine != "cupy": + raise ValueError("NCCL Communicator only works with engine `cupy`") + self.dtype = dtype self._global_shape = _value_or_sized_to_tuple(global_shape) - self._base_comm = base_comm + self._base_comm_nccl = base_comm_nccl + if base_comm_nccl is None: + self._base_comm = base_comm + else: + self._base_comm = MPI.COMM_WORLD self._partition = partition self._axis = axis self._mask = mask - self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank) - + self._sub_comm = (base_comm if base_comm_nccl is None else base_comm_nccl) if mask is None else subcomm_split(mask, (base_comm if base_comm_nccl is None else base_comm_nccl)) local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) - self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, - partition, axis) + self._local_shape = local_shapes[self.rank] if local_shapes else local_split(global_shape, base_comm, + partition, axis) + self._engine = engine self._local_array = get_module(engine).empty(shape=self.local_shape, dtype=self.dtype) @@ -150,7 +203,10 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - self.local_array[index] = self.base_comm.bcast(value) + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + nccl_bcast(self.base_comm_nccl, self.local_array, index, value) + else: + self.local_array[index] = self.base_comm.bcast(value) else: self.local_array[index] = value @@ -174,6 +230,16 @@ def base_comm(self): """ return self._base_comm + @property + def base_comm_nccl(self): + """Base NCCL Communicator + + Returns + ------- + base_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + """ + return self._base_comm_nccl + @property def local_shape(self): """Local Shape of the Distributed array @@ -222,7 +288,11 @@ def rank(self): ------- rank : :obj:`int` """ - return self.base_comm.Get_rank() + # cp.cuda.Device().id will give local rank + # It works ok in the single-node multi-gpu environment. + # But in multi-node environment, the function will break. + # So we have to use MPI.COMM_WORLD() in both cases of base_comm (MPI and NCCL) + return MPI.COMM_WORLD.Get_rank() @property def size(self): @@ -233,7 +303,7 @@ def size(self): ------- size : :obj:`int` """ - return self.base_comm.Get_size() + return MPI.COMM_WORLD.Get_size() @property def axis(self): @@ -273,7 +343,15 @@ def local_shapes(self): ------- local_shapes : :obj:`list` """ - return self.base_comm.allgather(self.local_shape) + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + # gather tuple of shapes from every rank and copy from GPU to CPU + all_tuples = self._allgather(self.local_shape).get() + # NCCL returns the flat array that packs every tuple as 1-dimensional array + # unpack each tuple from each rank + tuple_len = len(self.local_shape) + return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)] + else: + return self._allgather(self.local_shape) @property def sub_comm(self): @@ -281,7 +359,7 @@ def sub_comm(self): Returns ------- - sub_comm : :obj:`MPI.Comm` + sub_comm : :obj:`MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator` """ return self._sub_comm @@ -299,13 +377,18 @@ def asarray(self): if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: # Get only self.local_array. return self.local_array - # Gather all the local arrays and apply concatenation. - final_array = self.base_comm.allgather(self.local_array) - return np.concatenate(final_array, axis=self.axis) + + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_asarray(self.base_comm_nccl, self.local_array, self.local_shapes, self.axis) + else: + # Gather all the local arrays and apply concatenation. + final_array = self._allgather(self.local_array) + return np.concatenate(final_array, axis=self.axis) @classmethod def to_dist(cls, x: NDArray, base_comm: MPI.Comm = MPI.COMM_WORLD, + base_comm_nccl: NcclCommunicatorType = None, partition: Partition = Partition.SCATTER, axis: int = 0, local_shapes: Optional[List[Tuple]] = None, @@ -317,7 +400,9 @@ def to_dist(cls, x: NDArray, x : :obj:`numpy.ndarray` Global array. base_comm : :obj:`MPI.Comm`, optional - Type of elements in input array. Defaults to ``MPI.COMM_WORLD`` + MPI base communicator + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional + NCCL base communicator partition : :obj:`Partition`, optional Distributes the array, Defaults to ``Partition.Scatter``. axis : :obj:`int`, optional @@ -335,6 +420,7 @@ def to_dist(cls, x: NDArray, """ dist_array = DistributedArray(global_shape=x.shape, base_comm=base_comm, + base_comm_nccl=base_comm_nccl, partition=partition, axis=axis, local_shapes=local_shapes, @@ -345,7 +431,7 @@ def to_dist(cls, x: NDArray, dist_array[:] = x else: slices = [slice(None)] * x.ndim - local_shapes = np.append([0], base_comm.allgather( + local_shapes = np.append([0], dist_array._allgather( dist_array.local_shape[axis])) sum_shapes = np.cumsum(local_shapes) slices[axis] = slice(sum_shapes[dist_array.rank], @@ -356,7 +442,7 @@ def to_dist(cls, x: NDArray, def _check_local_shapes(self, local_shapes): """Check if the local shapes align with the global shape""" if local_shapes: - if len(local_shapes) != self.base_comm.size: + if len(local_shapes) != self.size: raise ValueError(f"Length of local shapes is not equal to number of processes; " f"{len(local_shapes)} != {self.size}") # Check if local shape == global shape @@ -387,22 +473,39 @@ def _check_mask(self, dist_array): raise ValueError("Mask of both the arrays must be same") def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """MPI Allreduce operation + """Allreduce operation """ - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) + else: + if recv_buf is None: + return self.base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """MPI Allreduce operation with subcommunicator + """Allreduce operation with subcommunicator """ - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) + else: + if recv_buf is None: + return self.sub_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + + def _allgather(self, send_buf, recv_buf=None): + """Allgather operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) + else: + if recv_buf is None: + return self.base_comm.allgather(send_buf) + self.base_comm.Allgather(send_buf, recv_buf) + return recv_buf def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, @@ -486,14 +589,14 @@ def dot(self, dist_array): """ self._check_partition_shape(dist_array) self._check_mask(dist_array) - + ncp = get_module(self.engine) # Convert to Partition.SCATTER if Partition.BROADCAST x = DistributedArray.to_dist(x=self.local_array) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self y = DistributedArray.to_dist(x=dist_array.local_array) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten())) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -509,32 +612,33 @@ def _compute_vector_norm(self, local_array: NDArray, Order of the norm """ # Compute along any axis + ncp = get_module(self.engine) ord = 2 if ord is None else ord if local_array.ndim == 1: - recv_buf = np.empty(shape=1, dtype=np.float64) + recv_buf = ncp.empty(shape=1, dtype=np.float64) else: global_shape = list(self.global_shape) global_shape[axis] = 1 - recv_buf = np.empty(shape=global_shape, dtype=np.float64) + recv_buf = ncp.empty(shape=global_shape, dtype=ncp.float64) if ord in ['fro', 'nuc']: raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce_subcomm(np.count_nonzero(local_array, axis=axis).astype(np.float64)) - elif ord == np.inf: + recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) + elif ord == ncp.inf: # Calculate max followed by max reduction - recv_buf = self._allreduce_subcomm(np.max(np.abs(local_array), axis=axis).astype(np.float64), + recv_buf = self._allreduce_subcomm(ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64), recv_buf, op=MPI.MAX) - recv_buf = np.squeeze(recv_buf, axis=axis) - elif ord == -np.inf: + recv_buf = ncp.squeeze(recv_buf, axis=axis) + elif ord == -ncp.inf: # Calculate min followed by min reduction - recv_buf = self._allreduce_subcomm(np.min(np.abs(local_array), axis=axis).astype(np.float64), + recv_buf = self._allreduce_subcomm(ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64), recv_buf, op=MPI.MIN) - recv_buf = np.squeeze(recv_buf, axis=axis) + recv_buf = ncp.squeeze(recv_buf, axis=axis) else: - recv_buf = self._allreduce_subcomm(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis)) - recv_buf = np.power(recv_buf, 1. / ord) + recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) + recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf def zeros_like(self): @@ -648,7 +752,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, """ ghosted_array = self.local_array.copy() if cells_front is not None: - total_cells_front = self.base_comm.allgather(cells_front) + [0] + total_cells_front = self._allgather(cells_front) + [0] # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) cells_front = total_cells_front[self.rank + 1] if self.rank != 0: @@ -664,7 +768,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), dest=self.rank + 1, tag=1) if cells_back is not None: - total_cells_back = self.base_comm.allgather(cells_back) + [0] + total_cells_back = self._allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) cells_back = total_cells_back[self.rank - 1] if self.rank != 0: @@ -708,8 +812,8 @@ def __init__(self, distarrays: List, base_comm: MPI.Comm = MPI.COMM_WORLD): self.distarrays = distarrays self.narrays = len(distarrays) self.base_comm = base_comm - self.rank = base_comm.Get_rank() - self.size = base_comm.Get_size() + self.rank = MPI.COMM_WORLD.Get_rank() + self.size = MPI.COMM_WORLD.Get_size() def __getitem__(self, index): return self.distarrays[index] @@ -768,7 +872,6 @@ def __rmul__(self, x): def add(self, stacked_array): """Stacked Distributed Addition of arrays - """ self._check_stacked_size(stacked_array) SumArray = self.copy() diff --git a/pylops_mpi/utils/__init__.py b/pylops_mpi/utils/__init__.py index 1d064202..03204685 100644 --- a/pylops_mpi/utils/__init__.py +++ b/pylops_mpi/utils/__init__.py @@ -1,2 +1,5 @@ # isort: skip_file -from .dottest import * + +# currently dottest create circular dependency with DistributedArray.py +# from .dottest import * +from .deps import * diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py new file mode 100644 index 00000000..c3b02b71 --- /dev/null +++ b/pylops_mpi/utils/_nccl.py @@ -0,0 +1,288 @@ +__all__ = [ + "initialize_nccl_comm", + "nccl_split", + "nccl_allgather", + "nccl_allreduce", + "nccl_bcast", + "nccl_asarray" +] + +from enum import IntEnum +from mpi4py import MPI +import os +import numpy as np +import cupy as cp +import cupy.cuda.nccl as nccl + +cupy_to_nccl_dtype = { + "float32": nccl.NCCL_FLOAT32, + "float64": nccl.NCCL_FLOAT64, + "int32": nccl.NCCL_INT32, + "int64": nccl.NCCL_INT64, + "uint8": nccl.NCCL_UINT8, + "int8": nccl.NCCL_INT8, + "uint32": nccl.NCCL_UINT32, + "uint64": nccl.NCCL_UINT64, +} + + +class NcclOp(IntEnum): + SUM = nccl.NCCL_SUM + PROD = nccl.NCCL_PROD + MAX = nccl.NCCL_MAX + MIN = nccl.NCCL_MIN + + +def mpi_op_to_nccl(mpi_op) -> NcclOp: + """ Map MPI reduction operation to NCCL equivalent + + Parameters + ---------- + mpi_op : :obj:`MPI.Op` + A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN). + + Returns: + ------- + NcclOp : :obj:`IntEnum` + A corresponding NCCL reduction operation. + """ + if mpi_op is MPI.SUM: + return NcclOp.SUM + elif mpi_op is MPI.PROD: + return NcclOp.PROD + elif mpi_op is MPI.MAX: + return NcclOp.MAX + elif mpi_op is MPI.MIN: + return NcclOp.MIN + else: + raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}") + + +def initialize_nccl_comm() -> nccl.NcclCommunicator: + """ Initialize NCCL world communicator for every GPU device + + Each GPU must be managed by exactly one MPI process. + i.e. the number of MPI process launched must be equal to + number of GPUs in communications + + Returns: + ------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + A corresponding NCCL communicator + """ + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + device_id = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") + or rank % cp.cuda.runtime.getDeviceCount() + ) + cp.cuda.Device(device_id).use() + + if rank == 0: + with cp.cuda.Device(device_id): + nccl_id_bytes = nccl.get_unique_id() + else: + nccl_id_bytes = None + nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0) + + nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank) + return nccl_comm + + +def nccl_split(mask) -> nccl.NcclCommunicator: + """ NCCL-equivalent of MPI.Split() + + Splitting the communicator into multiple NCCL subcommunicators + + Parameters + ---------- + mask : :obj:`list` + Mask defining subsets of ranks to consider when performing 'global' + operations on the distributed array such as dot product or norm. + + Returns: + ------- + sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + Subcommunicator according to mask + """ + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + sub_comm = comm.Split(color=mask[rank], key=rank) + + sub_rank = sub_comm.Get_rank() + sub_size = sub_comm.Get_size() + + if sub_rank == 0: + nccl_id_bytes = nccl.get_unique_id() + else: + nccl_id_bytes = None + nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0) + sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank) + + return sub_comm + + +def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray: + """ NCCL equivalent of MPI_Allgather. Gathers data from all GPUs + and distributes the concatenated result to all participants. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator over which data will be gathered. + send_buf : :obj:`cupy.ndarray` or array-like + The data buffer from the local GPU to be sent. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to receive data from all GPUs. If None, a new + buffer will be allocated with the appropriate shape. + + Returns + ------- + recv_buf : :obj:`cupy.ndarray` + A buffer containing the gathered data from all GPUs. + """ + send_buf = ( + send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) + ) + if recv_buf is None: + recv_buf = cp.zeros( + MPI.COMM_WORLD.Get_size() * send_buf.size, + dtype=send_buf.dtype, + ) + nccl_comm.allGather( + send_buf.data.ptr, + recv_buf.data.ptr, + send_buf.size, + cupy_to_nccl_dtype[str(send_buf.dtype)], + cp.cuda.Stream.null.ptr, + ) + return recv_buf + + +def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray: + """ NCCL equivalent of MPI_Allreduce. Applies a reduction operation + (e.g., sum, max) across all GPUs and distributes the result. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for collective communication. + send_buf : :obj:`cupy.ndarray` or array-like + The data buffer from the local GPU to be reduced. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + op : :obj:mpi4py.MPI.Op, optional + The reduction operation to apply. Defaults to MPI.SUM. + + Returns + ------- + recv_buf : :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all GPUs. + """ + send_buf = ( + send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) + ) + if recv_buf is None: + recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype) + + nccl_comm.allReduce( + send_buf.data.ptr, + recv_buf.data.ptr, + send_buf.size, + cupy_to_nccl_dtype[str(send_buf.dtype)], + mpi_op_to_nccl(op), + cp.cuda.Stream.null.ptr, + ) + return recv_buf + + +def nccl_bcast(nccl_comm, local_array, index, value) -> None: + """ NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index + from the root GPU (rank 0) to all other GPUs. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for collective communication. + local_array : :obj:`cupy.ndarray` + The local array on each GPU. The value at `index` will be broadcasted. + index : :obj:`int` + The index in the array to be broadcasted. + value : :obj:`scalar` + The value to broadcast (only used by the root GPU, rank 0). + + Returns + ------- + None + """ + if nccl_comm.rank_id() == 0: + local_array[index] = value + nccl_comm.bcast( + local_array[index].data.ptr, + local_array[index].size, + cupy_to_nccl_dtype[str(local_array[index].dtype)], + 0, + cp.cuda.Stream.null.ptr, + ) + + +def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: + """Global view of the array + + Gather all local GPU arrays into a single global array via NCCL all-gather. + + Parameters + ---------- + nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` + The NCCL communicator used for collective communication. + local_array : :obj:`cupy.ndarray` + The local array on the current GPU. + local_shapes : :obj:`list` + A list of shapes for each GPU local array (used to trim padding). + axis : :obj:`int` + The axis along which to concatenate the gathered arrays. + + Returns + ------- + final_array : :obj:`cupy.ndarray` + Global array gathered from all GPUs and concatenated along `axis`. + + Notes + ----- + NCCL's allGather requires the sending buffer to have the same size for every device. + Therefore, the padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the sending buffer has the size of + each dimension corresponding to the max possible size of that dimension. + """ + sizes_each_dim = list(zip(*local_shapes)) + + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape) + ] + + send_buf = cp.pad( + local_array, pad_size, mode="constant", constant_values=0 + ) + + # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred + ndev = len(local_shapes) + recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + nccl_allgather(nccl_comm, send_buf, recv_buf) + + # extract an individual array from each device + chunk_size = np.prod(send_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in local_shapes[i]) + chunks[i] = chunks[i].reshape(send_shape)[slicing] + # combine back to single global array + return cp.concatenate(chunks, axis=axis) diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py new file mode 100644 index 00000000..443098e5 --- /dev/null +++ b/pylops_mpi/utils/deps.py @@ -0,0 +1,45 @@ +__all__ = [ + "nccl_enabled" +] + +import os +from importlib import import_module, util +from typing import Optional + + +# error message at import of available package +def nccl_import(message: Optional[str] = None) -> str: + nccl_test = ( + # detect if nccl is available and the user is expecting it to be used + # CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash + util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1 + ) + if nccl_test: + # try importing it + try: + import_module("cupy.cuda.nccl") # noqa: F401 + + # if succesful, set the message to None + nccl_message = None + # if unable to import but the package is installed + except (ImportError, ModuleNotFoundError) as e: + nccl_message = ( + f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})." + "Please ensure your CUDA NCCL environment is set up correctly " + "for more detials visit 'https://docs.cupy.dev/en/stable/install.html'" + ) + print(UserWarning(nccl_message)) + else: + nccl_message = ( + "cupy.cuda.nccl package not installed or os.getenv('NCCL_PYLOPS_MPI') == 0. " + f"In order to be able to use {message} " + "ensure 'os.getenv('NCCL_PYLOPS_MPI') == 1'" + "for more details for installing NCCL visit 'https://docs.cupy.dev/en/stable/install.html'" + ) + + return nccl_message + + +nccl_enabled: bool = ( + True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False +) diff --git a/setup.cfg b/setup.cfg index 0411c6c7..c5bee689 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,9 @@ [tool:pytest] addopts = --verbose -python_files = tests/*.py +python_files = tests/*.py tests_nccl/*.py [flake8] ignore = E203, E501, W503, E402 per-file-ignores = - __init__.py: F401, F403, F405 -max-line-length = 88 \ No newline at end of file +__init__.py: F401, F403, F405 +max-line-length = 88 diff --git a/tests/test_distributedarray.py b/tests/test_distributedarray.py index 8ee47a85..9eaea0f4 100644 --- a/tests/test_distributedarray.py +++ b/tests/test_distributedarray.py @@ -201,7 +201,7 @@ def test_distributed_maskeddot(par1, par2): """Test Distributed Dot product with masked array""" # number of subcommunicators if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 + nsub = 2 elif MPI.COMM_WORLD.Get_size() % 3 == 0: nsub = 3 else: @@ -236,7 +236,7 @@ def test_distributed_maskednorm(par): """Test Distributed numpy.linalg.norm method with masked array""" # number of subcommunicators if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 + nsub = 2 elif MPI.COMM_WORLD.Get_size() % 3 == 0: nsub = 3 else: diff --git a/tests_nccl/test_distributedarray_nccl.py b/tests_nccl/test_distributedarray_nccl.py new file mode 100644 index 00000000..3478c8a8 --- /dev/null +++ b/tests_nccl/test_distributedarray_nccl.py @@ -0,0 +1,408 @@ +"""Test the DistributedArray class +Designed to run with n GPUs (with 1 MPI process per GPU) +$ mpiexec -n 3 pytest test_distributedarray_nccl.py --with-mpi + +This file employs the same test sets as test_distributedarray under NCCL environment +""" + +import numpy as np +import cupy as cp +from mpi4py import MPI +import pytest +from numpy.testing import assert_allclose + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.DistributedArray import local_split +from pylops_mpi.utils._nccl import initialize_nccl_comm + +np.random.seed(42) + +nccl_comm = initialize_nccl_comm() + +par1 = { + "global_shape": (500, 501), + "partition": Partition.SCATTER, + "dtype": np.float64, + "axis": 1, +} + +par2 = { + "global_shape": (500, 501), + "partition": Partition.BROADCAST, + "dtype": np.float64, + "axis": 1, +} + +par3 = { + "global_shape": (200, 201, 101), + "partition": Partition.SCATTER, + "dtype": np.float64, + "axis": 1, +} + +par4 = { + "x": np.random.normal(100, 100, (500, 501)), + "partition": Partition.SCATTER, + "axis": 1, +} + +par5 = { + "x": np.random.normal(300, 300, (500, 501)), + "partition": Partition.SCATTER, + "axis": 1, +} + +par6 = { + "x": np.random.normal(100, 100, (600, 600)), + "partition": Partition.SCATTER, + "axis": 0, +} + +par6b = { + "x": np.random.normal(100, 100, (600, 600)), + "partition": Partition.BROADCAST, + "axis": 0, +} + +par7 = { + "x": np.random.normal(300, 300, (600, 600)), + "partition": Partition.SCATTER, + "axis": 0, +} + +par7b = { + "x": np.random.normal(300, 300, (600, 600)), + "partition": Partition.BROADCAST, + "axis": 0, +} + +par8 = { + "x": np.random.normal(100, 100, (1200,)), + "partition": Partition.SCATTER, + "axis": 0, +} + +par8b = { + "x": np.random.normal(100, 100, (1200,)), + "partition": Partition.BROADCAST, + "axis": 0, +} + +par9 = { + "x": np.random.normal(300, 300, (1200,)), + "partition": Partition.SCATTER, + "axis": 0, +} + +par9b = { + "x": np.random.normal(300, 300, (1200,)), + "partition": Partition.BROADCAST, + "axis": 0, +} + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) +def test_creation_nccl(par): + """Test creation of local arrays""" + distributed_array = DistributedArray( + global_shape=par["global_shape"], + base_comm_nccl=nccl_comm, + partition=par["partition"], + dtype=par["dtype"], + axis=par["axis"], + engine="cupy", + ) + loc_shape = local_split( + distributed_array.global_shape, + distributed_array.base_comm, + distributed_array.partition, + distributed_array.axis, + ) + assert distributed_array.global_shape == par["global_shape"] + assert distributed_array.local_shape == loc_shape + assert isinstance(distributed_array, DistributedArray) + # Distributed array of ones + distributed_ones = DistributedArray( + global_shape=par["global_shape"], + base_comm_nccl=nccl_comm, + partition=par["partition"], + dtype=par["dtype"], + axis=par["axis"], + engine="cupy", + ) + distributed_ones[:] = 1 + # Distributed array of zeroes + distributed_zeroes = DistributedArray( + global_shape=par["global_shape"], + base_comm_nccl=nccl_comm, + partition=par["partition"], + dtype=par["dtype"], + axis=par["axis"], + engine="cupy", + ) + distributed_zeroes[:] = 0 + # Test for distributed ones + assert isinstance(distributed_ones, DistributedArray) + assert_allclose( + distributed_ones.local_array.get(), + np.ones(shape=distributed_ones.local_shape, dtype=par["dtype"]), + rtol=1e-14, + ) + assert_allclose( + distributed_ones.asarray().get(), + np.ones(shape=distributed_ones.global_shape, dtype=par["dtype"]), + rtol=1e-14, + ) + # Test for distributed zeroes + assert isinstance(distributed_zeroes, DistributedArray) + assert_allclose( + distributed_zeroes.local_array.get(), + np.zeros(shape=distributed_zeroes.local_shape, dtype=par["dtype"]), + rtol=1e-14, + ) + assert_allclose( + distributed_zeroes.asarray().get(), + np.zeros(shape=distributed_zeroes.global_shape, dtype=par["dtype"]), + rtol=1e-14, + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par4), (par5)]) +def test_to_dist_nccl(par): + """Test the ``to_dist`` method""" + x_gpu = cp.asarray(par["x"]) + dist_array = DistributedArray.to_dist( + x=x_gpu, + base_comm_nccl=nccl_comm, + partition=par["partition"], + axis=par["axis"], + ) + assert isinstance(dist_array, DistributedArray) + assert dist_array.global_shape == par["x"].shape + assert dist_array.axis == par["axis"] + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) +def test_local_shapes_nccl(par): + """Test the `local_shapes` parameter in DistributedArray""" + # Reverse the local_shapes to test the local_shapes parameter + loc_shapes = MPI.COMM_WORLD.allgather( + local_split(par["global_shape"], MPI.COMM_WORLD, par["partition"], par["axis"]) + )[::-1] + distributed_array = DistributedArray( + global_shape=par["global_shape"], + base_comm_nccl=nccl_comm, + partition=par["partition"], + axis=par["axis"], + local_shapes=loc_shapes, + dtype=par["dtype"], + engine="cupy", + ) + assert isinstance(distributed_array, DistributedArray) + assert distributed_array.local_shape == loc_shapes[distributed_array.rank] + + # Distributed ones + distributed_array[:] = 1 + assert_allclose( + distributed_array.local_array.get(), + np.ones(loc_shapes[distributed_array.rank], dtype=par["dtype"]), + rtol=1e-14, + ) + assert_allclose( + distributed_array.asarray().get(), + np.ones(par["global_shape"], dtype=par["dtype"]), + rtol=1e-14, + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize("par1, par2", [(par4, par5)]) +def test_distributed_math_nccl(par1, par2): + """Test the Element-Wise Addition, Subtraction and Multiplication""" + x1_gpu = cp.asarray(par1["x"]) + x2_gpu = cp.asarray(par2["x"]) + arr1 = DistributedArray.to_dist( + x=x1_gpu, base_comm_nccl=nccl_comm, partition=par1["partition"] + ) + arr2 = DistributedArray.to_dist( + x=x2_gpu, base_comm_nccl=nccl_comm, partition=par2["partition"] + ) + + # Addition + sum_array = arr1 + arr2 + assert isinstance(sum_array, DistributedArray) + # Subtraction + sub_array = arr1 - arr2 + assert isinstance(sub_array, DistributedArray) + # Multiplication + mult_array = arr1 * arr2 + assert isinstance(mult_array, DistributedArray) + # Global array of Sum with np.add + + assert_allclose(sum_array.asarray().get(), np.add(par1["x"], par2["x"]), rtol=1e-14) + # Global array of Subtract with np.subtract + assert_allclose( + sub_array.asarray().get(), np.subtract(par1["x"], par2["x"]), rtol=1e-14 + ) + # Global array of Multiplication with np.multiply + assert_allclose( + mult_array.asarray().get(), np.multiply(par1["x"], par2["x"]), rtol=1e-14 + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)] +) +def test_distributed_dot_nccl(par1, par2): + """Test Distributed Dot product""" + x1_gpu = cp.asarray(par1["x"]) + x2_gpu = cp.asarray(par2["x"]) + arr1 = DistributedArray.to_dist( + x=x1_gpu, base_comm_nccl=nccl_comm, partition=par1["partition"], axis=par1["axis"] + ) + arr2 = DistributedArray.to_dist( + x=x2_gpu, base_comm_nccl=nccl_comm, partition=par2["partition"], axis=par2["axis"] + ) + assert_allclose( + (arr1.dot(arr2)).get(), + np.dot(par1["x"].flatten(), par2["x"].flatten()), + rtol=1e-14, + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", + [ + (par4), + (par5), + (par6), + (par6b), + (par7), + (par7b), + (par8), + (par8b), + (par9), + (par9b), + ], +) +def test_distributed_norm_nccl(par): + """Test Distributed numpy.linalg.norm method""" + x_gpu = cp.asarray(par["x"]) + arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, axis=par["axis"]) + assert_allclose( + arr.norm(ord=1, axis=par["axis"]).get(), + np.linalg.norm(par["x"], ord=1, axis=par["axis"]), + rtol=1e-14, + ) + assert_allclose( + arr.norm(ord=np.inf, axis=par["axis"]).get(), + np.linalg.norm(par["x"], ord=np.inf, axis=par["axis"]), + rtol=1e-14, + ) + assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)] +) +def test_distributed_maskeddot_nccl(par1, par2): + """Test Distributed Dot product with masked array""" + # number of subcommunicators + if MPI.COMM_WORLD.Get_size() % 2 == 0: + nsub = 2 + elif MPI.COMM_WORLD.Get_size() % 3 == 0: + nsub = 3 + else: + pass + subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub) + mask = np.repeat(np.arange(nsub), subsize) + # Replicate x1 and x2 as required in masked arrays + x1, x2 = par1["x"], par2["x"] + if par1["axis"] != 0: + x1 = np.swapaxes(x1, par1["axis"], 0) + for isub in range(1, nsub): + x1[(x1.shape[0] // nsub) * isub : (x1.shape[0] // nsub) * (isub + 1)] = x1[ + : x1.shape[0] // nsub + ] + if par1["axis"] != 0: + x1 = np.swapaxes(x1, 0, par1["axis"]) + if par2["axis"] != 0: + x2 = np.swapaxes(x2, par2["axis"], 0) + for isub in range(1, nsub): + x2[(x2.shape[0] // nsub) * isub : (x2.shape[0] // nsub) * (isub + 1)] = x2[ + : x2.shape[0] // nsub + ] + if par2["axis"] != 0: + x2 = np.swapaxes(x2, 0, par2["axis"]) + + x1_gpu, x2_gpu = cp.asarray(x1), cp.asarray(x2) + arr1 = DistributedArray.to_dist( + x=x1_gpu, + base_comm_nccl=nccl_comm, + partition=par1["partition"], + mask=mask, + axis=par1["axis"], + ) + arr2 = DistributedArray.to_dist( + x=x2_gpu, + base_comm_nccl=nccl_comm, + partition=par2["partition"], + mask=mask, + axis=par2["axis"], + ) + assert_allclose( + arr1.dot(arr2).get(), np.dot(x1.flatten(), x2.flatten()) / nsub, rtol=1e-14 + ) + + +@pytest.mark.mpi(min_size=2) +@pytest.mark.parametrize( + "par", [(par6), (par6b), (par7), (par7b), (par8), (par8b), (par9), (par9b)] +) +def test_distributed_maskednorm_nccl(par): + """Test Distributed numpy.linalg.norm method with masked array""" + # number of subcommunicators + if MPI.COMM_WORLD.Get_size() % 2 == 0: + nsub = 2 + elif MPI.COMM_WORLD.Get_size() % 3 == 0: + nsub = 3 + else: + pass + subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub) + mask = np.repeat(np.arange(nsub), subsize) + # Replicate x as required in masked arrays + x = par["x"] + if par["axis"] != 0: + x = np.swapaxes(x, par["axis"], 0) + for isub in range(1, nsub): + x[(x.shape[0] // nsub) * isub : (x.shape[0] // nsub) * (isub + 1)] = x[ + : x.shape[0] // nsub + ] + if par["axis"] != 0: + x = np.swapaxes(x, 0, par["axis"]) + + x_gpu = cp.asarray(x) + arr = DistributedArray.to_dist( + x=x_gpu, base_comm_nccl=nccl_comm, mask=mask, axis=par["axis"] + ) + assert_allclose( + arr.norm(ord=1, axis=par["axis"]).get(), + np.linalg.norm(par["x"], ord=1, axis=par["axis"]) / nsub, + rtol=1e-14, + ) + assert_allclose( + arr.norm(ord=np.inf, axis=par["axis"]).get(), + np.linalg.norm(par["x"], ord=np.inf, axis=par["axis"]), + rtol=1e-14, + ) + assert_allclose( + arr.norm(ord=2, axis=par["axis"]).get(), + np.linalg.norm(par["x"], ord=2, axis=par["axis"]) / np.sqrt(nsub), + rtol=1e-13, + )