diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
index 66e1a373..fca1c1cd 100644
--- a/docs/source/api/index.rst
+++ b/docs/source/api/index.rst
@@ -42,7 +42,7 @@ Basic Operators
.. autosummary::
:toctree: generated/
- MPIMatrixMult
+ MatrixMult.MPIMatrixMult
MPIBlockDiag
MPIStackedBlockDiag
MPIVStack
@@ -118,6 +118,16 @@ Utils
local_split
+.. currentmodule:: pylops_mpi.basicoperators.MatrixMult
+
+.. autosummary::
+ :toctree: generated/
+
+ block_gather
+ local_block_split
+ active_grid_comm
+
+
.. currentmodule:: pylops_mpi.utils.dottest
.. autosummary::
diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py
index 9c7e2d35..082d924c 100644
--- a/examples/plot_matrixmult.py
+++ b/examples/plot_matrixmult.py
@@ -28,6 +28,7 @@
import pylops_mpi
from pylops_mpi import Partition
+from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult
plt.close("all")
@@ -88,8 +89,7 @@
# than the row or columm ranks.
base_comm = MPI.COMM_WORLD
-comm, rank, row_id, col_id, is_active = \
- pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M)
+comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
if not is_active: exit(0)
@@ -147,7 +147,7 @@
################################################################################
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
# operator and the input matrix :math:`\mathbf{X}`
-Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
+Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block")
col_lens = comm.allgather(my_own_cols)
total_cols = np.sum(col_lens)
diff --git a/examples/plot_summamatrixmult.py b/examples/plot_summamatrixmult.py
new file mode 100644
index 00000000..dd3f0225
--- /dev/null
+++ b/examples/plot_summamatrixmult.py
@@ -0,0 +1,146 @@
+r"""
+Distributed SUMMA Matrix Multiplication
+=======================================
+This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult`
+operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
+distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}`
+and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly,
+the adjoint operation can be performed with a matrix :math:`\mathbf{Y}` distributed
+in the same fashion as matrix :math:`\mathbf{X}`.
+
+Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly
+stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and
+:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray`
+objects where the different blocks are flattened and stored on different ranks.
+Note that to optimize communications, the ranks are organized in a square grid and
+blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast
+across different ranks during computation - see below for details.
+"""
+
+import math
+import numpy as np
+from mpi4py import MPI
+from matplotlib import pyplot as plt
+
+import pylops_mpi
+from pylops import Conj
+from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm)
+
+plt.close("all")
+
+###############################################################################
+# We set the seed such that all processes can create the input matrices filled
+# with the same random number. In practical application, such matrices will be
+# filled with data that is appropriate that is appropriate the use-case.
+np.random.seed(42)
+
+
+N, M, K = 6, 6, 6
+A_shape, x_shape, y_shape= (N, K), (K, M), (N, M)
+
+
+base_comm = MPI.COMM_WORLD
+comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
+print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
+
+
+###############################################################################
+# We are now ready to create the input matrices for our distributed matrix
+# multiplication example. We need to set up:
+# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand)
+# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand)
+# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M`
+#
+# For distributed computation, we arrange processes in a square grid of size
+# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total
+# number of MPI processes. Each process will own a block of each matrix
+# according to this 2D grid layout.
+
+p_prime = math.isqrt(comm.Get_size())
+print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes")
+
+# Create global test matrices with sequential values for easy verification
+# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering)
+# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j`
+A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
+x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape)
+
+print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})")
+print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})")
+print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)")
+
+################################################################################
+# Determine which block of each matrix this process should own
+# The 2D block distribution ensures:
+# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]`
+# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder
+#
+# .. raw:: html
+#
+#
+# Example: 2x2 Process Grid with 6x6 Matrices
+#
+# Matrix A (6x6): Matrix X (6x6):
+# ┌───────────┬───────────┐ ┌───────────┬───────────┐
+# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │
+# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │
+# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │
+# ├───────────┼───────────┤ ├───────────┼───────────┤
+# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │
+# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │
+# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │
+# └───────────┴───────────┘ └───────────┴───────────┘
+#
+# Process (0,0): A[0:3, 0:3], X[0:3, 0:3]
+# Process (0,1): A[0:3, 3:6], X[0:3, 3:6]
+# Process (1,0): A[3:6, 0:3], X[3:6, 0:3]
+# Process (1,1): A[3:6, 3:6], X[3:6, 3:6]
+#
+#
+
+A_slice = local_block_spit(A_shape, rank, comm)
+x_slice = local_block_spit(x_shape, rank, comm)
+################################################################################
+# Extract the local portion of each matrix for this process
+A_local = A_data[A_slice]
+x_local = x_data[x_slice]
+
+print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}")
+print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}")
+
+x_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
+ local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]),
+ base_comm=comm,
+ partition=pylops_mpi.Partition.SCATTER,
+ dtype=x_local.dtype)
+x_dist[:] = x_local.flatten()
+
+################################################################################
+# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
+# operator and the input matrix :math:`\mathbf{X}`. Given that we chose a block-block distribution
+# of data we shall use SUMMA
+Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype)
+
+################################################################################
+# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which
+# effectively implements a distributed matrix-matrix multiplication
+# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same
+# way as the input :math:`\mathbf{X}` in a block-block fashion.
+y_dist = Aop @ x_dist
+
+###############################################################################
+# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
+# (which effectively implements a distributed summa matrix-matrix multiplication
+# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
+# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
+# :math:`\mathbf{X}` in a block-block fashion.
+xadj_dist = Aop.H @ y_dist
+
+###############################################################################
+# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
+# operator can be combined with any other PyLops-MPI operator. We are going to
+# apply here a conjugate operator to the output of the matrix multiplication.
+Dop = Conj(dims=(A_local.shape[0], x_local.shape[1]))
+DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ])
+Op = DBop @ Aop
+y1 = Op @ x_dist
diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py
index 49077325..7ad9f8c4 100644
--- a/pylops_mpi/LinearOperator.py
+++ b/pylops_mpi/LinearOperator.py
@@ -76,7 +76,6 @@ def matvec(self, x: DistributedArray) -> DistributedArray:
"""
M, N = self.shape
-
if x.global_shape != (N,):
raise ValueError("dimension mismatch")
diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py
index 39eda45e..8738470c 100644
--- a/pylops_mpi/basicoperators/MatrixMult.py
+++ b/pylops_mpi/basicoperators/MatrixMult.py
@@ -1,6 +1,9 @@
-import numpy as np
+
import math
+import numpy as np
+from typing import Tuple, Union, Literal
from mpi4py import MPI
+
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray
@@ -11,8 +14,153 @@
)
-class MPIMatrixMult(MPILinearOperator):
- r"""MPI Matrix multiplication
+def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
+ r"""Configure active grid for distributed matrix multiplication.
+
+ Configure a square process grid from a parent MPI communicator and
+ select a subset of "active" processes. Each process in ``base_comm``
+ is assigned to a logical 2D grid of size :math:`P' \times P'`,
+ where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
+ :math:`active_dim x active_dim` processes
+ (by row-major order) are considered "active". Inactive ranks return
+ immediately with no new communicator.
+
+ Parameters:
+ -----------
+ base_comm : :obj:`mpi4py.MPI.Comm`
+ MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
+ N : :obj:`int`
+ Number of rows of the global data domain.
+ M : :obj:`int`
+ Number of columns of the global data domain.
+
+ Returns:
+ --------
+ comm : :obj:`mpi4py.MPI.Comm`
+ Sub-communicator including only active ranks.
+ rank : :obj:`int`
+ Rank within the new sub-communicator (or original rank
+ if inactive).
+ row : :obj:`int`
+ Grid row index of this process in the active grid (or original rank
+ if inactive).
+ col : :obj:`int`
+ Grid column index of this process in the active grid
+ (or original rank if inactive).
+ is_active : :obj:`bool`
+ Flag indicating whether this rank is in the active sub-grid.
+
+ """
+ rank = base_comm.Get_rank()
+ size = base_comm.Get_size()
+ p_prime = math.isqrt(size)
+ row, col = divmod(rank, p_prime)
+ active_dim = min(N, M, p_prime)
+ is_active = (row < active_dim and col < active_dim)
+
+ if not is_active:
+ return None, rank, row, col, False
+
+ active_ranks = [r for r in range(size)
+ if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
+ new_group = base_comm.Get_group().Incl(active_ranks)
+ new_comm = base_comm.Create_group(new_group)
+ p_prime_new = math.isqrt(len(active_ranks))
+ new_rank = new_comm.Get_rank()
+ new_row, new_col = divmod(new_rank, p_prime_new)
+
+ return new_comm, new_rank, new_row, new_col, True
+
+
+def local_block_spit(global_shape: Tuple[int, int],
+ rank: int,
+ comm: MPI.Comm) -> Tuple[slice, slice]:
+ r"""
+ Compute the local sub‐block of a 2D global array for a process in a square process grid.
+
+ Parameters
+ ----------
+ global_shape : Tuple[int, int]
+ Dimensions of the global 2D array (n_rows, n_cols).
+ rank : int
+ Rank of the MPI process in `comm` for which to get the owned block partition.
+ comm : MPI.Comm
+ MPI communicator whose total number of processes :math:`\mathbf{P}`
+ must be a perfect square :math:`\mathbf{P} = \sqrt{\mathbf{P'}}`.
+
+ Returns
+ -------
+ Tuple[slice, slice]
+ Two `slice` objects `(row_slice, col_slice)` indicating the sub‐block
+ of the global array owned by this rank.
+
+ Raises
+ ------
+ ValueError
+ if `rank` is out of range.
+ RuntimeError
+ If the number of processes participating in the provided communicator is not a perfect square.
+ """
+ size = comm.Get_size()
+ p_prime = math.isqrt(size)
+ if p_prime * p_prime != size:
+ raise RuntimeError(f"Number of processes must be a square number, provided {size} instead...")
+ if not ( isinstance(rank, int) and 0 <= rank < size ):
+ raise ValueError(f"rank must be integer in [0, {size}), got {rank!r}")
+
+ pr, pc = divmod(rank, p_prime)
+ orig_r, orig_c = global_shape
+ new_r = math.ceil(orig_r / p_prime) * p_prime
+ new_c = math.ceil(orig_c / p_prime) * p_prime
+ blkr, blkc = new_r // p_prime, new_c // p_prime
+ rs, cs = pr * blkr, pc * blkc
+ re, ce = min(rs + blkr, orig_r), min(cs + blkc, orig_c)
+ return slice(rs, re), slice(cs, ce)
+
+
+def block_gather(x: DistributedArray, orig_shape: Tuple[int, int], comm: MPI.Comm):
+ r"""
+ Gather distributed local blocks from 2D block distributed matrix distributed
+ amongst a square process grid into the full global array.
+
+ Parameters
+ ----------
+ x : :obj:`pylops_mpi.DistributedArray`
+ The distributed array to gather locally.
+ orig_shape : Tuple[int, int]
+ Original shape `(N, M)` of the global array to be gathered.
+ comm : MPI.Comm
+ MPI communicator whose size must be a perfect square (P = p_prime**2).
+
+ Returns
+ -------
+ Array
+ The reconstructed 2D array of shape `orig_shape`, assembled from
+ the distributed blocks.
+
+ Raises
+ ------
+ RuntimeError
+ If the number of processes participating in the provided communicator is not a perfect square.
+ """
+ ncp = get_module(x.engine)
+ p_prime = math.isqrt(comm.Get_size())
+ if p_prime * p_prime != comm.Get_size():
+ raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}")
+
+ all_blks = comm.allgather(x.local_array)
+ nr, nc = orig_shape
+ br, bc = math.ceil(nr / p_prime), math.ceil(nc / p_prime)
+ C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
+ for rank in range(p_prime * p_prime):
+ pr, pc = divmod(rank, p_prime)
+ rs, cs = pr * br, pc * bc
+ re, ce = min(rs + br, nr), min(cs + bc, nc)
+ C[rs:re, cs:ce] = all_blks[rank].reshape(re - rs, cs - ce)
+ return C
+
+class _MPIBlockMatrixMult(MPILinearOperator):
+ r"""MPI Blocked Matrix multiplication
Implement distributed matrix-matrix multiplication between a matrix
:math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
@@ -29,7 +177,7 @@ class MPIMatrixMult(MPILinearOperator):
Global leading dimension (i.e., number of columns) of the matrices
representing the input model and data vectors.
saveAt : :obj:`bool`, optional
- Save ``A`` and ``A.H`` to speed up the computation of adjoint
+ Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
(``True``) or create ``A.H`` on-the-fly (``False``)
Note that ``saveAt=True`` will double the amount of required memory.
Default is ``False``.
@@ -68,22 +216,22 @@ class MPIMatrixMult(MPILinearOperator):
processes by a factor equivalent to :math:`\sqrt{P}` across a square process
grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
- - The matrix ``A`` is distributed across MPI processes in a block-row fashion
- and each process holds a local block of ``A`` with shape
+ - The matrix :math:`\mathbf{A}` is distributed across MPI processes in a block-row fashion
+ and each process holds a local block of :math:`\mathbf{A}` with shape
:math:`[N_{loc} \times K]`
- - The operand matrix ``X`` is distributed in a block-column fashion and
- each process holds a local block of ``X`` with shape
+ - The operand matrix :math:`\mathbf{X}` is distributed in a block-column fashion and
+ each process holds a local block of :math:`\mathbf{X}` with shape
:math:`[K \times M_{loc}]`
- Communication is minimized by using a 2D process grid layout
**Forward Operation step-by-step**
- 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
+ 1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X}`
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
is the number of columns assigned to the current process.
2. **Local Computation**: Each process computes ``A_local @ X_local`` where:
- - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
+ - ``A_local`` is the local block of matrix :math:`\mathbf{A}` (shape ``N_local x K``)
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
3. **Row-wise Gather**: Results from all processes in each row are gathered
@@ -98,10 +246,10 @@ class MPIMatrixMult(MPILinearOperator):
representing the local columns of the input matrix.
2. **Local Adjoint Computation**: Each process computes
- ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed
- and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as
+ ``A_local.H @ X_tile`` where ``A_local.H`` is either pre-computed
+ and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as
``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its
- transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``)
+ transposed local :math:`\mathbf{A}` block ``A_local^H`` (shape ``K x N_block``)
with the extracted ``X_tile`` (shape ``N_block x M_local``),
producing a partial result of shape ``(K, M_local)``.
This computes the local contribution of columns of ``A^H`` to the final
@@ -163,81 +311,23 @@ def __init__(
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
- @staticmethod
- def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
- r"""Configure active grid
-
- Configure a square process grid from a parent MPI communicator and
- select a subset of "active" processes. Each process in ``base_comm``
- is assigned to a logical 2D grid of size :math:`P' \times P'`,
- where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
- :math:`active_dim x active_dim` processes
- (by row-major order) are considered "active". Inactive ranks return
- immediately with no new communicator.
-
- Parameters:
- -----------
- base_comm : :obj:`mpi4py.MPI.Comm`
- MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
- N : :obj:`int`
- Number of rows of the global data domain.
- M : :obj:`int`
- Number of columns of the global data domain.
-
- Returns:
- --------
- comm : :obj:`mpi4py.MPI.Comm`
- Sub-communicator including only active ranks.
- rank : :obj:`int`
- Rank within the new sub-communicator (or original rank
- if inactive).
- row : :obj:`int`
- Grid row index of this process in the active grid (or original rank
- if inactive).
- col : :obj:`int`
- Grid column index of this process in the active grid
- (or original rank if inactive).
- is_active : :obj:`bool`
- Flag indicating whether this rank is in the active sub-grid.
-
- """
- rank = base_comm.Get_rank()
- size = base_comm.Get_size()
- p_prime = math.isqrt(size)
- row, col = divmod(rank, p_prime)
- active_dim = min(N, M, p_prime)
- is_active = (row < active_dim and col < active_dim)
-
- if not is_active:
- return None, rank, row, col, False
-
- active_ranks = [r for r in range(size)
- if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
- new_group = base_comm.Get_group().Incl(active_ranks)
- new_comm = base_comm.Create_group(new_group)
- p_prime_new = math.isqrt(len(active_ranks))
- new_rank = new_comm.Get_rank()
- new_row, new_col = divmod(new_rank, p_prime_new)
-
- return new_comm, new_rank, new_row, new_col, True
-
def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
-
+ output_dtype = np.result_type(self.dtype, x.dtype)
y = DistributedArray(
global_shape=(self.N * self.dimsd[1]),
local_shapes=[(self.N * c) for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
- dtype=self.dtype,
+ dtype=output_dtype,
base_comm=self.base_comm
)
my_own_cols = self._rank_col_lens[self.rank]
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
- X_local = x_arr.astype(self.dtype)
+ X_local = x_arr.astype(output_dtype)
Y_local = ncp.vstack(
self._row_comm.allgather(
ncp.matmul(self.A, X_local)
@@ -251,19 +341,418 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
+ # - If A is real: A^H = A^T,
+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
+ # - If A is complex: A^H is complex,
+ # so result will be complex regardless of x
+ if np.iscomplexobj(self.A):
+ output_dtype = np.result_type(self.dtype, x.dtype)
+ else:
+ # Real matrix: A^T @ x preserves input type complexity
+ output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
+ # But still need to check type promotion for precision
+ output_dtype = np.result_type(self.dtype, output_dtype)
+
y = DistributedArray(
global_shape=(self.K * self.dimsd[1]),
local_shapes=[self.K * c for c in self._rank_col_lens],
mask=x.mask,
partition=Partition.SCATTER,
- dtype=self.dtype,
+ dtype=output_dtype,
base_comm=self.base_comm
)
- x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
+ x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype)
X_tile = x_arr[self._row_start:self._row_end, :]
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
Y_local = ncp.matmul(A_local, X_tile)
y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
y[:] = y_layer.flatten()
return y
+
+class _MPISummaMatrixMult(MPILinearOperator):
+ r"""MPI SUMMA Matrix multiplication
+
+ Implements distributed matrix-matrix multiplication using the SUMMA algorithm
+ between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and
+ input model and data vectors, which are both interpreted as matrices
+ distributed in block fashion wherein each process owns a tile of the matrix.
+
+ Parameters
+ ----------
+ A : :obj:`numpy.ndarray`
+ Local block of the matrix of shape :math:`[N_{loc} \times K_{loc}]`
+ where :math:`N_{loc}` and :math:`K_{loc}` are the number of rows and
+ columns stored on this MPI rank.
+ M : :obj:`int`
+ Global number of columns of the matrices representing the input model
+ and data vectors.
+ saveAt : :obj:`bool`, optional
+ Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint
+ (``True``) or create ``A.H`` on-the-fly (``False``).
+ Note that ``saveAt=True`` will double the amount of required memory.
+ Default is ``False``.
+ base_comm : :obj:`mpi4py.MPI.Comm`, optional
+ MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
+ dtype : :obj:`str`, optional
+ Type of elements in input array.
+
+ Attributes
+ ----------
+ shape : :obj:`tuple`
+ Operator shape
+
+ Raises
+ ------
+ Exception
+ If the operator is created with a non-square number of MPI ranks.
+ ValueError
+ If input vector does not have the correct partition type.
+
+ Notes
+ -----
+ This operator performs distributed matrix-matrix multiplication using the
+ SUMMA (Scalable Universal Matrix Multiplication Algorithm), whose forward
+ operation can be described as :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` where:
+
+ - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
+
+ The adjoint operation is represented by
+ :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
+ :math:`\mathbf{A}^H` is the complex conjugate transpose of :math:`\mathbf{A}`.
+
+ This implementation is based on a 2D block distribution across a square process
+ grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows:
+
+ - The matrix :math:`\mathbf{A}` is distributed across MPI processes in 2D blocks where
+ each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K_{loc}]`
+ where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`.
+
+ - The operand matrix :math:`\mathbf{X}` is also distributed across MPI processes in 2D blocks where
+ each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K_{loc} \times M_{loc}]`
+ where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
+
+ - The result matrix :math:`\mathbf{Y}` is also distributed across MPI processes in 2D blocks where
+ each process holds a local block of :math:`\mathbf{Y}` with shape :math:`[N_{loc} \times M_{loc}]`
+ where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`.
+
+
+ **Forward Operation (SUMMA Algorithm)**
+
+ The forward operation implements the SUMMA algorithm:
+
+ 1. **Input Preparation**: The input vector ``x``is reshaped to ``(K_{loc}, M_{loc})`` representing
+ the local block assigned to the current process.
+
+ 2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` :
+
+ a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A}`
+ block to all other processes in the same process row.
+
+ b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X}`
+ block to all other processes in the same process column.
+
+ c. **Local Computation**: Each process computes the partial matrix
+ product ``A_broadcast @ X_broadcast`` and accumulates it to its
+ local result.
+
+ 3. **Result Assembly**: After all k SUMMA iterations, each process has computed
+ its local block of the result matrix :math:`\mathbf{Y}`.
+
+ **Adjoint Operation (SUMMA Algorithm)**
+
+ The adjoint operation performs the conjugate transpose multiplication using
+ a modified SUMMA algorithm:
+
+ 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N_{loc}, M_{loc})``
+ representing the local block of the input matrix.
+
+ 2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm:
+
+ a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A}` blocks is
+ communicated between processes. If ``saveAt=True``, the pre-computed
+ ``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly.
+
+ b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y}`
+ block to all other processes in the same process column.
+
+ c. **Local Adjoint Computation**: Each process computes the partial
+ matrix product ``A_H_broadcast @ Y_broadcast`` and accumulates it
+ to the local result.
+
+ 3. **Result Assembly**: After all adjoint SUMMA iterations, each process has
+ computed its local block of the result matrix ``X_{adj}``.
+
+ The implementation handles padding automatically to ensure proper block sizes
+ for the square process grid, and unpadding is performed before returning results.
+
+ """
+ def __init__(
+ self,
+ A: NDArray,
+ M: int,
+ saveAt: bool = False,
+ base_comm: MPI.Comm = MPI.COMM_WORLD,
+ dtype: DTypeLike = "float64",
+ ) -> None:
+ rank = base_comm.Get_rank()
+ size = base_comm.Get_size()
+
+ # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
+ self._P_prime = math.isqrt(size)
+ if self._P_prime * self._P_prime != size:
+ raise Exception(f"Number of processes must be a square number, provided {size} instead...")
+
+ self._row_id, self._col_id = divmod(rank, self._P_prime)
+
+ self.base_comm = base_comm
+ self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id)
+ self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
+
+ self.A = A.astype(np.dtype(dtype))
+
+ self.N = self._col_comm.allreduce(A.shape[0])
+ self.K = self._row_comm.allreduce(A.shape[1])
+ self.M = M
+
+ self._N_padded = math.ceil(self.N / self._P_prime) * self._P_prime
+ self._K_padded = math.ceil(self.K / self._P_prime) * self._P_prime
+ self._M_padded = math.ceil(self.M / self._P_prime) * self._P_prime
+
+ bn = self._N_padded // self._P_prime
+ bk = self._K_padded // self._P_prime
+ bm = self._M_padded // self._P_prime
+
+ pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0
+ pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0
+
+ if pr > 0 or pc > 0:
+ self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant')
+
+ if saveAt:
+ self.At = self.A.T.conj()
+
+ self.dims = (self.K, self.M)
+ self.dimsd = (self.N, self.M)
+ shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
+ super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
+
+ def _matvec(self, x: DistributedArray) -> DistributedArray:
+ ncp = get_module(x.engine)
+ if x.partition != Partition.SCATTER:
+ raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
+
+ output_dtype = np.result_type(self.dtype, x.dtype)
+ # Calculate local shapes for block distribution
+ bn = self._N_padded // self._P_prime # block size in N dimension
+ bm = self._M_padded // self._P_prime # block size in M dimension
+
+ local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
+ local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
+
+ local_shapes = self.base_comm.allgather(local_n * local_m)
+
+ y = DistributedArray(global_shape=(self.N * self.M),
+ mask=x.mask,
+ local_shapes=local_shapes,
+ partition=Partition.SCATTER,
+ dtype=output_dtype,
+ base_comm=self.base_comm)
+
+ # Calculate expected padded dimensions for x
+ bk = self._K_padded // self._P_prime # block size in K dimension
+
+ # The input x corresponds to blocks from matrix B (K x M)
+ # This process should receive a block of size (local_k x local_m)
+ local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
+
+ # Reshape x.local_array to its 2D block form
+ x_block = x.local_array.reshape((local_k, local_m))
+
+ # Pad the block to the full padded size if necessary
+ pad_k = bk - local_k
+ pad_m = bm - local_m
+
+ if pad_k > 0 or pad_m > 0:
+ x_block = ncp.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
+
+ Y_local = ncp.zeros((self.A.shape[0], bm),dtype=output_dtype)
+
+ for k in range(self._P_prime):
+ Atemp = self.A.copy() if self._col_id == k else ncp.empty_like(self.A)
+ Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
+ self._row_comm.Bcast(Atemp, root=k)
+ self._col_comm.Bcast(Xtemp, root=k)
+ Y_local += ncp.dot(Atemp, Xtemp)
+
+ Y_local_unpadded = Y_local[:local_n, :local_m]
+ y[:] = Y_local_unpadded.flatten()
+ return y
+
+ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
+ ncp = get_module(x.engine)
+ if x.partition != Partition.SCATTER:
+ raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
+
+ # Calculate local shapes for block distribution
+ bk = self._K_padded // self._P_prime # block size in K dimension
+ bm = self._M_padded // self._P_prime # block size in M dimension
+
+ # Calculate actual local shape for this process (considering original dimensions)
+ # Adjust for edge/corner processes that might have smaller blocks
+ local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
+ local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
+
+ local_shapes = self.base_comm.allgather(local_k * local_m)
+ # - If A is real: A^H = A^T,
+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
+ # - If A is complex: A^H is complex,
+ # so result will be complex regardless of x
+ if np.iscomplexobj(self.A):
+ output_dtype = np.result_type(self.dtype, x.dtype)
+ else:
+ # Real matrix: A^T @ x preserves input type complexity
+ output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
+ # But still need to check type promotion for precision
+ output_dtype = np.result_type(self.dtype, output_dtype)
+
+ y = DistributedArray(
+ global_shape=(self.K * self.M),
+ mask=x.mask,
+ local_shapes=local_shapes,
+ partition=Partition.SCATTER,
+ dtype=output_dtype,
+ base_comm=self.base_comm
+ )
+
+ # Calculate expected padded dimensions for x
+ bn = self._N_padded // self._P_prime # block size in N dimension
+
+ # The input x corresponds to blocks from the result (N x M)
+ # This process should receive a block of size (local_n x local_m)
+ local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
+
+ # Reshape x.local_array to its 2D block form
+ x_block = x.local_array.reshape((local_n, local_m))
+
+ # Pad the block to the full padded size if necessary
+ pad_n = bn - local_n
+ pad_m = bm - local_m
+
+ if pad_n > 0 or pad_m > 0:
+ x_block = ncp.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
+
+ A_local = self.At if hasattr(self, "At") else self.A.T.conj()
+ Y_local = ncp.zeros((self.A.shape[1], bm), dtype=output_dtype)
+
+ for k in range(self._P_prime):
+ requests = []
+ ATtemp = ncp.empty_like(A_local)
+ srcA = k * self._P_prime + self._row_id
+ tagA = (100 + k) * 1000 + self.rank
+ requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA))
+ if self._row_id == k:
+ fixed_col = self._col_id
+ for moving_col in range(self._P_prime):
+ destA = fixed_col * self._P_prime + moving_col
+ tagA = (100 + k) * 1000 + destA
+ requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA))
+ Xtemp = x_block.copy() if self._row_id == k else ncp.empty_like(x_block)
+ requests.append(self._col_comm.Ibcast(Xtemp, root=k))
+ MPI.Request.Waitall(requests)
+ Y_local += ncp.dot(ATtemp, Xtemp)
+
+ Y_local_unpadded = Y_local[:local_k, :local_m]
+ y[:] = Y_local_unpadded.flatten()
+ return y
+
+def MPIMatrixMult(
+ A: NDArray,
+ M: int,
+ saveAt: bool = False,
+ base_comm: MPI.Comm = MPI.COMM_WORLD,
+ kind: Literal["summa", "block"] = "summa",
+ dtype: DTypeLike = "float64",
+ ):
+ r"""
+ MPI Distributed Matrix Multiplication Operator
+
+ This general operator performs distributed matrix-matrix multiplication
+ using either the SUMMA (Scalable Universal Matrix Multiplication Algorithm)
+ or a 1D block-row decomposition algorithm, depending on the specified
+ ``kind`` parameter.
+
+ The forward operation computes::
+ :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}`
+
+ where:
+ - :math:`\mathbf{A}` is the distributed operator matrix of shape :math:`[N \times K]`
+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
+
+ The adjoint (conjugate-transpose) operation computes::
+ :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}`
+
+ where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`.
+
+ Distribution Layouts
+ --------------------
+ :summa:
+ 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`:
+ - :math:`\mathbf{A}` and :math:`\mathbf{X}` are partitioned into :math:`[N_loc \times K_loc]` and
+ :math:`[K_loc \times M_loc]` tiles on each rank, respectively.
+ - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and
+ :math:`\mathbf{X}` and accumulates local partial products.
+
+ :block:
+ 1D block-row distribution over a :math:`[1 \times P]` grid:
+ - :math:`\mathbf{A}` is partitioned into :math:`[N_loc \times K]` blocks across ranks.
+ - :math:`\mathbf{X}` (and result :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_loc]` blocks.
+ - Local multiplication is followed by row-wise gather (forward) or
+ allreduce (adjoint) across ranks.
+
+ Parameters
+ ----------
+ A : NDArray
+ Local block of the matrix operator.
+ M : int
+ Global number of columns in the operand and result matrices.
+ saveAt : bool, optional
+ If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H`
+ to accelerate adjoint operations (uses twice the memory).
+ Default is ``False``.
+ base_comm : mpi4py.MPI.Comm, optional
+ MPI communicator to use. Defaults to ``MPI.COMM_WORLD``.
+ kind : {'summa', 'block'}, optional
+ Algorithm to use: ``'summa'`` for the SUMMA 2D algorithm, or
+ ``'block'`` for the block-row-col decomposition. Default is ``'summa'``.
+ dtype : DTypeLike, optional
+ Numeric data type for computations. Default is ``np.float64``.
+
+ Attributes
+ ----------
+ shape : :obj:`tuple`
+ Operator shape
+ comm : mpi4py.MPI.Comm
+ The MPI communicator in use.
+ kind : str
+ Selected distributed matrix multiply algorithm ('summa' or 'block').
+
+ Raises
+ ------
+ NotImplementedError
+ If ``kind`` is not one of ``'summa'`` or ``'block'``.
+ Exception
+ If the MPI communicator does not form a compatible grid for the
+ selected algorithm.
+ """
+ if kind == "summa":
+ return _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype)
+ elif kind == "block":
+ return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype)
+ else:
+ raise NotImplementedError("kind must be summa or block")
+
+__all__ = ["active_grid_comm", "block_gather", "local_block_spit", "MPIMatrixMult"]
diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py
index 7def7807..b659a979 100644
--- a/tests/test_matrixmult.py
+++ b/tests/test_matrixmult.py
@@ -8,9 +8,10 @@
from mpi4py import MPI
import pytest
-from pylops.basicoperators import FirstDerivative, Identity
+from pylops.basicoperators import Conj, FirstDerivative, Identity
from pylops_mpi import DistributedArray, Partition
-from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag
+from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, local_block_spit, block_gather, \
+ active_grid_comm
np.random.seed(42)
base_comm = MPI.COMM_WORLD
@@ -55,8 +56,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
- comm, rank, row_id, col_id, is_active = \
- MPIMatrixMult.active_grid_comm(base_comm, N, M)
+ comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
if not is_active: return
size = comm.Get_size()
@@ -86,7 +86,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
X_p = X_glob[:, col_start_X:col_end_X]
# Create MPIMatrixMult operator
- Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str)
+ Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="block")
# Create DistributedArray for input x (representing B flattened)
all_local_col_len = comm.allgather(local_col_X_len)
@@ -160,3 +160,102 @@ def test_MPIMatrixMult(N, K, M, dtype_str):
rtol=np.finfo(np.dtype(dtype)).resolution,
err_msg=f"Rank {rank}: Adjoint verification failed."
)
+
+@pytest.mark.mpi(min_size=1)
+@pytest.mark.parametrize("N, K, M, dtype_str", test_params)
+def test_MPISummaMatrixMult(N, K, M, dtype_str):
+ dtype = np.dtype(dtype_str)
+
+ cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
+ base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
+
+ comm, rank, row_id, col_id, is_active = \
+ active_grid_comm(base_comm, N, M)
+ if not is_active: return
+
+ size = comm.Get_size()
+
+ # Fill local matrices
+ A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K)
+ A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
+ A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
+
+ X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M)
+ X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7
+ X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
+
+ A_slice = local_block_spit((N, K), rank, comm)
+ X_slice = local_block_spit((K, M), rank, comm)
+
+ A_p = A_glob[A_slice]
+ X_p = X_glob[X_slice]
+
+ # Create MPIMatrixMult operator
+ Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="summa")
+
+ x_dist = DistributedArray(
+ global_shape=(K * M),
+ local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]),
+ partition=Partition.SCATTER,
+ base_comm=comm,
+ dtype=dtype
+ )
+
+ x_dist.local_array[:] = X_p.ravel()
+
+ # Forward operation: y = A @ x (distributed)
+ y_dist = Aop @ x_dist
+
+ # Adjoint operation: xadj = A.H @ y (distributed)
+ xadj_dist = Aop.H @ y_dist
+
+ # Re-organize in local matrix
+ y = block_gather(y_dist, (N,M), comm)
+ xadj = block_gather(xadj_dist, (K,M), comm)
+
+ if rank == 0:
+ y_loc = A_glob @ X_glob
+ assert_allclose(
+ y.squeeze(),
+ y_loc.squeeze(),
+ rtol=np.finfo(np.dtype(dtype)).resolution,
+ err_msg=f"Rank {rank}: Forward verification failed."
+ )
+
+ xadj_loc = A_glob.conj().T @ y_loc
+ assert_allclose(
+ xadj.squeeze(),
+ xadj_loc.squeeze(),
+ rtol=np.finfo(np.dtype(dtype)).resolution,
+ err_msg=f"Rank {rank}: Adjoint verification failed."
+ )
+
+ # Chain with another operator
+ Dop = Conj(dims=(A_p.shape[0], X_p.shape[1]))
+ DBop = MPIBlockDiag(ops=[Dop,], base_comm=comm)
+ Op = DBop @ Aop
+
+ y1_dist = Op @ x_dist
+ xadj1_dist = Op.H @ y1_dist
+
+ # Re-organize in local matrix
+ y1 = block_gather(y1_dist, (N, M), comm)
+ xadj1 = block_gather(xadj1_dist, (K,M), comm)
+
+ if rank == 0:
+ y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M)
+
+ assert_allclose(
+ y1.squeeze(),
+ y1_loc.squeeze(),
+ rtol=np.finfo(y1_loc.dtype).resolution,
+ err_msg=f"Rank {rank}: Forward verification failed."
+ )
+
+ xadj1_loc = ((A_glob.conj().T @ y1_loc.conj()).ravel()).reshape(K, M)
+ assert_allclose(
+ xadj1.squeeze().ravel(),
+ xadj1_loc.squeeze().ravel(),
+ rtol=np.finfo(xadj1_loc.dtype).resolution,
+ err_msg=f"Rank {rank}: Adjoint verification failed."
+ )