diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 66e1a373..fca1c1cd 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -42,7 +42,7 @@ Basic Operators .. autosummary:: :toctree: generated/ - MPIMatrixMult + MatrixMult.MPIMatrixMult MPIBlockDiag MPIStackedBlockDiag MPIVStack @@ -118,6 +118,16 @@ Utils local_split +.. currentmodule:: pylops_mpi.basicoperators.MatrixMult + +.. autosummary:: + :toctree: generated/ + + block_gather + local_block_split + active_grid_comm + + .. currentmodule:: pylops_mpi.utils.dottest .. autosummary:: diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py index 9c7e2d35..082d924c 100644 --- a/examples/plot_matrixmult.py +++ b/examples/plot_matrixmult.py @@ -28,6 +28,7 @@ import pylops_mpi from pylops_mpi import Partition +from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult plt.close("all") @@ -88,8 +89,7 @@ # than the row or columm ranks. base_comm = MPI.COMM_WORLD -comm, rank, row_id, col_id, is_active = \ - pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M) +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") if not is_active: exit(0) @@ -147,7 +147,7 @@ ################################################################################ # We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` # operator and the input matrix :math:`\mathbf{X}` -Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32") +Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block") col_lens = comm.allgather(my_own_cols) total_cols = np.sum(col_lens) diff --git a/examples/plot_summamatrixmult.py b/examples/plot_summamatrixmult.py new file mode 100644 index 00000000..dd3f0225 --- /dev/null +++ b/examples/plot_summamatrixmult.py @@ -0,0 +1,146 @@ +r""" +Distributed SUMMA Matrix Multiplication +======================================= +This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult` +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` +distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}` +and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly, +the adjoint operation can be performed with a matrix :math:`\mathbf{Y}` distributed +in the same fashion as matrix :math:`\mathbf{X}`. + +Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly +stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and +:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray` +objects where the different blocks are flattened and stored on different ranks. +Note that to optimize communications, the ranks are organized in a square grid and +blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast +across different ranks during computation - see below for details. +""" + +import math +import numpy as np +from mpi4py import MPI +from matplotlib import pyplot as plt + +import pylops_mpi +from pylops import Conj +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm) + +plt.close("all") + +############################################################################### +# We set the seed such that all processes can create the input matrices filled +# with the same random number. In practical application, such matrices will be +# filled with data that is appropriate that is appropriate the use-case. +np.random.seed(42) + + +N, M, K = 6, 6, 6 +A_shape, x_shape, y_shape= (N, K), (K, M), (N, M) + + +base_comm = MPI.COMM_WORLD +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") + + +############################################################################### +# We are now ready to create the input matrices for our distributed matrix +# multiplication example. We need to set up: +# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) +# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand) +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M` +# +# For distributed computation, we arrange processes in a square grid of size +# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total +# number of MPI processes. Each process will own a block of each matrix +# according to this 2D grid layout. + +p_prime = math.isqrt(comm.Get_size()) +print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") + +# Create global test matrices with sequential values for easy verification +# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) +# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` +A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) +x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) + +print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") +print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") +print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") + +################################################################################ +# Determine which block of each matrix this process should own +# The 2D block distribution ensures: +# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder +# +# .. raw:: html +# +#
+# Example: 2x2 Process Grid with 6x6 Matrices +# +# Matrix A (6x6): Matrix X (6x6): +# ┌───────────┬───────────┐ ┌───────────┬───────────┐ +# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │ +# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │ +# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │ +# ├───────────┼───────────┤ ├───────────┼───────────┤ +# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │ +# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │ +# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │ +# └───────────┴───────────┘ └───────────┴───────────┘ +# +# Process (0,0): A[0:3, 0:3], X[0:3, 0:3] +# Process (0,1): A[0:3, 3:6], X[0:3, 3:6] +# Process (1,0): A[3:6, 0:3], X[3:6, 0:3] +# Process (1,1): A[3:6, 3:6], X[3:6, 3:6] +#
+# + +A_slice = local_block_spit(A_shape, rank, comm) +x_slice = local_block_spit(x_shape, rank, comm) +################################################################################ +# Extract the local portion of each matrix for this process +A_local = A_data[A_slice] +x_local = x_data[x_slice] + +print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}") +print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}") + +x_dist = pylops_mpi.DistributedArray(global_shape=(K * M), + local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), + base_comm=comm, + partition=pylops_mpi.Partition.SCATTER, + dtype=x_local.dtype) +x_dist[:] = x_local.flatten() + +################################################################################ +# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +# operator and the input matrix :math:`\mathbf{X}`. Given that we chose a block-block distribution +# of data we shall use SUMMA +Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype) + +################################################################################ +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which +# effectively implements a distributed matrix-matrix multiplication +# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same +# way as the input :math:`\mathbf{X}` in a block-block fashion. +y_dist = Aop @ x_dist + +############################################################################### +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` +# (which effectively implements a distributed summa matrix-matrix multiplication +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input +# :math:`\mathbf{X}` in a block-block fashion. +xadj_dist = Aop.H @ y_dist + +############################################################################### +# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +# operator can be combined with any other PyLops-MPI operator. We are going to +# apply here a conjugate operator to the output of the matrix multiplication. +Dop = Conj(dims=(A_local.shape[0], x_local.shape[1])) +DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ]) +Op = DBop @ Aop +y1 = Op @ x_dist diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py index 49077325..7ad9f8c4 100644 --- a/pylops_mpi/LinearOperator.py +++ b/pylops_mpi/LinearOperator.py @@ -76,7 +76,6 @@ def matvec(self, x: DistributedArray) -> DistributedArray: """ M, N = self.shape - if x.global_shape != (N,): raise ValueError("dimension mismatch") diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 39eda45e..8738470c 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -1,6 +1,9 @@ -import numpy as np + import math +import numpy as np +from typing import Tuple, Union, Literal from mpi4py import MPI + from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, NDArray @@ -11,8 +14,153 @@ ) -class MPIMatrixMult(MPILinearOperator): - r"""MPI Matrix multiplication +def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): + r"""Configure active grid for distributed matrix multiplication. + + Configure a square process grid from a parent MPI communicator and + select a subset of "active" processes. Each process in ``base_comm`` + is assigned to a logical 2D grid of size :math:`P' \times P'`, + where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first + :math:`active_dim x active_dim` processes + (by row-major order) are considered "active". Inactive ranks return + immediately with no new communicator. + + Parameters: + ----------- + base_comm : :obj:`mpi4py.MPI.Comm` + MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). + N : :obj:`int` + Number of rows of the global data domain. + M : :obj:`int` + Number of columns of the global data domain. + + Returns: + -------- + comm : :obj:`mpi4py.MPI.Comm` + Sub-communicator including only active ranks. + rank : :obj:`int` + Rank within the new sub-communicator (or original rank + if inactive). + row : :obj:`int` + Grid row index of this process in the active grid (or original rank + if inactive). + col : :obj:`int` + Grid column index of this process in the active grid + (or original rank if inactive). + is_active : :obj:`bool` + Flag indicating whether this rank is in the active sub-grid. + + """ + rank = base_comm.Get_rank() + size = base_comm.Get_size() + p_prime = math.isqrt(size) + row, col = divmod(rank, p_prime) + active_dim = min(N, M, p_prime) + is_active = (row < active_dim and col < active_dim) + + if not is_active: + return None, rank, row, col, False + + active_ranks = [r for r in range(size) + if (r // p_prime) < active_dim and (r % p_prime) < active_dim] + new_group = base_comm.Get_group().Incl(active_ranks) + new_comm = base_comm.Create_group(new_group) + p_prime_new = math.isqrt(len(active_ranks)) + new_rank = new_comm.Get_rank() + new_row, new_col = divmod(new_rank, p_prime_new) + + return new_comm, new_rank, new_row, new_col, True + + +def local_block_spit(global_shape: Tuple[int, int], + rank: int, + comm: MPI.Comm) -> Tuple[slice, slice]: + r""" + Compute the local sub‐block of a 2D global array for a process in a square process grid. + + Parameters + ---------- + global_shape : Tuple[int, int] + Dimensions of the global 2D array (n_rows, n_cols). + rank : int + Rank of the MPI process in `comm` for which to get the owned block partition. + comm : MPI.Comm + MPI communicator whose total number of processes :math:`\mathbf{P}` + must be a perfect square :math:`\mathbf{P} = \sqrt{\mathbf{P'}}`. + + Returns + ------- + Tuple[slice, slice] + Two `slice` objects `(row_slice, col_slice)` indicating the sub‐block + of the global array owned by this rank. + + Raises + ------ + ValueError + if `rank` is out of range. + RuntimeError + If the number of processes participating in the provided communicator is not a perfect square. + """ + size = comm.Get_size() + p_prime = math.isqrt(size) + if p_prime * p_prime != size: + raise RuntimeError(f"Number of processes must be a square number, provided {size} instead...") + if not ( isinstance(rank, int) and 0 <= rank < size ): + raise ValueError(f"rank must be integer in [0, {size}), got {rank!r}") + + pr, pc = divmod(rank, p_prime) + orig_r, orig_c = global_shape + new_r = math.ceil(orig_r / p_prime) * p_prime + new_c = math.ceil(orig_c / p_prime) * p_prime + blkr, blkc = new_r // p_prime, new_c // p_prime + rs, cs = pr * blkr, pc * blkc + re, ce = min(rs + blkr, orig_r), min(cs + blkc, orig_c) + return slice(rs, re), slice(cs, ce) + + +def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm): + r""" + Gather distributed local blocks from 2D block distributed matrix distributed + amongst a square process grid into the full global array. + + Parameters + ---------- + x : :obj:`pylops_mpi.DistributedArray` + The distributed array to gather locally. + orig_shape : Tuple[int, int] + Original shape `(N, M)` of the global array to be gathered. + comm : MPI.Comm + MPI communicator whose size must be a perfect square (P = p_prime**2). + + Returns + ------- + Array + The reconstructed 2D array of shape `orig_shape`, assembled from + the distributed blocks. + + Raises + ------ + RuntimeError + If the number of processes participating in the provided communicator is not a perfect square. + """ + ncp = get_module(x.engine) + p_prime = math.isqrt(comm.Get_size()) + if p_prime * p_prime != comm.Get_size(): + raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}") + + all_blks = comm.allgather(x.local_array) + nr, nc = orig_shape + br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime) + C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype) + for rank in range(p_prime * p_prime): + pr, pc = divmod(rank, p_prime) + rs, cs = pr * br, pc * bc + re, ce = min(rs + br, nr), min(cs + bc, nc) + C[rs:re, cs:ce] = all_blks[rank].reshape(re - rs, cs - ce) + return C + +class _MPIBlockMatrixMult(MPILinearOperator): + r"""MPI Blocked Matrix multiplication Implement distributed matrix-matrix multiplication between a matrix :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored @@ -29,7 +177,7 @@ class MPIMatrixMult(MPILinearOperator): Global leading dimension (i.e., number of columns) of the matrices representing the input model and data vectors. saveAt : :obj:`bool`, optional - Save ``A`` and ``A.H`` to speed up the computation of adjoint + Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint (``True``) or create ``A.H`` on-the-fly (``False``) Note that ``saveAt=True`` will double the amount of required memory. Default is ``False``. @@ -68,22 +216,22 @@ class MPIMatrixMult(MPILinearOperator): processes by a factor equivalent to :math:`\sqrt{P}` across a square process grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically: - - The matrix ``A`` is distributed across MPI processes in a block-row fashion - and each process holds a local block of ``A`` with shape + - The matrix :math:`\mathbf{A}` is distributed across MPI processes in a block-row fashion + and each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K]` - - The operand matrix ``X`` is distributed in a block-column fashion and - each process holds a local block of ``X`` with shape + - The operand matrix :math:`\mathbf{X}` is distributed in a block-column fashion and + each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K \times M_{loc}]` - Communication is minimized by using a 2D process grid layout **Forward Operation step-by-step** - 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` + 1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X}` of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local`` is the number of columns assigned to the current process. 2. **Local Computation**: Each process computes ``A_local @ X_local`` where: - - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``) + - ``A_local`` is the local block of matrix :math:`\mathbf{A}` (shape ``N_local x K``) - ``X_local`` is the broadcasted operand (shape ``K x M_local``) 3. **Row-wise Gather**: Results from all processes in each row are gathered @@ -98,10 +246,10 @@ class MPIMatrixMult(MPILinearOperator): representing the local columns of the input matrix. 2. **Local Adjoint Computation**: Each process computes - ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed - and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as + ``A_local.H @ X_tile`` where ``A_local.H`` is either pre-computed + and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its - transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``) + transposed local :math:`\mathbf{A}` block ``A_local^H`` (shape ``K x N_block``) with the extracted ``X_tile`` (shape ``N_block x M_local``), producing a partial result of shape ``(K, M_local)``. This computes the local contribution of columns of ``A^H`` to the final @@ -163,81 +311,23 @@ def __init__( shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) - @staticmethod - def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): - r"""Configure active grid - - Configure a square process grid from a parent MPI communicator and - select a subset of "active" processes. Each process in ``base_comm`` - is assigned to a logical 2D grid of size :math:`P' \times P'`, - where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first - :math:`active_dim x active_dim` processes - (by row-major order) are considered "active". Inactive ranks return - immediately with no new communicator. - - Parameters: - ----------- - base_comm : :obj:`mpi4py.MPI.Comm` - MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). - N : :obj:`int` - Number of rows of the global data domain. - M : :obj:`int` - Number of columns of the global data domain. - - Returns: - -------- - comm : :obj:`mpi4py.MPI.Comm` - Sub-communicator including only active ranks. - rank : :obj:`int` - Rank within the new sub-communicator (or original rank - if inactive). - row : :obj:`int` - Grid row index of this process in the active grid (or original rank - if inactive). - col : :obj:`int` - Grid column index of this process in the active grid - (or original rank if inactive). - is_active : :obj:`bool` - Flag indicating whether this rank is in the active sub-grid. - - """ - rank = base_comm.Get_rank() - size = base_comm.Get_size() - p_prime = math.isqrt(size) - row, col = divmod(rank, p_prime) - active_dim = min(N, M, p_prime) - is_active = (row < active_dim and col < active_dim) - - if not is_active: - return None, rank, row, col, False - - active_ranks = [r for r in range(size) - if (r // p_prime) < active_dim and (r % p_prime) < active_dim] - new_group = base_comm.Get_group().Incl(active_ranks) - new_comm = base_comm.Create_group(new_group) - p_prime_new = math.isqrt(len(active_ranks)) - new_rank = new_comm.Get_rank() - new_row, new_col = divmod(new_rank, p_prime_new) - - return new_comm, new_rank, new_row, new_col, True - def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - + output_dtype = np.result_type(self.dtype, x.dtype) y = DistributedArray( global_shape=(self.N * self.dimsd[1]), local_shapes=[(self.N * c) for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, - dtype=self.dtype, + dtype=output_dtype, base_comm=self.base_comm ) my_own_cols = self._rank_col_lens[self.rank] x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) - X_local = x_arr.astype(self.dtype) + X_local = x_arr.astype(output_dtype) Y_local = ncp.vstack( self._row_comm.allgather( ncp.matmul(self.A, X_local) @@ -251,19 +341,418 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + # - If A is real: A^H = A^T, + # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) + # - If A is complex: A^H is complex, + # so result will be complex regardless of x + if np.iscomplexobj(self.A): + output_dtype = np.result_type(self.dtype, x.dtype) + else: + # Real matrix: A^T @ x preserves input type complexity + output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype + # But still need to check type promotion for precision + output_dtype = np.result_type(self.dtype, output_dtype) + y = DistributedArray( global_shape=(self.K * self.dimsd[1]), local_shapes=[self.K * c for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, - dtype=self.dtype, + dtype=output_dtype, base_comm=self.base_comm ) - x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype) X_tile = x_arr[self._row_start:self._row_end, :] A_local = self.At if hasattr(self, "At") else self.A.T.conj() Y_local = ncp.matmul(A_local, X_tile) y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) y[:] = y_layer.flatten() return y + +class _MPISummaMatrixMult(MPILinearOperator): + r"""MPI SUMMA Matrix multiplication + + Implements distributed matrix-matrix multiplication using the SUMMA algorithm + between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and + input model and data vectors, which are both interpreted as matrices + distributed in block fashion wherein each process owns a tile of the matrix. + + Parameters + ---------- + A : :obj:`numpy.ndarray` + Local block of the matrix of shape :math:`[N_{loc} \times K_{loc}]` + where :math:`N_{loc}` and :math:`K_{loc}` are the number of rows and + columns stored on this MPI rank. + M : :obj:`int` + Global number of columns of the matrices representing the input model + and data vectors. + saveAt : :obj:`bool`, optional + Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint + (``True``) or create ``A.H`` on-the-fly (``False``). + Note that ``saveAt=True`` will double the amount of required memory. + Default is ``False``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + dtype : :obj:`str`, optional + Type of elements in input array. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + + Raises + ------ + Exception + If the operator is created with a non-square number of MPI ranks. + ValueError + If input vector does not have the correct partition type. + + Notes + ----- + This operator performs distributed matrix-matrix multiplication using the + SUMMA (Scalable Universal Matrix Multiplication Algorithm), whose forward + operation can be described as :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` where: + + - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` + + The adjoint operation is represented by + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where + :math:`\mathbf{A}^H` is the complex conjugate transpose of :math:`\mathbf{A}`. + + This implementation is based on a 2D block distribution across a square process + grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows: + + - The matrix :math:`\mathbf{A}` is distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K_{loc}]` + where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`. + + - The operand matrix :math:`\mathbf{X}` is also distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K_{loc} \times M_{loc}]` + where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. + + - The result matrix :math:`\mathbf{Y}` is also distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{Y}` with shape :math:`[N_{loc} \times M_{loc}]` + where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. + + + **Forward Operation (SUMMA Algorithm)** + + The forward operation implements the SUMMA algorithm: + + 1. **Input Preparation**: The input vector ``x``is reshaped to ``(K_{loc}, M_{loc})`` representing + the local block assigned to the current process. + + 2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` : + + a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A}` + block to all other processes in the same process row. + + b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X}` + block to all other processes in the same process column. + + c. **Local Computation**: Each process computes the partial matrix + product ``A_broadcast @ X_broadcast`` and accumulates it to its + local result. + + 3. **Result Assembly**: After all k SUMMA iterations, each process has computed + its local block of the result matrix :math:`\mathbf{Y}`. + + **Adjoint Operation (SUMMA Algorithm)** + + The adjoint operation performs the conjugate transpose multiplication using + a modified SUMMA algorithm: + + 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N_{loc}, M_{loc})`` + representing the local block of the input matrix. + + 2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm: + + a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A}` blocks is + communicated between processes. If ``saveAt=True``, the pre-computed + ``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly. + + b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y}` + block to all other processes in the same process column. + + c. **Local Adjoint Computation**: Each process computes the partial + matrix product ``A_H_broadcast @ Y_broadcast`` and accumulates it + to the local result. + + 3. **Result Assembly**: After all adjoint SUMMA iterations, each process has + computed its local block of the result matrix ``X_{adj}``. + + The implementation handles padding automatically to ensure proper block sizes + for the square process grid, and unpadding is performed before returning results. + + """ + def __init__( + self, + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + dtype: DTypeLike = "float64", + ) -> None: + rank = base_comm.Get_rank() + size = base_comm.Get_size() + + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size + self._P_prime = math.isqrt(size) + if self._P_prime * self._P_prime != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + self._row_id, self._col_id = divmod(rank, self._P_prime) + + self.base_comm = base_comm + self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) + self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) + + self.A = A.astype(np.dtype(dtype)) + + self.N = self._col_comm.allreduce(A.shape[0]) + self.K = self._row_comm.allreduce(A.shape[1]) + self.M = M + + self._N_padded = math.ceil(self.N / self._P_prime) * self._P_prime + self._K_padded = math.ceil(self.K / self._P_prime) * self._P_prime + self._M_padded = math.ceil(self.M / self._P_prime) * self._P_prime + + bn = self._N_padded // self._P_prime + bk = self._K_padded // self._P_prime + bm = self._M_padded // self._P_prime + + pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0 + pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0 + + if pr > 0 or pc > 0: + self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant') + + if saveAt: + self.At = self.A.T.conj() + + self.dims = (self.K, self.M) + self.dimsd = (self.N, self.M) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + + output_dtype = np.result_type(self.dtype, x.dtype) + # Calculate local shapes for block distribution + bn = self._N_padded // self._P_prime # block size in N dimension + bm = self._M_padded // self._P_prime # block size in M dimension + + local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn + local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm + + local_shapes = self.base_comm.allgather(local_n * local_m) + + y = DistributedArray(global_shape=(self.N * self.M), + mask=x.mask, + local_shapes=local_shapes, + partition=Partition.SCATTER, + dtype=output_dtype, + base_comm=self.base_comm) + + # Calculate expected padded dimensions for x + bk = self._K_padded // self._P_prime # block size in K dimension + + # The input x corresponds to blocks from matrix B (K x M) + # This process should receive a block of size (local_k x local_m) + local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk + + # Reshape x.local_array to its 2D block form + x_block = x.local_array.reshape((local_k, local_m)) + + # Pad the block to the full padded size if necessary + pad_k = bk - local_k + pad_m = bm - local_m + + if pad_k > 0 or pad_m > 0: + x_block = ncp.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant') + + Y_local = ncp.zeros((self.A.shape[0], bm),dtype=output_dtype) + + for k in range(self._P_prime): + Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A) + Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block) + self._row_comm.Bcast(Atemp, root=k) + self._col_comm.Bcast(Xtemp, root=k) + Y_local += ncp.dot(Atemp, Xtemp) + + Y_local_unpadded = Y_local[:local_n, :local_m] + y[:] = Y_local_unpadded.flatten() + return y + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + # Calculate local shapes for block distribution + bk = self._K_padded // self._P_prime # block size in K dimension + bm = self._M_padded // self._P_prime # block size in M dimension + + # Calculate actual local shape for this process (considering original dimensions) + # Adjust for edge/corner processes that might have smaller blocks + local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk + local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm + + local_shapes = self.base_comm.allgather(local_k * local_m) + # - If A is real: A^H = A^T, + # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) + # - If A is complex: A^H is complex, + # so result will be complex regardless of x + if np.iscomplexobj(self.A): + output_dtype = np.result_type(self.dtype, x.dtype) + else: + # Real matrix: A^T @ x preserves input type complexity + output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype + # But still need to check type promotion for precision + output_dtype = np.result_type(self.dtype, output_dtype) + + y = DistributedArray( + global_shape=(self.K * self.M), + mask=x.mask, + local_shapes=local_shapes, + partition=Partition.SCATTER, + dtype=output_dtype, + base_comm=self.base_comm + ) + + # Calculate expected padded dimensions for x + bn = self._N_padded // self._P_prime # block size in N dimension + + # The input x corresponds to blocks from the result (N x M) + # This process should receive a block of size (local_n x local_m) + local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn + + # Reshape x.local_array to its 2D block form + x_block = x.local_array.reshape((local_n, local_m)) + + # Pad the block to the full padded size if necessary + pad_n = bn - local_n + pad_m = bm - local_m + + if pad_n > 0 or pad_m > 0: + x_block = ncp.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant') + + A_local = self.At if hasattr(self, "At") else self.A.T.conj() + Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype) + + for k in range(self._P_prime): + requests = [] + ATtemp = ncp.empty_like(A_local) + srcA = k * self._P_prime + self._row_id + tagA = (100 + k) * 1000 + self.rank + requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA)) + if self._row_id == k: + fixed_col = self._col_id + for moving_col in range(self._P_prime): + destA = fixed_col * self._P_prime + moving_col + tagA = (100 + k) * 1000 + destA + requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA)) + Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block) + requests.append(self._col_comm.Ibcast(Xtemp, root=k)) + MPI.Request.Waitall(requests) + Y_local += ncp.dot(ATtemp, Xtemp) + + Y_local_unpadded = Y_local[:local_k, :local_m] + y[:] = Y_local_unpadded.flatten() + return y + +def MPIMatrixMult( + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + kind: Literal["summa", "block"] = "summa", + dtype: DTypeLike = "float64", + ): + r""" + MPI Distributed Matrix Multiplication Operator + + This general operator performs distributed matrix-matrix multiplication + using either the SUMMA (Scalable Universal Matrix Multiplication Algorithm) + or a 1D block-row decomposition algorithm, depending on the specified + ``kind`` parameter. + + The forward operation computes:: + :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` + + where: + - :math:`\mathbf{A}` is the distributed operator matrix of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` + + The adjoint (conjugate-transpose) operation computes:: + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` + + where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`. + + Distribution Layouts + -------------------- + :summa: + 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`: + - :math:`\mathbf{A}` and :math:`\mathbf{X}` are partitioned into :math:`[N_loc \times K_loc]` and + :math:`[K_loc \times M_loc]` tiles on each rank, respectively. + - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and + :math:`\mathbf{X}` and accumulates local partial products. + + :block: + 1D block-row distribution over a :math:`[1 \times P]` grid: + - :math:`\mathbf{A}` is partitioned into :math:`[N_loc \times K]` blocks across ranks. + - :math:`\mathbf{X}` (and result :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_loc]` blocks. + - Local multiplication is followed by row-wise gather (forward) or + allreduce (adjoint) across ranks. + + Parameters + ---------- + A : NDArray + Local block of the matrix operator. + M : int + Global number of columns in the operand and result matrices. + saveAt : bool, optional + If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H` + to accelerate adjoint operations (uses twice the memory). + Default is ``False``. + base_comm : mpi4py.MPI.Comm, optional + MPI communicator to use. Defaults to ``MPI.COMM_WORLD``. + kind : {'summa', 'block'}, optional + Algorithm to use: ``'summa'`` for the SUMMA 2D algorithm, or + ``'block'`` for the block-row-col decomposition. Default is ``'summa'``. + dtype : DTypeLike, optional + Numeric data type for computations. Default is ``np.float64``. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + comm : mpi4py.MPI.Comm + The MPI communicator in use. + kind : str + Selected distributed matrix multiply algorithm ('summa' or 'block'). + + Raises + ------ + NotImplementedError + If ``kind`` is not one of ``'summa'`` or ``'block'``. + Exception + If the MPI communicator does not form a compatible grid for the + selected algorithm. + """ + if kind == "summa": + return _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype) + elif kind == "block": + return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype) + else: + raise NotImplementedError("kind must be summa or block") + +__all__ = ["active_grid_comm", "block_gather", "local_block_spit", "MPIMatrixMult"] diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py index 7def7807..b659a979 100644 --- a/tests/test_matrixmult.py +++ b/tests/test_matrixmult.py @@ -8,9 +8,10 @@ from mpi4py import MPI import pytest -from pylops.basicoperators import FirstDerivative, Identity +from pylops.basicoperators import Conj, FirstDerivative, Identity from pylops_mpi import DistributedArray, Partition -from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag +from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, local_block_spit, block_gather, \ + active_grid_comm np.random.seed(42) base_comm = MPI.COMM_WORLD @@ -55,8 +56,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str): cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 - comm, rank, row_id, col_id, is_active = \ - MPIMatrixMult.active_grid_comm(base_comm, N, M) + comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) if not is_active: return size = comm.Get_size() @@ -86,7 +86,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str): X_p = X_glob[:, col_start_X:col_end_X] # Create MPIMatrixMult operator - Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str) + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="block") # Create DistributedArray for input x (representing B flattened) all_local_col_len = comm.allgather(local_col_X_len) @@ -160,3 +160,102 @@ def test_MPIMatrixMult(N, K, M, dtype_str): rtol=np.finfo(np.dtype(dtype)).resolution, err_msg=f"Rank {rank}: Adjoint verification failed." ) + +@pytest.mark.mpi(min_size=1) +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) +def test_MPISummaMatrixMult(N, K, M, dtype_str): + dtype = np.dtype(dtype_str) + + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + comm, rank, row_id, col_id, is_active = \ + active_grid_comm(base_comm, N, M) + if not is_active: return + + size = comm.Get_size() + + # Fill local matrices + A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_slice = local_block_spit((N, K), rank, comm) + X_slice = local_block_spit((K, M), rank, comm) + + A_p = A_glob[A_slice] + X_p = X_glob[X_slice] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="summa") + + x_dist = DistributedArray( + global_shape=(K * M), + local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]), + partition=Partition.SCATTER, + base_comm=comm, + dtype=dtype + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ x (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + # Re-organize in local matrix + y = block_gather(y_dist, (N,M), comm) + xadj = block_gather(xadj_dist, (K,M), comm) + + if rank == 0: + y_loc = A_glob @ X_glob + assert_allclose( + y.squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob.conj().T @ y_loc + assert_allclose( + xadj.squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + # Chain with another operator + Dop = Conj(dims=(A_p.shape[0], X_p.shape[1])) + DBop = MPIBlockDiag(ops=[Dop,], base_comm=comm) + Op = DBop @ Aop + + y1_dist = Op @ x_dist + xadj1_dist = Op.H @ y1_dist + + # Re-organize in local matrix + y1 = block_gather(y1_dist, (N, M), comm) + xadj1 = block_gather(xadj1_dist, (K,M), comm) + + if rank == 0: + y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M) + + assert_allclose( + y1.squeeze(), + y1_loc.squeeze(), + rtol=np.finfo(y1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj1_loc = ((A_glob.conj().T @ y1_loc.conj()).ravel()).reshape(K, M) + assert_allclose( + xadj1.squeeze().ravel(), + xadj1_loc.squeeze().ravel(), + rtol=np.finfo(xadj1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + )