Skip to content

Initial implementation of SUMMA like Matrix Mul #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions examples/matixmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import sys
import math
import numpy as np
from mpi4py import MPI

from pylops_mpi import DistributedArray, Partition
from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult

np.random.seed(42)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nProcs = comm.Get_size()


P_prime = int(math.ceil(math.sqrt(nProcs)))
C = int(math.ceil(nProcs / P_prime))
assert P_prime * C >= nProcs

# matrix dims
M = 37 # any M
K = 37 # any K
N = 37 # any N

blk_rows = int(math.ceil(M / P_prime))
blk_cols = int(math.ceil(N / P_prime))

my_group = rank % P_prime
my_layer = rank // P_prime

# sub‐communicators
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group

# Each rank will end up with:
# A_p: shape (my_own_rows, K)
# B_p: shape (K, my_own_cols)
# where
row_start = my_group * blk_rows
row_end = min(M, row_start + blk_rows)
my_own_rows = row_end - row_start

col_start = my_group * blk_cols # note: same my_group index on cols
col_end = min(N, col_start + blk_cols)
my_own_cols = col_end - col_start

# ======================= BROADCASTING THE SLICES =======================
if rank == 0:
A = np.arange(M*K, dtype=np.float32).reshape(M, K)
B = np.arange(K*N, dtype=np.float32).reshape(K, N)
for dest in range(nProcs):
pg = dest % P_prime
rs = pg*blk_rows; re = min(M, rs+blk_rows)
cs = pg*blk_cols; ce = min(N, cs+blk_cols)
a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy()
if dest == 0:
A_p, B_p = a_block, b_block
else:
comm.Send(a_block, dest=dest, tag=100+dest)
comm.Send(b_block, dest=dest, tag=200+dest)
else:
A_p = np.empty((my_own_rows, K), dtype=np.float32)
B_p = np.empty((K, my_own_cols), dtype=np.float32)
comm.Recv(A_p, source=0, tag=100+rank)
comm.Recv(B_p, source=0, tag=200+rank)

comm.Barrier()

Aop = SUMMAMatrixMult(A_p, N)
col_lens = comm.allgather(my_own_cols)
total_cols = np.add.reduce(col_lens, 0)
x = DistributedArray(global_shape=K * total_cols,
local_shapes=[K * col_len for col_len in col_lens],
partition=Partition.SCATTER,
mask=[i % P_prime for i in range(comm.Get_size())],
dtype=np.float32)
x[:] = B_p.flatten()
y = Aop @ x

# ======================= VERIFICATION =================-=============
A = np.arange(M*K).reshape(M, K).astype(np.float32)
B = np.arange(K*N).reshape(K, N).astype(np.float32)
C_true = A @ B
Z_true = (A.T.dot(C_true.conj())).conj()


col_start = my_layer * blk_cols # note: same my_group index on cols
col_end = min(N, col_start + blk_cols)
my_own_cols = col_end - col_start
expected_y = C_true[:,col_start:col_end].flatten()

if not np.allclose(y.local_array, expected_y, atol=1e-6):
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
else:
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")


z = Aop.H @ y
expected_z = Z_true[:,col_start:col_end].flatten()
if not np.allclose(z.local_array, expected_z, atol=1e-6):
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
else:
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
143 changes: 143 additions & 0 deletions pylops_mpi/basicoperators/MatrixMultiply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import numpy as np
import math
from mpi4py import MPI
from pylops.utils.backend import get_module
from pylops.utils.typing import DTypeLike, NDArray

from pylops_mpi import (
DistributedArray,
MPILinearOperator,
Partition
)


class SUMMAMatrixMult(MPILinearOperator):
def __init__(
self,
A: NDArray,
N: int,
base_comm: MPI.Comm = MPI.COMM_WORLD,
dtype: DTypeLike = "float64",
) -> None:
rank = base_comm.Get_rank()
size = base_comm.Get_size()

# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
self._P_prime = int(math.ceil(math.sqrt(size)))
self._C = int(math.ceil(size / self._P_prime))
assert self._P_prime * self._C >= size

# Compute this process's group and layer indices
self._group_id = rank % self._P_prime
self._layer_id = rank // self._P_prime

# Split communicators by layer (rows) and by group (columns)
self.base_comm = base_comm
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)

self.A = A

self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
self.K = A.shape[1]
self.N = N

# Determine how many columns each group holds
block_cols = int(math.ceil(self.N / self._P_prime))
local_col_start = self._group_id * block_cols
local_col_end = min(self.N, local_col_start + block_cols)
local_ncols = local_col_end - local_col_start

# Sum up the total number of input columns across all processes
total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM)
self.dims = (self.K, total_ncols)

# Recompute how many output columns each layer holds
layer_col_start = self._layer_id * block_cols
layer_col_end = min(self.N, layer_col_start + block_cols)
layer_ncols = layer_col_end - layer_col_start
total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM)

self.dimsd = (self.M, total_layer_cols)
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))

super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)

def _matvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
blk_cols = int(math.ceil(self.N / self._P_prime))
col_start = self._group_id * blk_cols
col_end = min(self.N, col_start + blk_cols)
my_own_cols = col_end - col_start
x = x.local_array.reshape((self.dims[0], my_own_cols))
B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None,
root=self._layer_id)
C_local = ncp.vstack(
self._layer_comm.allgather(
ncp.matmul(self.A, B_block, dtype=self.dtype)
)
)

layer_col_start = self._layer_id * blk_cols
layer_col_end = min(self.N, layer_col_start + blk_cols)
layer_ncols = layer_col_end - layer_col_start
layer_col_lens = self.base_comm.allgather(layer_ncols)
mask = [i // self._P_prime for i in range(self.size)]

y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
local_shapes=[(self.M * c) for c in layer_col_lens],
mask=mask,
#axis=1,
partition=Partition.SCATTER,
dtype=self.dtype)
y[:] = C_local.flatten()
return y

def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
if x.partition != Partition.SCATTER:
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")

# Determine local column block for this layer
blk_cols = int(math.ceil(self.N / self._P_prime))
layer_col_start = self._layer_id * blk_cols
layer_col_end = min(self.N, layer_col_start + blk_cols)
layer_ncols = layer_col_end - layer_col_start
layer_col_lens = self.base_comm.allgather(layer_ncols)
x = x.local_array.reshape((self.M, layer_ncols))

# Determine local row block for this process group
blk_rows = int(math.ceil(self.M / self._P_prime))
row_start = self._group_id * blk_rows
row_end = min(self.M, row_start + blk_rows)

B_tile = x[row_start:row_end, :]
A_local = self.A.T.conj()

# Pad A_local so its first dimension is divisible by _P_prime, then batch it
m, b = A_local.shape
r = math.ceil(m / self._P_prime)
A_pad = np.zeros((r * self._P_prime, b), dtype=self.dtype)
A_pad[:m, :] = A_local
A_batch = A_pad.reshape(self._P_prime, r, b)

# Perform local matmul and unpad
Y_batch = ncp.matmul(A_batch, B_tile)
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
y_local = Y_pad[:m, :]
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)

# Build the output DistributedArray with SCATTER partition
mask = [i // self._P_prime for i in range(self.size)]
y = DistributedArray(
global_shape=(self.K * self.dimsd[1]),
local_shapes=[self.K * c for c in layer_col_lens],
mask=mask,
#axis=1
partition=Partition.SCATTER,
dtype=self.dtype,
)
y[:] = y_layer.flatten()
return y