diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4de6df86..ba460fb0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -43,4 +43,4 @@ jobs: - name: Install pylops-mpi run: pip install . - name: Testing using pytest-mpi - run: mpiexec -n ${{ matrix.rank }} pytest tests/ --with-mpi + run: mpiexec -n ${{ matrix.rank }} pytest --with-mpi diff --git a/Makefile b/Makefile index 33065808..aa62520c 100644 --- a/Makefile +++ b/Makefile @@ -36,10 +36,6 @@ lint: tests: mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi -# assuming NUM_PRCESS <= number of gpus available -tests_nccl: - mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi - doc: cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\ rm -rf source/tutorials && rm -rf build &&\ diff --git a/examples/matrixmul.py b/examples/matrixmul.py new file mode 100644 index 00000000..c34c97f0 --- /dev/null +++ b/examples/matrixmul.py @@ -0,0 +1,104 @@ +import sys +import math +import numpy as np +from mpi4py import MPI + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult + +np.random.seed(42) + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +nProcs = comm.Get_size() + + +P_prime = int(math.ceil(math.sqrt(nProcs))) +C = int(math.ceil(nProcs / P_prime)) +assert P_prime * C >= nProcs + +# matrix dims +M = 32 # any M +K = 32 # any K +N = 35 # any N + +blk_rows = int(math.ceil(M / P_prime)) +blk_cols = int(math.ceil(N / P_prime)) + +my_group = rank % P_prime +my_layer = rank // P_prime + +# sub‐communicators +layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer +group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group + +# Each rank will end up with: +# A_p: shape (my_own_rows, K) +# B_p: shape (K, my_own_cols) +# where +row_start = my_group * blk_rows +row_end = min(M, row_start + blk_rows) +my_own_rows = row_end - row_start + +col_start = my_group * blk_cols # note: same my_group index on cols +col_end = min(N, col_start + blk_cols) +my_own_cols = col_end - col_start + +# ======================= BROADCASTING THE SLICES ======================= +if rank == 0: + A = np.arange(M*K, dtype=np.float32).reshape(M, K) + B = np.arange(K*N, dtype=np.float32).reshape(K, N) + for dest in range(nProcs): + pg = dest % P_prime + rs = pg*blk_rows; re = min(M, rs+blk_rows) + cs = pg*blk_cols; ce = min(N, cs+blk_cols) + a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy() + if dest == 0: + A_p, B_p = a_block, b_block + else: + comm.Send(a_block, dest=dest, tag=100+dest) + comm.Send(b_block, dest=dest, tag=200+dest) +else: + A_p = np.empty((my_own_rows, K), dtype=np.float32) + B_p = np.empty((K, my_own_cols), dtype=np.float32) + comm.Recv(A_p, source=0, tag=100+rank) + comm.Recv(B_p, source=0, tag=200+rank) + +comm.Barrier() + +Aop = SUMMAMatrixMult(A_p, N) +col_lens = comm.allgather(my_own_cols) +total_cols = np.add.reduce(col_lens, 0) +x = DistributedArray(global_shape=K * total_cols, + local_shapes=[K * col_len for col_len in col_lens], + partition=Partition.SCATTER, + mask=[i % P_prime for i in range(comm.Get_size())], + dtype=np.float32) +x[:] = B_p.flatten() +y = Aop @ x + +# ======================= VERIFICATION =================-============= +A = np.arange(M*K).reshape(M, K).astype(np.float32) +B = np.arange(K*N).reshape(K, N).astype(np.float32) +C_true = A @ B +Z_true = (A.T.dot(C_true.conj())).conj() + + +col_start = my_layer * blk_cols # note: same my_group index on cols +col_end = min(N, col_start + blk_cols) +my_own_cols = col_end - col_start +expected_y = C_true[:,col_start:col_end].flatten() + +if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14): + print(f"RANK {rank}: FORWARD VERIFICATION FAILED") + print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}') +else: + print(f"RANK {rank}: FORWARD VERIFICATION PASSED") + +z = Aop.H @ y +expected_z = Z_true[:,col_start:col_end].flatten() +if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14): + print(f"RANK {rank}: ADJOINT VERIFICATION FAILED") + print(f'{rank} local: {z.local_array}, expected: {Z_true[:,col_start:col_end]}') +else: + print(f"RANK {rank}: ADJOINT VERIFICATION PASSED") diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 50e54d3b..431ddb6e 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -1,25 +1,12 @@ -from enum import Enum -from numbers import Integral -from typing import Any, List, Optional, Tuple, Union, NewType - import numpy as np +from typing import Optional, Union, Tuple, List +from numbers import Integral from mpi4py import MPI +from enum import Enum + from pylops.utils import DTypeLike, NDArray -from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple -from pylops.utils.backend import get_array_module, get_module, get_module_name -from pylops_mpi.utils import deps - -cupy_message = pylops_deps.cupy_import("the DistributedArray module") -nccl_message = deps.nccl_import("the DistributedArray module") - -if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split - from cupy.cuda.nccl import NcclCommunicator -else: - NcclCommunicator = Any - -NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) +from pylops.utils.backend import get_module, get_array_module, get_module_name class Partition(Enum): @@ -60,8 +47,8 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm, """ if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: local_shape = global_shape + # Split the array else: - # Split the array local_shape = list(global_shape) if base_comm.Get_rank() < (global_shape[axis] % base_comm.Get_size()): local_shape[axis] = global_shape[axis] // base_comm.Get_size() + 1 @@ -70,35 +57,6 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm, return tuple(local_shape) -def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = MPI.COMM_WORLD): - """Create new communicators based on mask - - This method creates new communicators based on ``mask``. - - Parameters - ---------- - mask : :obj:`list` - Mask defining subsets of ranks to consider when performing 'global' - operations on the distributed array such as dot product or norm. - - comm : :obj:`mpi4py.MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator`, optional - A Communicator over which array is distributed - Defaults to ``mpi4py.MPI.COMM_WORLD``. - - Returns: - ------- - sub_comm : :obj:`mpi4py.MPI.Comm` or :obj:`cupy.cuda.nccl.NcclCommunicator` - Subcommunicator according to mask - """ - # NcclCommunicatorType cannot be used with isinstance() so check the negate of MPI.Comm - if deps.nccl_enabled and not isinstance(comm, MPI.Comm): - sub_comm = nccl_split(mask) - else: - rank = comm.Get_rank() - sub_comm = comm.Split(color=mask[rank], key=rank) - return sub_comm - - class DistributedArray: r"""Distributed Numpy Arrays @@ -123,9 +81,6 @@ class DistributedArray: base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Communicator over which array is distributed. Defaults to ``mpi4py.MPI.COMM_WORLD``. - base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional - NCCL Communicator over which array is distributed. Whenever NCCL - Communicator is provided, the base_comm will be set to MPI.COMM_WORLD. partition : :obj:`Partition`, optional Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``. axis : :obj:`int`, optional @@ -143,7 +98,6 @@ class DistributedArray: def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, - base_comm_nccl: Optional[NcclCommunicatorType] = None, partition: Partition = Partition.SCATTER, axis: int = 0, local_shapes: Optional[List[Union[Tuple, Integral]]] = None, mask: Optional[List[Integral]] = None, @@ -157,25 +111,18 @@ def __init__(self, global_shape: Union[Tuple, Integral], if partition not in Partition: raise ValueError(f"Should be either {Partition.BROADCAST}, " f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}") - if base_comm_nccl and engine != "cupy": - raise ValueError("NCCL Communicator only works with engine `cupy`") - self.dtype = dtype self._global_shape = _value_or_sized_to_tuple(global_shape) - self._base_comm_nccl = base_comm_nccl - if base_comm_nccl is None: - self._base_comm = base_comm - else: - self._base_comm = MPI.COMM_WORLD + self._base_comm = base_comm self._partition = partition self._axis = axis self._mask = mask - self._sub_comm = (base_comm if base_comm_nccl is None else base_comm_nccl) if mask is None else subcomm_split(mask, (base_comm if base_comm_nccl is None else base_comm_nccl)) + self._sub_comm = base_comm if mask is None else base_comm.Split(color=mask[base_comm.rank], key=base_comm.rank) + local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) - self._local_shape = local_shapes[self.rank] if local_shapes else local_split(global_shape, base_comm, - partition, axis) - + self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, + partition, axis) self._engine = engine self._local_array = get_module(engine).empty(shape=self.local_shape, dtype=self.dtype) @@ -203,10 +150,7 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - nccl_bcast(self.base_comm_nccl, self.local_array, index, value) - else: - self.local_array[index] = self.base_comm.bcast(value) + self.local_array[index] = self.base_comm.bcast(value) else: self.local_array[index] = value @@ -230,16 +174,6 @@ def base_comm(self): """ return self._base_comm - @property - def base_comm_nccl(self): - """Base NCCL Communicator - - Returns - ------- - base_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - """ - return self._base_comm_nccl - @property def local_shape(self): """Local Shape of the Distributed array @@ -256,7 +190,7 @@ def mask(self): Returns ------- - mask : :obj:`list` + engine : :obj:`list` """ return self._mask @@ -339,15 +273,7 @@ def local_shapes(self): ------- local_shapes : :obj:`list` """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - # gather tuple of shapes from every rank and copy from GPU to CPU - all_tuples = self._allgather(self.local_shape).get() - # NCCL returns the flat array that packs every tuple as 1-dimensional array - # unpack each tuple from each rank - tuple_len = len(self.local_shape) - return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)] - else: - return self._allgather(self.local_shape) + return self.base_comm.allgather(self.local_shape) @property def sub_comm(self): @@ -355,7 +281,7 @@ def sub_comm(self): Returns ------- - sub_comm : :obj:`MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator` + sub_comm : :obj:`MPI.Comm` """ return self._sub_comm @@ -373,18 +299,13 @@ def asarray(self): if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: # Get only self.local_array. return self.local_array - - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_asarray(self.base_comm_nccl, self.local_array, self.local_shapes, self.axis) - else: - # Gather all the local arrays and apply concatenation. - final_array = self._allgather(self.local_array) - return np.concatenate(final_array, axis=self.axis) + # Gather all the local arrays and apply concatenation. + final_array = self.base_comm.allgather(self.local_array) + return np.concatenate(final_array, axis=self.axis) @classmethod def to_dist(cls, x: NDArray, base_comm: MPI.Comm = MPI.COMM_WORLD, - base_comm_nccl: NcclCommunicatorType = None, partition: Partition = Partition.SCATTER, axis: int = 0, local_shapes: Optional[List[Tuple]] = None, @@ -396,9 +317,7 @@ def to_dist(cls, x: NDArray, x : :obj:`numpy.ndarray` Global array. base_comm : :obj:`MPI.Comm`, optional - MPI base communicator - base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional - NCCL base communicator + Type of elements in input array. Defaults to ``MPI.COMM_WORLD`` partition : :obj:`Partition`, optional Distributes the array, Defaults to ``Partition.Scatter``. axis : :obj:`int`, optional @@ -416,7 +335,6 @@ def to_dist(cls, x: NDArray, """ dist_array = DistributedArray(global_shape=x.shape, base_comm=base_comm, - base_comm_nccl=base_comm_nccl, partition=partition, axis=axis, local_shapes=local_shapes, @@ -427,7 +345,7 @@ def to_dist(cls, x: NDArray, dist_array[:] = x else: slices = [slice(None)] * x.ndim - local_shapes = np.append([0], dist_array._allgather( + local_shapes = np.append([0], base_comm.allgather( dist_array.local_shape[axis])) sum_shapes = np.cumsum(local_shapes) slices[axis] = slice(sum_shapes[dist_array.rank], @@ -438,7 +356,7 @@ def to_dist(cls, x: NDArray, def _check_local_shapes(self, local_shapes): """Check if the local shapes align with the global shape""" if local_shapes: - if len(local_shapes) != self.size: + if len(local_shapes) != self.base_comm.size: raise ValueError(f"Length of local shapes is not equal to number of processes; " f"{len(local_shapes)} != {self.size}") # Check if local shape == global shape @@ -469,39 +387,22 @@ def _check_mask(self, dist_array): raise ValueError("Mask of both the arrays must be same") def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation + """MPI Allreduce operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) - else: - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if recv_buf is None: + return self.base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) - else: - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - - def _allgather(self, send_buf, recv_buf=None): - """Allgather operation + """MPI Allreduce operation with subcommunicator """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) - else: - if recv_buf is None: - return self.base_comm.allgather(send_buf) - self.base_comm.Allgather(send_buf, recv_buf) - return recv_buf + if recv_buf is None: + return self.sub_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def __neg__(self): arr = DistributedArray(global_shape=self.global_shape, @@ -585,14 +486,14 @@ def dot(self, dist_array): """ self._check_partition_shape(dist_array) self._check_mask(dist_array) - ncp = get_module(self.engine) + # Convert to Partition.SCATTER if Partition.BROADCAST x = DistributedArray.to_dist(x=self.local_array) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self y = DistributedArray.to_dist(x=dist_array.local_array) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(np.dot(x.local_array.flatten(), y.local_array.flatten())) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -608,33 +509,32 @@ def _compute_vector_norm(self, local_array: NDArray, Order of the norm """ # Compute along any axis - ncp = get_module(self.engine) ord = 2 if ord is None else ord if local_array.ndim == 1: - recv_buf = ncp.empty(shape=1, dtype=np.float64) + recv_buf = np.empty(shape=1, dtype=np.float64) else: global_shape = list(self.global_shape) global_shape[axis] = 1 - recv_buf = ncp.empty(shape=global_shape, dtype=ncp.float64) + recv_buf = np.empty(shape=global_shape, dtype=np.float64) if ord in ['fro', 'nuc']: raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) - elif ord == ncp.inf: + recv_buf = self._allreduce_subcomm(np.count_nonzero(local_array, axis=axis).astype(np.float64)) + elif ord == np.inf: # Calculate max followed by max reduction - recv_buf = self._allreduce_subcomm(ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64), + recv_buf = self._allreduce_subcomm(np.max(np.abs(local_array), axis=axis).astype(np.float64), recv_buf, op=MPI.MAX) - recv_buf = ncp.squeeze(recv_buf, axis=axis) - elif ord == -ncp.inf: + recv_buf = np.squeeze(recv_buf, axis=axis) + elif ord == -np.inf: # Calculate min followed by min reduction - recv_buf = self._allreduce_subcomm(ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64), + recv_buf = self._allreduce_subcomm(np.min(np.abs(local_array), axis=axis).astype(np.float64), recv_buf, op=MPI.MIN) - recv_buf = ncp.squeeze(recv_buf, axis=axis) + recv_buf = np.squeeze(recv_buf, axis=axis) else: - recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) - recv_buf = ncp.power(recv_buf, 1.0 / ord) + recv_buf = self._allreduce_subcomm(np.sum(np.abs(np.float_power(local_array, ord)), axis=axis)) + recv_buf = np.power(recv_buf, 1. / ord) return recv_buf def zeros_like(self): @@ -748,7 +648,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, """ ghosted_array = self.local_array.copy() if cells_front is not None: - total_cells_front = self._allgather(cells_front) + [0] + total_cells_front = self.base_comm.allgather(cells_front) + [0] # Read cells_front which needs to be sent to rank + 1(cells_front for rank + 1) cells_front = total_cells_front[self.rank + 1] if self.rank != 0: @@ -764,7 +664,7 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, self.base_comm.send(np.take(self.local_array, np.arange(-cells_front, 0), axis=self.axis), dest=self.rank + 1, tag=1) if cells_back is not None: - total_cells_back = self._allgather(cells_back) + [0] + total_cells_back = self.base_comm.allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) cells_back = total_cells_back[self.rank - 1] if self.rank != 0: @@ -868,6 +768,7 @@ def __rmul__(self, x): def add(self, stacked_array): """Stacked Distributed Addition of arrays + """ self._check_stacked_size(stacked_array) SumArray = self.copy() diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py index 266e55fe..a7bc9bea 100644 --- a/pylops_mpi/LinearOperator.py +++ b/pylops_mpi/LinearOperator.py @@ -12,27 +12,20 @@ class MPILinearOperator: - """MPI-enabled PyLops Linear Operator + """Common interface for performing matrix-vector products in distributed fashion. - Common interface for performing matrix-vector products in distributed fashion. + This class provides methods to perform matrix-vector product and adjoint matrix-vector + products using MPI. - In practice, this class provides methods to perform matrix-vector and - adjoint matrix-vector products between any :obj:`pylops.LinearOperator` - (which must be the same across ranks) and a :class:`pylops_mpi.DistributedArray` - with ``Partition.BROADCAST`` and ``Partition.UNSAFE_BROADCAST`` partition. It - internally handles the extraction of the local array from the distributed array - and the creation of the output :class:`pylops_mpi.DistributedArray`. - - Note that whilst this operator could also be used with different - :obj:`pylops.LinearOperator` across ranks, and with a - :class:`pylops_mpi.DistributedArray` with ``Partition.SCATTER``, it is however - reccomended to use the :class:`pylops_mpi.basicoperators.MPIBlockDiag` operator - instead as this can also handle distributed arrays with subcommunicators. + .. note:: End users of pylops-mpi should not use this class directly but simply + use operators that are already implemented. This class is meant for + developers only, it has to be used as the parent class of any new operator + developed within pylops-mpi. Parameters ---------- Op : :obj:`pylops.LinearOperator`, optional - PyLops Linear Operator to wrap. Defaults to ``None``. + Linear Operator. Defaults to ``None``. shape : :obj:`tuple(int, int)`, optional Shape of the MPI Linear Operator. Defaults to ``None``. dtype : :obj:`str`, optional diff --git a/pylops_mpi/StackedLinearOperator.py b/pylops_mpi/StackedLinearOperator.py index a214d47e..21339549 100644 --- a/pylops_mpi/StackedLinearOperator.py +++ b/pylops_mpi/StackedLinearOperator.py @@ -11,20 +11,16 @@ class MPIStackedLinearOperator(ABC): - """Stack of MPI-enabled PyLops Linear Operators + """Common interface for performing matrix-vector products in distributed fashion + for StackedLinearOperators. - Common interface for performing matrix-vector products in distributed fashion - for a stack of :class:`pylops_mpi.MPILinearOperator` operators. - - In practice, this class provides methods to perform matrix-vector and adjoint - matrix-vector products on a stack of :class:`pylops_mpi.MPILinearOperator` - operators, allowing the actual execution of each operator to be distributed, - whilst dispatching the execution of the different operators in sequential order. + This class provides methods to perform matrix-vector product and adjoint matrix-vector + products on a stack of :class:`pylops_mpi.MPILinearOperator` objects. .. note:: End users of pylops-mpi should not use this class directly but simply - use operators that are already implemented as extensions of this class. - This class is meant for developers only, it has to be used as the parent - class of any new stacked operator developed within pylops-mpi. + use operators that are already implemented. This class is meant for + developers only, it has to be used as the parent class of any new operator + developed within pylops-mpi. Parameters ---------- diff --git a/pylops_mpi/basicoperators/MatrixMultiply.py b/pylops_mpi/basicoperators/MatrixMultiply.py new file mode 100644 index 00000000..eb6a6cdc --- /dev/null +++ b/pylops_mpi/basicoperators/MatrixMultiply.py @@ -0,0 +1,142 @@ +import numpy as np +import math +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops.utils.typing import DTypeLike, NDArray + +from pylops_mpi import ( + DistributedArray, + MPILinearOperator, + Partition +) + + +class SUMMAMatrixMult(MPILinearOperator): + def __init__( + self, + A: NDArray, + N: int, + base_comm: MPI.Comm = MPI.COMM_WORLD, + dtype: DTypeLike = "float64", + ) -> None: + rank = base_comm.Get_rank() + size = base_comm.Get_size() + + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size + self._P_prime = int(math.ceil(math.sqrt(size))) + self._C = int(math.ceil(size / self._P_prime)) + assert self._P_prime * self._C >= size + + # Compute this process's group and layer indices + self._group_id = rank % self._P_prime + self._layer_id = rank // self._P_prime + + # Split communicators by layer (rows) and by group (columns) + self.base_comm = base_comm + self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id) + self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id) + + self.dtype = np.dtype(dtype) + self.A = np.array(A, dtype=self.dtype, copy=False) + + self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM) + self.K = A.shape[1] + self.N = N + + # Determine how many columns each group holds + block_cols = int(math.ceil(self.N / self._P_prime)) + local_col_start = self._group_id * block_cols + local_col_end = min(self.N, local_col_start + block_cols) + local_ncols = local_col_end - local_col_start + + # Sum up the total number of input columns across all processes + total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM) + self.dims = (self.K, total_ncols) + + # Recompute how many output columns each layer holds + layer_col_start = self._layer_id * block_cols + layer_col_end = min(self.N, layer_col_start + block_cols) + layer_ncols = layer_col_end - layer_col_start + total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM) + + self.dimsd = (self.M, total_layer_cols) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + blk_cols = int(math.ceil(self.N / self._P_prime)) + col_start = self._group_id * blk_cols + col_end = min(self.N, col_start + blk_cols) + my_own_cols = max(0, col_end - col_start) + x = x.local_array.reshape((self.dims[0], my_own_cols)) + x = x.astype(self.dtype, copy=False) + B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id) + C_local = ncp.vstack( + self._layer_comm.allgather( + ncp.matmul(self.A, B_block) + ) + ) + + layer_col_start = self._layer_id * blk_cols + layer_col_end = min(self.N, layer_col_start + blk_cols) + layer_ncols = max(0, layer_col_end - layer_col_start) + layer_col_lens = self.base_comm.allgather(layer_ncols) + mask = [i // self._P_prime for i in range(self.size)] + + y = DistributedArray(global_shape= (self.M * self.dimsd[1]), + local_shapes=[(self.M * c) for c in layer_col_lens], + mask=mask, + #axis=1, + partition=Partition.SCATTER, + dtype=self.dtype) + y[:] = C_local.flatten() + return y + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + # Determine local column block for this layer + blk_cols = int(math.ceil(self.N / self._P_prime)) + layer_col_start = self._layer_id * blk_cols + layer_col_end = min(self.N, layer_col_start + blk_cols) + layer_ncols = layer_col_end - layer_col_start + layer_col_lens = self.base_comm.allgather(layer_ncols) + x = x.local_array.reshape((self.M, layer_ncols)) + + # Determine local row block for this process group + blk_rows = int(math.ceil(self.M / self._P_prime)) + row_start = self._group_id * blk_rows + row_end = min(self.M, row_start + blk_rows) + + B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False) + A_local = self.A.T.conj() + + m, b = A_local.shape + pad = (-m) % self._P_prime + r = (m + pad) // self._P_prime + A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0) + A_batch = A_pad.reshape(self._P_prime, r, b) + + # Perform local matmul and unpad + Y_batch = ncp.matmul(A_batch, B_tile) + Y_pad = Y_batch.reshape(r * self._P_prime, -1) + y_local = Y_pad[:m, :] + y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM) + + mask = [i // self._P_prime for i in range(self.size)] + y = DistributedArray( + global_shape=(self.K * self.dimsd[1]), + local_shapes=[self.K * c for c in layer_col_lens], + mask=mask, + #axis=1 + partition=Partition.SCATTER, + dtype=self.dtype, + ) + y[:] = y_layer.flatten() + return y diff --git a/pylops_mpi/utils/__init__.py b/pylops_mpi/utils/__init__.py index df8d8b0e..1d064202 100644 --- a/pylops_mpi/utils/__init__.py +++ b/pylops_mpi/utils/__init__.py @@ -1,4 +1,2 @@ # isort: skip_file - from .dottest import * -from .deps import * diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py deleted file mode 100644 index c3b02b71..00000000 --- a/pylops_mpi/utils/_nccl.py +++ /dev/null @@ -1,288 +0,0 @@ -__all__ = [ - "initialize_nccl_comm", - "nccl_split", - "nccl_allgather", - "nccl_allreduce", - "nccl_bcast", - "nccl_asarray" -] - -from enum import IntEnum -from mpi4py import MPI -import os -import numpy as np -import cupy as cp -import cupy.cuda.nccl as nccl - -cupy_to_nccl_dtype = { - "float32": nccl.NCCL_FLOAT32, - "float64": nccl.NCCL_FLOAT64, - "int32": nccl.NCCL_INT32, - "int64": nccl.NCCL_INT64, - "uint8": nccl.NCCL_UINT8, - "int8": nccl.NCCL_INT8, - "uint32": nccl.NCCL_UINT32, - "uint64": nccl.NCCL_UINT64, -} - - -class NcclOp(IntEnum): - SUM = nccl.NCCL_SUM - PROD = nccl.NCCL_PROD - MAX = nccl.NCCL_MAX - MIN = nccl.NCCL_MIN - - -def mpi_op_to_nccl(mpi_op) -> NcclOp: - """ Map MPI reduction operation to NCCL equivalent - - Parameters - ---------- - mpi_op : :obj:`MPI.Op` - A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN). - - Returns: - ------- - NcclOp : :obj:`IntEnum` - A corresponding NCCL reduction operation. - """ - if mpi_op is MPI.SUM: - return NcclOp.SUM - elif mpi_op is MPI.PROD: - return NcclOp.PROD - elif mpi_op is MPI.MAX: - return NcclOp.MAX - elif mpi_op is MPI.MIN: - return NcclOp.MIN - else: - raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}") - - -def initialize_nccl_comm() -> nccl.NcclCommunicator: - """ Initialize NCCL world communicator for every GPU device - - Each GPU must be managed by exactly one MPI process. - i.e. the number of MPI process launched must be equal to - number of GPUs in communications - - Returns: - ------- - nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - A corresponding NCCL communicator - """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - size = comm.Get_size() - device_id = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") - or rank % cp.cuda.runtime.getDeviceCount() - ) - cp.cuda.Device(device_id).use() - - if rank == 0: - with cp.cuda.Device(device_id): - nccl_id_bytes = nccl.get_unique_id() - else: - nccl_id_bytes = None - nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0) - - nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank) - return nccl_comm - - -def nccl_split(mask) -> nccl.NcclCommunicator: - """ NCCL-equivalent of MPI.Split() - - Splitting the communicator into multiple NCCL subcommunicators - - Parameters - ---------- - mask : :obj:`list` - Mask defining subsets of ranks to consider when performing 'global' - operations on the distributed array such as dot product or norm. - - Returns: - ------- - sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - Subcommunicator according to mask - """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - sub_comm = comm.Split(color=mask[rank], key=rank) - - sub_rank = sub_comm.Get_rank() - sub_size = sub_comm.Get_size() - - if sub_rank == 0: - nccl_id_bytes = nccl.get_unique_id() - else: - nccl_id_bytes = None - nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0) - sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank) - - return sub_comm - - -def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray: - """ NCCL equivalent of MPI_Allgather. Gathers data from all GPUs - and distributes the concatenated result to all participants. - - Parameters - ---------- - nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - The NCCL communicator over which data will be gathered. - send_buf : :obj:`cupy.ndarray` or array-like - The data buffer from the local GPU to be sent. - recv_buf : :obj:`cupy.ndarray`, optional - The buffer to receive data from all GPUs. If None, a new - buffer will be allocated with the appropriate shape. - - Returns - ------- - recv_buf : :obj:`cupy.ndarray` - A buffer containing the gathered data from all GPUs. - """ - send_buf = ( - send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) - ) - if recv_buf is None: - recv_buf = cp.zeros( - MPI.COMM_WORLD.Get_size() * send_buf.size, - dtype=send_buf.dtype, - ) - nccl_comm.allGather( - send_buf.data.ptr, - recv_buf.data.ptr, - send_buf.size, - cupy_to_nccl_dtype[str(send_buf.dtype)], - cp.cuda.Stream.null.ptr, - ) - return recv_buf - - -def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray: - """ NCCL equivalent of MPI_Allreduce. Applies a reduction operation - (e.g., sum, max) across all GPUs and distributes the result. - - Parameters - ---------- - nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - The NCCL communicator used for collective communication. - send_buf : :obj:`cupy.ndarray` or array-like - The data buffer from the local GPU to be reduced. - recv_buf : :obj:`cupy.ndarray`, optional - The buffer to store the result of the reduction. If None, - a new buffer will be allocated with the appropriate shape. - op : :obj:mpi4py.MPI.Op, optional - The reduction operation to apply. Defaults to MPI.SUM. - - Returns - ------- - recv_buf : :obj:`cupy.ndarray` - A buffer containing the result of the reduction, broadcasted - to all GPUs. - """ - send_buf = ( - send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) - ) - if recv_buf is None: - recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype) - - nccl_comm.allReduce( - send_buf.data.ptr, - recv_buf.data.ptr, - send_buf.size, - cupy_to_nccl_dtype[str(send_buf.dtype)], - mpi_op_to_nccl(op), - cp.cuda.Stream.null.ptr, - ) - return recv_buf - - -def nccl_bcast(nccl_comm, local_array, index, value) -> None: - """ NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index - from the root GPU (rank 0) to all other GPUs. - - Parameters - ---------- - nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - The NCCL communicator used for collective communication. - local_array : :obj:`cupy.ndarray` - The local array on each GPU. The value at `index` will be broadcasted. - index : :obj:`int` - The index in the array to be broadcasted. - value : :obj:`scalar` - The value to broadcast (only used by the root GPU, rank 0). - - Returns - ------- - None - """ - if nccl_comm.rank_id() == 0: - local_array[index] = value - nccl_comm.bcast( - local_array[index].data.ptr, - local_array[index].size, - cupy_to_nccl_dtype[str(local_array[index].dtype)], - 0, - cp.cuda.Stream.null.ptr, - ) - - -def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: - """Global view of the array - - Gather all local GPU arrays into a single global array via NCCL all-gather. - - Parameters - ---------- - nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` - The NCCL communicator used for collective communication. - local_array : :obj:`cupy.ndarray` - The local array on the current GPU. - local_shapes : :obj:`list` - A list of shapes for each GPU local array (used to trim padding). - axis : :obj:`int` - The axis along which to concatenate the gathered arrays. - - Returns - ------- - final_array : :obj:`cupy.ndarray` - Global array gathered from all GPUs and concatenated along `axis`. - - Notes - ----- - NCCL's allGather requires the sending buffer to have the same size for every device. - Therefore, the padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the sending buffer has the size of - each dimension corresponding to the max possible size of that dimension. - """ - sizes_each_dim = list(zip(*local_shapes)) - - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape) - ] - - send_buf = cp.pad( - local_array, pad_size, mode="constant", constant_values=0 - ) - - # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred - ndev = len(local_shapes) - recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - nccl_allgather(nccl_comm, send_buf, recv_buf) - - # extract an individual array from each device - chunk_size = np.prod(send_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] - - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in local_shapes[i]) - chunks[i] = chunks[i].reshape(send_shape)[slicing] - # combine back to single global array - return cp.concatenate(chunks, axis=axis) diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py deleted file mode 100644 index 443098e5..00000000 --- a/pylops_mpi/utils/deps.py +++ /dev/null @@ -1,45 +0,0 @@ -__all__ = [ - "nccl_enabled" -] - -import os -from importlib import import_module, util -from typing import Optional - - -# error message at import of available package -def nccl_import(message: Optional[str] = None) -> str: - nccl_test = ( - # detect if nccl is available and the user is expecting it to be used - # CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash - util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1 - ) - if nccl_test: - # try importing it - try: - import_module("cupy.cuda.nccl") # noqa: F401 - - # if succesful, set the message to None - nccl_message = None - # if unable to import but the package is installed - except (ImportError, ModuleNotFoundError) as e: - nccl_message = ( - f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})." - "Please ensure your CUDA NCCL environment is set up correctly " - "for more detials visit 'https://docs.cupy.dev/en/stable/install.html'" - ) - print(UserWarning(nccl_message)) - else: - nccl_message = ( - "cupy.cuda.nccl package not installed or os.getenv('NCCL_PYLOPS_MPI') == 0. " - f"In order to be able to use {message} " - "ensure 'os.getenv('NCCL_PYLOPS_MPI') == 1'" - "for more details for installing NCCL visit 'https://docs.cupy.dev/en/stable/install.html'" - ) - - return nccl_message - - -nccl_enabled: bool = ( - True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False -) diff --git a/pylops_mpi/utils/dottest.py b/pylops_mpi/utils/dottest.py index e7c6c4cb..e91bc25b 100644 --- a/pylops_mpi/utils/dottest.py +++ b/pylops_mpi/utils/dottest.py @@ -4,7 +4,7 @@ import numpy as np -from pylops_mpi import DistributedArray +from pylops_mpi.DistributedArray import DistributedArray from pylops.utils.backend import to_numpy diff --git a/setup.cfg b/setup.cfg index 236f788b..0411c6c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,9 @@ [tool:pytest] addopts = --verbose -python_files = tests/*.py tests_nccl/*.py +python_files = tests/*.py [flake8] ignore = E203, E501, W503, E402 per-file-ignores = __init__.py: F401, F403, F405 -max-line-length = 88 +max-line-length = 88 \ No newline at end of file diff --git a/tests/test_distributedarray.py b/tests/test_distributedarray.py index 9eaea0f4..8ee47a85 100644 --- a/tests/test_distributedarray.py +++ b/tests/test_distributedarray.py @@ -201,7 +201,7 @@ def test_distributed_maskeddot(par1, par2): """Test Distributed Dot product with masked array""" # number of subcommunicators if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 + nsub = 2 elif MPI.COMM_WORLD.Get_size() % 3 == 0: nsub = 3 else: @@ -236,7 +236,7 @@ def test_distributed_maskednorm(par): """Test Distributed numpy.linalg.norm method with masked array""" # number of subcommunicators if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 + nsub = 2 elif MPI.COMM_WORLD.Get_size() % 3 == 0: nsub = 3 else: diff --git a/tests/test_fredholm.py b/tests/test_fredholm.py index 3d45a4c6..b3e34f73 100644 --- a/tests/test_fredholm.py +++ b/tests/test_fredholm.py @@ -135,7 +135,7 @@ def test_Fredholm1(par): y_adj_dist = Fop_MPI.H @ y_dist y_adj = y_adj_dist.asarray() # Dot test - dottest(Fop_MPI, x, y_dist, par["nsl"] * par["nx"] * par["nz"], par["nsl"] * par["ny"] * par["nz"]) + dottest(Fop_MPI, x, y_dist, par["nsl"] * par["nx"] * par["nz"],par["nsl"] * par["ny"] * par["nz"]) if rank == 0: Fop = pylops.signalprocessing.Fredholm1( diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py new file mode 100644 index 00000000..b62e46b7 --- /dev/null +++ b/tests/test_matrixmult.py @@ -0,0 +1,191 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose +from mpi4py import MPI +import math +import sys + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult + +np.random.seed(42) + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +# Define test cases: (M, K, N, dtype_str) +# M, K, N are matrix dimensions A(M,K), B(K,N) +# P_prime will be ceil(sqrt(size)). +test_params = [ + pytest.param(37, 37, 37, "float32", id="f32_37_37_37"), + pytest.param(40, 30, 50, "float64", id="f64_40_30_50"), + pytest.param(16, 20, 22, "complex64", id="c64_16_20_22"), + pytest.param(5, 4, 3, "float32", id="f32_5_4_3"), + pytest.param(1, 2, 1, "float64", id="f64_1_2_1",), + pytest.param(3, 1, 2, "float32", id="f32_3_1_2",), +] + + +@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process. +@pytest.mark.parametrize("M, K, N, dtype_str", test_params) +def test_SUMMAMatrixMult(M, K, N, dtype_str): + dtype = np.dtype(dtype_str) + + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + P_prime = int(math.ceil(math.sqrt(size))) + C = int(math.ceil(size / P_prime)) + assert P_prime * C >= size # Ensure process grid covers all processes + + my_group = rank % P_prime + my_layer = rank // P_prime + + # Create sub-communicators + layer_comm = comm.Split(color=my_layer, key=my_group) + group_comm = comm.Split(color=my_group, key=my_layer) + + # Calculate local matrix dimensions + blk_rows_A = int(math.ceil(M / P_prime)) + row_start_A = my_group * blk_rows_A + row_end_A = min(M, row_start_A + blk_rows_A) + my_own_rows_A = max(0, row_end_A - row_start_A) + + blk_cols_BC = int(math.ceil(N / P_prime)) + col_start_B = my_group * blk_cols_BC + col_end_B = min(N, col_start_B + blk_cols_BC) + my_own_cols_B = max(0, col_end_B - col_start_B) + + # Initialize local matrices + A_p = np.empty((my_own_rows_A, K), dtype=dtype) + B_p = np.empty((K, my_own_cols_B), dtype=dtype) + + # Generate and distribute test matrices + A_glob, B_glob = None, None + if rank == 0: + # Create global matrices with complex components if needed + A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) + A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + B_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) + B_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7 + B_glob = (B_glob_real + cmplx * B_glob_imag).astype(dtype) + + # Distribute matrix blocks to all ranks + for dest_rank in range(size): + dest_my_group = dest_rank % P_prime + + # Calculate destination rank's block dimensions + dest_row_start_A = dest_my_group * blk_rows_A + dest_row_end_A = min(M, dest_row_start_A + blk_rows_A) + dest_my_own_rows_A = max(0, dest_row_end_A - dest_row_start_A) + + dest_col_start_B = dest_my_group * blk_cols_BC + dest_col_end_B = min(N, dest_col_start_B + blk_cols_BC) + dest_my_own_cols_B = max(0, dest_col_end_B - dest_col_start_B) + + A_block_send = A_glob[dest_row_start_A:dest_row_end_A, :].copy() + B_block_send = B_glob[:, dest_col_start_B:dest_col_end_B].copy() + + # Validate block shapes + assert A_block_send.shape == (dest_my_own_rows_A, K) + assert B_block_send.shape == (K, dest_my_own_cols_B) + + if dest_rank == 0: + A_p, B_p = A_block_send, B_block_send + else: + if A_block_send.size > 0: + comm.Send(A_block_send, dest=dest_rank, tag=100 + dest_rank) + if B_block_send.size > 0: + comm.Send(B_block_send, dest=dest_rank, tag=200 + dest_rank) + else: + if A_p.size > 0: + comm.Recv(A_p, source=0, tag=100 + rank) + if B_p.size > 0: + comm.Recv(B_p, source=0, tag=200 + rank) + + comm.Barrier() + + # Create SUMMAMatrixMult operator + Aop = SUMMAMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str) + + # Create DistributedArray for input x (representing B flattened) + all_my_own_cols_B = comm.allgather(my_own_cols_B) + total_cols = sum(all_my_own_cols_B) + local_shapes_x = [(K * cl_b,) for cl_b in all_my_own_cols_B] + + x_dist = DistributedArray( + global_shape=(K * total_cols), + local_shapes=local_shapes_x, + partition=Partition.SCATTER, + base_comm=comm, + dtype=dtype + ) + + if B_p.size > 0: + x_dist.local_array[:] = B_p.ravel() + else: + assert x_dist.local_array.size == 0, ( + f"Rank {rank}: B_p empty but x_dist.local_array not empty " + f"(size {x_dist.local_array.size})" + ) + + # Forward operation: y = A @ B (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: z = A.H @ y (distributed y representing C) + z_dist = Aop.H @ y_dist + + if rank == 0: + if all(dim > 0 for dim in [M, K, N]): + C_true = A_glob @ B_glob + Z_true = A_glob.conj().T @ C_true + else: # Handle cases with zero dimensions + C_true = np.zeros((M, N), dtype=dtype) + Z_true = np.zeros((K if K > 0 else 0, N), dtype=dtype) if K > 0 else np.zeros((0, N), dtype=dtype) + else: + C_true = Z_true = None + + C_true = comm.bcast(C_true, root=0) + Z_true = comm.bcast(Z_true, root=0) + + col_start_C_dist = my_layer * blk_cols_BC + col_end_C_dist = min(N, col_start_C_dist + blk_cols_BC) + my_own_cols_C_dist = max(0, col_end_C_dist - col_start_C_dist) + expected_y_shape = (M * my_own_cols_C_dist,) + + assert y_dist.local_array.shape == expected_y_shape, ( + f"Rank {rank}: y_dist shape {y_dist.local_array.shape} != expected {expected_y_shape}" + ) + + if y_dist.local_array.size > 0 and C_true is not None and C_true.size > 0: + expected_y_slice = C_true[:, col_start_C_dist:col_end_C_dist] + assert_allclose( + y_dist.local_array, + expected_y_slice.ravel(), + rtol=1e-14, + atol=1e-7, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + # Verify adjoint operation (z = A.H @ y) + expected_z_shape = (K * my_own_cols_C_dist,) + assert z_dist.local_array.shape == expected_z_shape, ( + f"Rank {rank}: z_dist shape {z_dist.local_array.shape} != expected {expected_z_shape}" + ) + + # Verify adjoint result values + if z_dist.local_array.size > 0 and Z_true is not None and Z_true.size > 0: + expected_z_slice = Z_true[:, col_start_C_dist:col_end_C_dist] + assert_allclose( + z_dist.local_array, + expected_z_slice.ravel(), + rtol=1e-14, + atol=1e-7, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + group_comm.Free() + layer_comm.Free() \ No newline at end of file diff --git a/tests_nccl/test_distributedarray_nccl.py b/tests_nccl/test_distributedarray_nccl.py deleted file mode 100644 index 3478c8a8..00000000 --- a/tests_nccl/test_distributedarray_nccl.py +++ /dev/null @@ -1,408 +0,0 @@ -"""Test the DistributedArray class -Designed to run with n GPUs (with 1 MPI process per GPU) -$ mpiexec -n 3 pytest test_distributedarray_nccl.py --with-mpi - -This file employs the same test sets as test_distributedarray under NCCL environment -""" - -import numpy as np -import cupy as cp -from mpi4py import MPI -import pytest -from numpy.testing import assert_allclose - -from pylops_mpi import DistributedArray, Partition -from pylops_mpi.DistributedArray import local_split -from pylops_mpi.utils._nccl import initialize_nccl_comm - -np.random.seed(42) - -nccl_comm = initialize_nccl_comm() - -par1 = { - "global_shape": (500, 501), - "partition": Partition.SCATTER, - "dtype": np.float64, - "axis": 1, -} - -par2 = { - "global_shape": (500, 501), - "partition": Partition.BROADCAST, - "dtype": np.float64, - "axis": 1, -} - -par3 = { - "global_shape": (200, 201, 101), - "partition": Partition.SCATTER, - "dtype": np.float64, - "axis": 1, -} - -par4 = { - "x": np.random.normal(100, 100, (500, 501)), - "partition": Partition.SCATTER, - "axis": 1, -} - -par5 = { - "x": np.random.normal(300, 300, (500, 501)), - "partition": Partition.SCATTER, - "axis": 1, -} - -par6 = { - "x": np.random.normal(100, 100, (600, 600)), - "partition": Partition.SCATTER, - "axis": 0, -} - -par6b = { - "x": np.random.normal(100, 100, (600, 600)), - "partition": Partition.BROADCAST, - "axis": 0, -} - -par7 = { - "x": np.random.normal(300, 300, (600, 600)), - "partition": Partition.SCATTER, - "axis": 0, -} - -par7b = { - "x": np.random.normal(300, 300, (600, 600)), - "partition": Partition.BROADCAST, - "axis": 0, -} - -par8 = { - "x": np.random.normal(100, 100, (1200,)), - "partition": Partition.SCATTER, - "axis": 0, -} - -par8b = { - "x": np.random.normal(100, 100, (1200,)), - "partition": Partition.BROADCAST, - "axis": 0, -} - -par9 = { - "x": np.random.normal(300, 300, (1200,)), - "partition": Partition.SCATTER, - "axis": 0, -} - -par9b = { - "x": np.random.normal(300, 300, (1200,)), - "partition": Partition.BROADCAST, - "axis": 0, -} - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) -def test_creation_nccl(par): - """Test creation of local arrays""" - distributed_array = DistributedArray( - global_shape=par["global_shape"], - base_comm_nccl=nccl_comm, - partition=par["partition"], - dtype=par["dtype"], - axis=par["axis"], - engine="cupy", - ) - loc_shape = local_split( - distributed_array.global_shape, - distributed_array.base_comm, - distributed_array.partition, - distributed_array.axis, - ) - assert distributed_array.global_shape == par["global_shape"] - assert distributed_array.local_shape == loc_shape - assert isinstance(distributed_array, DistributedArray) - # Distributed array of ones - distributed_ones = DistributedArray( - global_shape=par["global_shape"], - base_comm_nccl=nccl_comm, - partition=par["partition"], - dtype=par["dtype"], - axis=par["axis"], - engine="cupy", - ) - distributed_ones[:] = 1 - # Distributed array of zeroes - distributed_zeroes = DistributedArray( - global_shape=par["global_shape"], - base_comm_nccl=nccl_comm, - partition=par["partition"], - dtype=par["dtype"], - axis=par["axis"], - engine="cupy", - ) - distributed_zeroes[:] = 0 - # Test for distributed ones - assert isinstance(distributed_ones, DistributedArray) - assert_allclose( - distributed_ones.local_array.get(), - np.ones(shape=distributed_ones.local_shape, dtype=par["dtype"]), - rtol=1e-14, - ) - assert_allclose( - distributed_ones.asarray().get(), - np.ones(shape=distributed_ones.global_shape, dtype=par["dtype"]), - rtol=1e-14, - ) - # Test for distributed zeroes - assert isinstance(distributed_zeroes, DistributedArray) - assert_allclose( - distributed_zeroes.local_array.get(), - np.zeros(shape=distributed_zeroes.local_shape, dtype=par["dtype"]), - rtol=1e-14, - ) - assert_allclose( - distributed_zeroes.asarray().get(), - np.zeros(shape=distributed_zeroes.global_shape, dtype=par["dtype"]), - rtol=1e-14, - ) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize("par", [(par4), (par5)]) -def test_to_dist_nccl(par): - """Test the ``to_dist`` method""" - x_gpu = cp.asarray(par["x"]) - dist_array = DistributedArray.to_dist( - x=x_gpu, - base_comm_nccl=nccl_comm, - partition=par["partition"], - axis=par["axis"], - ) - assert isinstance(dist_array, DistributedArray) - assert dist_array.global_shape == par["x"].shape - assert dist_array.axis == par["axis"] - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize("par", [(par1), (par2), (par3)]) -def test_local_shapes_nccl(par): - """Test the `local_shapes` parameter in DistributedArray""" - # Reverse the local_shapes to test the local_shapes parameter - loc_shapes = MPI.COMM_WORLD.allgather( - local_split(par["global_shape"], MPI.COMM_WORLD, par["partition"], par["axis"]) - )[::-1] - distributed_array = DistributedArray( - global_shape=par["global_shape"], - base_comm_nccl=nccl_comm, - partition=par["partition"], - axis=par["axis"], - local_shapes=loc_shapes, - dtype=par["dtype"], - engine="cupy", - ) - assert isinstance(distributed_array, DistributedArray) - assert distributed_array.local_shape == loc_shapes[distributed_array.rank] - - # Distributed ones - distributed_array[:] = 1 - assert_allclose( - distributed_array.local_array.get(), - np.ones(loc_shapes[distributed_array.rank], dtype=par["dtype"]), - rtol=1e-14, - ) - assert_allclose( - distributed_array.asarray().get(), - np.ones(par["global_shape"], dtype=par["dtype"]), - rtol=1e-14, - ) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize("par1, par2", [(par4, par5)]) -def test_distributed_math_nccl(par1, par2): - """Test the Element-Wise Addition, Subtraction and Multiplication""" - x1_gpu = cp.asarray(par1["x"]) - x2_gpu = cp.asarray(par2["x"]) - arr1 = DistributedArray.to_dist( - x=x1_gpu, base_comm_nccl=nccl_comm, partition=par1["partition"] - ) - arr2 = DistributedArray.to_dist( - x=x2_gpu, base_comm_nccl=nccl_comm, partition=par2["partition"] - ) - - # Addition - sum_array = arr1 + arr2 - assert isinstance(sum_array, DistributedArray) - # Subtraction - sub_array = arr1 - arr2 - assert isinstance(sub_array, DistributedArray) - # Multiplication - mult_array = arr1 * arr2 - assert isinstance(mult_array, DistributedArray) - # Global array of Sum with np.add - - assert_allclose(sum_array.asarray().get(), np.add(par1["x"], par2["x"]), rtol=1e-14) - # Global array of Subtract with np.subtract - assert_allclose( - sub_array.asarray().get(), np.subtract(par1["x"], par2["x"]), rtol=1e-14 - ) - # Global array of Multiplication with np.multiply - assert_allclose( - mult_array.asarray().get(), np.multiply(par1["x"], par2["x"]), rtol=1e-14 - ) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize( - "par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)] -) -def test_distributed_dot_nccl(par1, par2): - """Test Distributed Dot product""" - x1_gpu = cp.asarray(par1["x"]) - x2_gpu = cp.asarray(par2["x"]) - arr1 = DistributedArray.to_dist( - x=x1_gpu, base_comm_nccl=nccl_comm, partition=par1["partition"], axis=par1["axis"] - ) - arr2 = DistributedArray.to_dist( - x=x2_gpu, base_comm_nccl=nccl_comm, partition=par2["partition"], axis=par2["axis"] - ) - assert_allclose( - (arr1.dot(arr2)).get(), - np.dot(par1["x"].flatten(), par2["x"].flatten()), - rtol=1e-14, - ) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize( - "par", - [ - (par4), - (par5), - (par6), - (par6b), - (par7), - (par7b), - (par8), - (par8b), - (par9), - (par9b), - ], -) -def test_distributed_norm_nccl(par): - """Test Distributed numpy.linalg.norm method""" - x_gpu = cp.asarray(par["x"]) - arr = DistributedArray.to_dist(x=x_gpu, base_comm_nccl=nccl_comm, axis=par["axis"]) - assert_allclose( - arr.norm(ord=1, axis=par["axis"]).get(), - np.linalg.norm(par["x"], ord=1, axis=par["axis"]), - rtol=1e-14, - ) - assert_allclose( - arr.norm(ord=np.inf, axis=par["axis"]).get(), - np.linalg.norm(par["x"], ord=np.inf, axis=par["axis"]), - rtol=1e-14, - ) - assert_allclose(arr.norm().get(), np.linalg.norm(par["x"].flatten()), rtol=1e-13) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize( - "par1, par2", [(par6, par7), (par6b, par7b), (par8, par9), (par8b, par9b)] -) -def test_distributed_maskeddot_nccl(par1, par2): - """Test Distributed Dot product with masked array""" - # number of subcommunicators - if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 - elif MPI.COMM_WORLD.Get_size() % 3 == 0: - nsub = 3 - else: - pass - subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub) - mask = np.repeat(np.arange(nsub), subsize) - # Replicate x1 and x2 as required in masked arrays - x1, x2 = par1["x"], par2["x"] - if par1["axis"] != 0: - x1 = np.swapaxes(x1, par1["axis"], 0) - for isub in range(1, nsub): - x1[(x1.shape[0] // nsub) * isub : (x1.shape[0] // nsub) * (isub + 1)] = x1[ - : x1.shape[0] // nsub - ] - if par1["axis"] != 0: - x1 = np.swapaxes(x1, 0, par1["axis"]) - if par2["axis"] != 0: - x2 = np.swapaxes(x2, par2["axis"], 0) - for isub in range(1, nsub): - x2[(x2.shape[0] // nsub) * isub : (x2.shape[0] // nsub) * (isub + 1)] = x2[ - : x2.shape[0] // nsub - ] - if par2["axis"] != 0: - x2 = np.swapaxes(x2, 0, par2["axis"]) - - x1_gpu, x2_gpu = cp.asarray(x1), cp.asarray(x2) - arr1 = DistributedArray.to_dist( - x=x1_gpu, - base_comm_nccl=nccl_comm, - partition=par1["partition"], - mask=mask, - axis=par1["axis"], - ) - arr2 = DistributedArray.to_dist( - x=x2_gpu, - base_comm_nccl=nccl_comm, - partition=par2["partition"], - mask=mask, - axis=par2["axis"], - ) - assert_allclose( - arr1.dot(arr2).get(), np.dot(x1.flatten(), x2.flatten()) / nsub, rtol=1e-14 - ) - - -@pytest.mark.mpi(min_size=2) -@pytest.mark.parametrize( - "par", [(par6), (par6b), (par7), (par7b), (par8), (par8b), (par9), (par9b)] -) -def test_distributed_maskednorm_nccl(par): - """Test Distributed numpy.linalg.norm method with masked array""" - # number of subcommunicators - if MPI.COMM_WORLD.Get_size() % 2 == 0: - nsub = 2 - elif MPI.COMM_WORLD.Get_size() % 3 == 0: - nsub = 3 - else: - pass - subsize = max(1, MPI.COMM_WORLD.Get_size() // nsub) - mask = np.repeat(np.arange(nsub), subsize) - # Replicate x as required in masked arrays - x = par["x"] - if par["axis"] != 0: - x = np.swapaxes(x, par["axis"], 0) - for isub in range(1, nsub): - x[(x.shape[0] // nsub) * isub : (x.shape[0] // nsub) * (isub + 1)] = x[ - : x.shape[0] // nsub - ] - if par["axis"] != 0: - x = np.swapaxes(x, 0, par["axis"]) - - x_gpu = cp.asarray(x) - arr = DistributedArray.to_dist( - x=x_gpu, base_comm_nccl=nccl_comm, mask=mask, axis=par["axis"] - ) - assert_allclose( - arr.norm(ord=1, axis=par["axis"]).get(), - np.linalg.norm(par["x"], ord=1, axis=par["axis"]) / nsub, - rtol=1e-14, - ) - assert_allclose( - arr.norm(ord=np.inf, axis=par["axis"]).get(), - np.linalg.norm(par["x"], ord=np.inf, axis=par["axis"]), - rtol=1e-14, - ) - assert_allclose( - arr.norm(ord=2, axis=par["axis"]).get(), - np.linalg.norm(par["x"], ord=2, axis=par["axis"]) / np.sqrt(nsub), - rtol=1e-13, - ) diff --git a/tutorials/mdd.py b/tutorials/mdd.py index 5946fc05..dd175dd4 100644 --- a/tutorials/mdd.py +++ b/tutorials/mdd.py @@ -14,6 +14,7 @@ """ import numpy as np +from scipy.signal import filtfilt from matplotlib import pyplot as plt from mpi4py import MPI