From 75c815b98d2f8b8de89f6da2cdd3dd2cb0333507 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 29 Jun 2025 02:03:03 +0200 Subject: [PATCH 01/25] inital-example --- examples/matrixmul.py | 52 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 examples/matrixmul.py diff --git a/examples/matrixmul.py b/examples/matrixmul.py new file mode 100644 index 00000000..f69c3cef --- /dev/null +++ b/examples/matrixmul.py @@ -0,0 +1,52 @@ +import numpy as np +from mpi4py import MPI +import math +import pylops_mpi + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +M = 16 #512 +N = 16 #512 +K = 16 #512 + +A_shape = (M,K) +B_shape = (K,N) +C_shape = (M,N) + +p_prime = math.isqrt(size) +assert p_prime*p_prime == size +A = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape).reshape((M//p_prime,-1)) +B = np.arange(int(B_shape[0]*B_shape[1])).reshape(B_shape).reshape((K//p_prime,-1)) + +A_dist = pylops_mpi.DistributedArray.to_dist(A, + partition=pylops_mpi.Partition.SCATTER, + axis=1) +B_dist = pylops_mpi.DistributedArray.to_dist(B, + partition=pylops_mpi.Partition.SCATTER, + axis=1) + +C_dist = pylops_mpi.DistributedArray(global_shape=(M // p_prime, N * p_prime), + partition=pylops_mpi.Partition.SCATTER, + axis=1) + + + +p = int(np.sqrt(size)) +i, j = divmod(rank, p) +row_comm = comm.Split(color=i, key=j) +col_comm = comm.Split(color=j, key=i) + +c = np.zeros((M//p, N//p), dtype=np.float32) +for k in range(p): + Atemp = A_dist.local_array.copy() if j==k else np.empty_like(A_dist.local_array) + Btemp = B_dist.local_array.copy() if i==k else np.empty_like(B_dist.local_array) + if rank==0: print(k,"cast") + row_comm.Bcast([Atemp, MPI.FLOAT], root=k) + col_comm.Bcast([Btemp, MPI.FLOAT], root=k) + c += Atemp @ Btemp + if rank == 0: print(k,"after") + +C_dist.local_array[:] = c +if rank==0: print(C_dist.asarray()) \ No newline at end of file From 069e5dd1ac5a150dd7c27338aa28edb335402be7 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 29 Jun 2025 03:37:50 +0200 Subject: [PATCH 02/25] working simple example --- examples/matrixmul.py | 51 +++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index f69c3cef..ad5fbd41 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -7,9 +7,9 @@ rank = comm.Get_rank() size = comm.Get_size() -M = 16 #512 -N = 16 #512 -K = 16 #512 +M = 8 #512 +N = 8 #512 +K = 8 #512 A_shape = (M,K) B_shape = (K,N) @@ -17,8 +17,14 @@ p_prime = math.isqrt(size) assert p_prime*p_prime == size -A = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape).reshape((M//p_prime,-1)) -B = np.arange(int(B_shape[0]*B_shape[1])).reshape(B_shape).reshape((K//p_prime,-1)) + +# Create A with 2D block-cyclic structure +A_data = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape) +A = A_data.reshape(p_prime, M//p_prime, p_prime, K//p_prime).transpose(1, 0, 2, 3).reshape(M//p_prime, -1) + +# Create B with 2D block-cyclic structure +B_data = np.arange(int(B_shape[0]*B_shape[1])).reshape(B_shape) +B = B_data.reshape(p_prime, K//p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(K//p_prime, -1) A_dist = pylops_mpi.DistributedArray.to_dist(A, partition=pylops_mpi.Partition.SCATTER, @@ -30,23 +36,26 @@ C_dist = pylops_mpi.DistributedArray(global_shape=(M // p_prime, N * p_prime), partition=pylops_mpi.Partition.SCATTER, axis=1) +if rank == 0: print(A_dist.local_array) - - -p = int(np.sqrt(size)) -i, j = divmod(rank, p) +i, j = divmod(rank, p_prime) row_comm = comm.Split(color=i, key=j) col_comm = comm.Split(color=j, key=i) -c = np.zeros((M//p, N//p), dtype=np.float32) -for k in range(p): - Atemp = A_dist.local_array.copy() if j==k else np.empty_like(A_dist.local_array) - Btemp = B_dist.local_array.copy() if i==k else np.empty_like(B_dist.local_array) - if rank==0: print(k,"cast") - row_comm.Bcast([Atemp, MPI.FLOAT], root=k) - col_comm.Bcast([Btemp, MPI.FLOAT], root=k) - c += Atemp @ Btemp - if rank == 0: print(k,"after") - -C_dist.local_array[:] = c -if rank==0: print(C_dist.asarray()) \ No newline at end of file +c_local = np.zeros((M//p_prime, N//p_prime)) +for k in range(p_prime): + Atemp=A_dist.local_array.copy() if j==k else np.empty_like(A_dist.local_array) + Btemp=B_dist.local_array.copy() if i==k else np.empty_like(B_dist.local_array) + rootA=i*p_prime+k; rootB=k*p_prime+j + row_comm.Bcast([Atemp,MPI.FLOAT],root=k) + col_comm.Bcast([Btemp,MPI.FLOAT],root=k) + # print(f"[Rank {rank}] iter{k} after : received A from {rootA}, B from {rootB}, A0={Atemp.flat[0]},B0={Btemp.flat[0]}") + c_local += Atemp @ Btemp + +C_dist.local_array[:] = c_local +C = C_dist.asarray().reshape((M,N)) +A_ = A_dist.asarray().reshape((M,K)) +B_ = B_dist.asarray().reshape((K,N)) +if rank == 0 : + print(A_data @ B_data) + print(C) \ No newline at end of file From 5fcbad3a0eb8698679eb7a16b5f18fe40f6ff380 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 29 Jun 2025 03:55:12 +0200 Subject: [PATCH 03/25] untransformed C --- examples/matrixmul.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index ad5fbd41..6a73ff12 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -16,7 +16,7 @@ C_shape = (M,N) p_prime = math.isqrt(size) -assert p_prime*p_prime == size +assert p_prime*p_prime == size, "Number of processes must be a perfect square" # Create A with 2D block-cyclic structure A_data = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape) @@ -53,9 +53,9 @@ c_local += Atemp @ Btemp C_dist.local_array[:] = c_local -C = C_dist.asarray().reshape((M,N)) -A_ = A_dist.asarray().reshape((M,K)) -B_ = B_dist.asarray().reshape((K,N)) +C_temp = C_dist.asarray().reshape((M,N)) +C = C_temp.reshape(M//p_prime, p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(M, N) + if rank == 0 : - print(A_data @ B_data) - print(C) \ No newline at end of file + print("expected:\n",A_data @ B_data) + print("calculated:\n",C) \ No newline at end of file From d8d946322d564cbc3f7836556db8c4c94b74e420 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 29 Jun 2025 22:09:54 +0200 Subject: [PATCH 04/25] Initial impl of SUMMA matmul --- examples/matrixmul.py | 71 ++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index 6a73ff12..5b4873d8 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -1,61 +1,48 @@ -import numpy as np from mpi4py import MPI import math import pylops_mpi +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult +import numpy as np comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() -M = 8 #512 -N = 8 #512 -K = 8 #512 +N = 8 +M = 8 +K = 8 -A_shape = (M,K) -B_shape = (K,N) -C_shape = (M,N) +A_shape = (N, K) +B_shape = (K, M) +C_shape = (N, M) p_prime = math.isqrt(size) -assert p_prime*p_prime == size, "Number of processes must be a perfect square" +assert p_prime * p_prime == size, "Number of processes must be a perfect square" -# Create A with 2D block-cyclic structure -A_data = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape) -A = A_data.reshape(p_prime, M//p_prime, p_prime, K//p_prime).transpose(1, 0, 2, 3).reshape(M//p_prime, -1) +A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) +B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) -# Create B with 2D block-cyclic structure -B_data = np.arange(int(B_shape[0]*B_shape[1])).reshape(B_shape) -B = B_data.reshape(p_prime, K//p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(K//p_prime, -1) +N_starts, N_ends = MPIMatrixMult.block_distribute(N, p_prime) +M_starts, M_ends = MPIMatrixMult.block_distribute(M, p_prime) +K_starts, K_ends = MPIMatrixMult.block_distribute(K, p_prime) -A_dist = pylops_mpi.DistributedArray.to_dist(A, - partition=pylops_mpi.Partition.SCATTER, - axis=1) -B_dist = pylops_mpi.DistributedArray.to_dist(B, - partition=pylops_mpi.Partition.SCATTER, - axis=1) +i, j = divmod(rank, p_prime) +A_local = A_data[N_starts[i]:N_ends[i], K_starts[j]:K_ends[j]] +B_local = B_data[K_starts[i]:K_ends[i], M_starts[j]:M_ends[j]] -C_dist = pylops_mpi.DistributedArray(global_shape=(M // p_prime, N * p_prime), - partition=pylops_mpi.Partition.SCATTER, - axis=1) -if rank == 0: print(A_dist.local_array) +B_dist = pylops_mpi.DistributedArray(global_shape=(K*M), + local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), + base_comm=comm, + partition=pylops_mpi.Partition.SCATTER) +B_dist.local_array[:] = B_local.flatten() -i, j = divmod(rank, p_prime) -row_comm = comm.Split(color=i, key=j) -col_comm = comm.Split(color=j, key=i) - -c_local = np.zeros((M//p_prime, N//p_prime)) -for k in range(p_prime): - Atemp=A_dist.local_array.copy() if j==k else np.empty_like(A_dist.local_array) - Btemp=B_dist.local_array.copy() if i==k else np.empty_like(B_dist.local_array) - rootA=i*p_prime+k; rootB=k*p_prime+j - row_comm.Bcast([Atemp,MPI.FLOAT],root=k) - col_comm.Bcast([Btemp,MPI.FLOAT],root=k) - # print(f"[Rank {rank}] iter{k} after : received A from {rootA}, B from {rootB}, A0={Atemp.flat[0]},B0={Btemp.flat[0]}") - c_local += Atemp @ Btemp - -C_dist.local_array[:] = c_local -C_temp = C_dist.asarray().reshape((M,N)) -C = C_temp.reshape(M//p_prime, p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(M, N) +print(rank, A_local.shape) +Aop = MPIMatrixMult(A_local, M, base_comm=comm) +C_dist = Aop @ B_dist +C_temp = C_dist.asarray().reshape((N, M)) +C = C_temp.reshape(N // p_prime, p_prime, p_prime, M // p_prime).transpose(1, 0, 2, 3).reshape(N, M) if rank == 0 : - print("expected:\n",A_data @ B_data) + # print("expected:\n",np.allclose(A_data @ B_data, C)) + print("expected:\n", A_data @ B_data) print("calculated:\n",C) \ No newline at end of file From a9e679ec4bf513be3c155b6840a7bad03312937f Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 29 Jun 2025 22:35:41 +0200 Subject: [PATCH 05/25] matmul with padding --- examples/matrixmul.py | 30 ++++--- pylops_mpi/basicoperators/MatrixMult.py | 114 ++++++++++++++---------- 2 files changed, 82 insertions(+), 62 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index 5b4873d8..d83e81cc 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -1,16 +1,24 @@ from mpi4py import MPI import math import pylops_mpi +from pylops_mpi.DistributedArray import local_split from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult import numpy as np + +import numpy as np +import math + + + + comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() -N = 8 -M = 8 -K = 8 +N = 9 +M = 9 +K = 9 A_shape = (N, K) B_shape = (K, M) @@ -22,25 +30,21 @@ A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) -N_starts, N_ends = MPIMatrixMult.block_distribute(N, p_prime) -M_starts, M_ends = MPIMatrixMult.block_distribute(M, p_prime) -K_starts, K_ends = MPIMatrixMult.block_distribute(K, p_prime) - i, j = divmod(rank, p_prime) -A_local = A_data[N_starts[i]:N_ends[i], K_starts[j]:K_ends[j]] -B_local = B_data[K_starts[i]:K_ends[i], M_starts[j]:M_ends[j]] -B_dist = pylops_mpi.DistributedArray(global_shape=(K*M), +A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm) +B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm) + +B_dist = pylops_mpi.DistributedArray(global_shape=(K_new*M_new), local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), base_comm=comm, partition=pylops_mpi.Partition.SCATTER) B_dist.local_array[:] = B_local.flatten() print(rank, A_local.shape) -Aop = MPIMatrixMult(A_local, M, base_comm=comm) +Aop = MPIMatrixMult(A_local, M_new, base_comm=comm) C_dist = Aop @ B_dist -C_temp = C_dist.asarray().reshape((N, M)) -C = C_temp.reshape(N // p_prime, p_prime, p_prime, M // p_prime).transpose(1, 0, 2, 3).reshape(N, M) +C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm) if rank == 0 : # print("expected:\n",np.allclose(A_data @ B_data, C)) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 2aaffbcf..4305dec6 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -131,12 +131,10 @@ def __init__( # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size self._P_prime = math.isqrt(size) - self._C = self._P_prime - if self._P_prime * self._C != size: + if self._P_prime * self._P_prime != size: raise Exception(f"Number of processes must be a square number, provided {size} instead...") - self._col_id = rank % self._P_prime - self._row_id = rank // self._P_prime + self._row_id, self._col_id = divmod(rank, self._P_prime) self.base_comm = base_comm self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) @@ -145,67 +143,85 @@ def __init__( self.A = A.astype(np.dtype(dtype)) if saveAt: self.At = A.T.conj() - self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM) - self.K = A.shape[1] + self.N = self._col_comm.allreduce(A.shape[0]) + self.K = self._row_comm.allreduce(A.shape[1]) self.M = M - block_cols = int(math.ceil(self.M / self._P_prime)) - blk_rows = int(math.ceil(self.N / self._P_prime)) + self.dims = (self.K, self.M) + self.dimsd = (self.N, self.M) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) - self._row_start = self._col_id * blk_rows - self._row_end = min(self.N, self._row_start + blk_rows) + @staticmethod + def block_distribute(array, proc_i, proc_j, comm): + p_prime = math.isqrt(comm.Get_size()) + orig_r, orig_c = array.shape - self._col_start = self._row_id * block_cols - self._col_end = min(self.M, self._col_start + block_cols) + new_r = math.ceil(orig_r / p_prime) * p_prime + new_c = math.ceil(orig_c / p_prime) * p_prime - self._local_ncols = self._col_end - self._col_start - self._rank_col_lens = self.base_comm.allgather(self._local_ncols) - total_ncols = np.sum(self._rank_col_lens) + br, bc = new_r // p_prime, new_c // p_prime + i0, j0 = proc_i * br, proc_j * bc + i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) + + block = array[i0:i1, j0:j1] + pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0 + pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0 + if pr or pc: + block = np.pad(block, [(0, pr), (0, pc)], mode='constant') + + return block, (new_r, new_c) + + @staticmethod + def block_gather(x, new_shape, orig_shape, comm): + ncp = get_module(x.engine) + p_prime = math.isqrt(comm.Get_size()) + all_blks = comm.allgather(x.local_array) + nr, nc = new_shape + orr, orc = orig_shape + br, bc = nr // p_prime, nc // p_prime + C = ncp.array(all_blks).reshape(p_prime, p_prime, br, bc).transpose(0, 2, 1, 3).reshape(nr, nc) + return C[:orr, :orc] - self.dims = (self.K, total_ncols) - self.dimsd = (self.N, total_ncols) - shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) - super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - - y = DistributedArray(global_shape=(self.N * self.dimsd[1]), - local_shapes=[(self.N * c) for c in self._rank_col_lens], + y = DistributedArray(global_shape=(self.N // self._P_prime, self.M * self._P_prime), mask=x.mask, partition=Partition.SCATTER, - dtype=self.dtype) - - my_own_cols = self._rank_col_lens[self.rank] - x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) - X_local = x_arr.astype(self.dtype) - Y_local = ncp.vstack( - self._row_comm.allgather( - ncp.matmul(self.A, X_local) - ) - ) - y[:] = Y_local.flatten() + dtype=self.dtype, + axis=1) + + x = x.local_array.reshape((self.A.shape[1], -1)) + c_local = np.zeros((self.A.shape[0], x.shape[1])) + for k in range(self._P_prime): + Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A) + Xtemp = x.copy() if self._row_id == k else np.empty_like(x) + self._row_comm.Bcast(Atemp, root=k) + self._col_comm.Bcast(Xtemp, root=k) + c_local += ncp.dot(Atemp, Xtemp) + y[:] = c_local return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") - - y = DistributedArray( - global_shape=(self.K * self.dimsd[1]), - local_shapes=[self.K * c for c in self._rank_col_lens], - mask=x.mask, - partition=Partition.SCATTER, - dtype=self.dtype, - ) - - x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) - X_tile = x_arr[self._row_start:self._row_end, :] - A_local = self.At if hasattr(self, "At") else self.A.T.conj() - Y_local = ncp.matmul(A_local, X_tile) - y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) - y[:] = y_layer.flatten() - return y + return None + # y = DistributedArray( + # global_shape=(self.K * self.dimsd[1]), + # local_shapes=[self.K * c for c in self._rank_col_lens], + # mask=x.mask, + # partition=Partition.SCATTER, + # dtype=self.dtype, + # ) + # + # x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + # X_tile = x_arr[self._row_start:self._row_end, :] + # A_local = self.At if hasattr(self, "At") else self.A.T.conj() + # Y_local = ncp.matmul(A_local, X_tile) + # y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) + # y[:] = y_layer.flatten() + # return y From 8142d440d3baeea62c16fdc0c9a5047aeba209e0 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 13 Jul 2025 07:32:46 +0200 Subject: [PATCH 06/25] impl adjoint --- examples/matrixmul.py | 26 ++--- pylops_mpi/basicoperators/MatrixMult.py | 120 ++++++++++++++++++++---- 2 files changed, 113 insertions(+), 33 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index d83e81cc..3f8a063a 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -1,16 +1,9 @@ -from mpi4py import MPI import math -import pylops_mpi -from pylops_mpi.DistributedArray import local_split -from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult import numpy as np +from mpi4py import MPI - -import numpy as np -import math - - - +import pylops_mpi +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -44,9 +37,16 @@ print(rank, A_local.shape) Aop = MPIMatrixMult(A_local, M_new, base_comm=comm) C_dist = Aop @ B_dist -C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm) +Z_dist = Aop.H @ C_dist +C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm) +Z = MPIMatrixMult.block_gather(Z_dist, (K_new,M_new), (K,M), comm) if rank == 0 : + print("expected:\n", np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32))) + # print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj()) + # print("calculated:\n",Z.astype(np.int32)) + # print("calculated:\n", (A_data.T.dot((A_data @ B_data).conj())).conj() == Z.astype(np.int32)) + # print("expected:\n",np.allclose(A_data @ B_data, C)) - print("expected:\n", A_data @ B_data) - print("calculated:\n",C) \ No newline at end of file + # print("expected:\n", A_data @ B_data) + # print("calculated:\n",C) \ No newline at end of file diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 4305dec6..3409d3df 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -152,6 +152,64 @@ def __init__( shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + @staticmethod + def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): + r"""Configure active grid + + Configure a square process grid from a parent MPI communicator and + select a subset of "active" processes. Each process in ``base_comm`` + is assigned to a logical 2D grid of size :math:`P' \times P'`, + where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first + :math:`active_dim x active_dim` processes + (by row-major order) are considered "active". Inactive ranks return + immediately with no new communicator. + + Parameters: + ----------- + base_comm : :obj:`mpi4py.MPI.Comm` + MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). + N : :obj:`int` + Number of rows of the global data domain. + M : :obj:`int` + Number of columns of the global data domain. + + Returns: + -------- + comm : :obj:`mpi4py.MPI.Comm` + Sub-communicator including only active ranks. + rank : :obj:`int` + Rank within the new sub-communicator (or original rank + if inactive). + row : :obj:`int` + Grid row index of this process in the active grid (or original rank + if inactive). + col : :obj:`int` + Grid column index of this process in the active grid + (or original rank if inactive). + is_active : :obj:`bool` + Flag indicating whether this rank is in the active sub-grid. + + """ + rank = base_comm.Get_rank() + size = base_comm.Get_size() + p_prime = math.isqrt(size) + row, col = divmod(rank, p_prime) + active_dim = min(N, M, p_prime) + is_active = (row < active_dim and col < active_dim) + + if not is_active: + return None, rank, row, col, False + + active_ranks = [r for r in range(size) + if (r // p_prime) < active_dim and (r % p_prime) < active_dim] + new_group = base_comm.Get_group().Incl(active_ranks) + new_comm = base_comm.Create_group(new_group) + p_prime_new = math.isqrt(len(active_ranks)) + new_rank = new_comm.Get_rank() + new_row, new_col = divmod(new_rank, p_prime_new) + + return new_comm, new_rank, new_row, new_col, True + @staticmethod def block_distribute(array, proc_i, proc_j, comm): p_prime = math.isqrt(comm.Get_size()) @@ -188,11 +246,12 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - y = DistributedArray(global_shape=(self.N // self._P_prime, self.M * self._P_prime), + local_shape = (self.N // self._P_prime) * ( self.M * self._P_prime // self.size) + y = DistributedArray(global_shape=((self.N // self._P_prime) * self.M * self._P_prime), mask=x.mask, + local_shapes=[ local_shape for _ in range(self.size)], partition=Partition.SCATTER, - dtype=self.dtype, - axis=1) + dtype=self.dtype) x = x.local_array.reshape((self.A.shape[1], -1)) c_local = np.zeros((self.A.shape[0], x.shape[1])) @@ -202,26 +261,47 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: self._row_comm.Bcast(Atemp, root=k) self._col_comm.Bcast(Xtemp, root=k) c_local += ncp.dot(Atemp, Xtemp) - y[:] = c_local + y[:] = c_local.flatten() return y + def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") - return None - # y = DistributedArray( - # global_shape=(self.K * self.dimsd[1]), - # local_shapes=[self.K * c for c in self._rank_col_lens], - # mask=x.mask, - # partition=Partition.SCATTER, - # dtype=self.dtype, - # ) - # - # x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) - # X_tile = x_arr[self._row_start:self._row_end, :] - # A_local = self.At if hasattr(self, "At") else self.A.T.conj() - # Y_local = ncp.matmul(A_local, X_tile) - # y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) - # y[:] = y_layer.flatten() - # return y + + local_shape = (self.K // self._P_prime) * (self.M * self._P_prime // self.size) + y = DistributedArray( + global_shape=((self.K // self._P_prime) * self.M * self._P_prime), + mask=x.mask, + local_shapes=[local_shape for _ in range(self.size)], + partition=Partition.SCATTER, + dtype=self.dtype, + ) + x_reshaped = x.local_array.reshape((self.A.shape[0], -1)) + A_local = self.At if hasattr(self, "At") else self.A.T.conj() + c_local = np.zeros((self.A.shape[1], x_reshaped.shape[1])) + P = self._P_prime + + for k in range(P): + temps = {} + requests = [] + for buf, owner, base, name in ( + (A_local, self._row_id, 100, 'A'), + (x_reshaped, self._col_id, 200, 'B'), + ): + tmp = np.empty_like(buf) + temps[name] = tmp + src, tag = k * P + owner, (base + k) * 1000 + self.rank + requests.append(self.base_comm.Irecv(tmp, source=src, tag=tag)) + + if self.rank // P == k: + fixed = self.rank % P + for moving in range(P): + dest = (fixed * P + moving) if name == 'A' else moving * P + fixed + tag = (base + k) * 1000 + dest + requests.append(self.base_comm.Isend(buf, dest=dest, tag=tag)) + MPI.Request.Waitall(requests) + c_local += ncp.dot(temps['A'], temps['B']) + y[:] = c_local.flatten() + return y \ No newline at end of file From 7363431138fe3e41529141a15ce5e7e87e382bc1 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 13 Jul 2025 15:20:21 +0200 Subject: [PATCH 07/25] cleanedup adjoint impl --- pylops_mpi/basicoperators/MatrixMult.py | 55 +++++++++++-------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 3409d3df..c29b1eb3 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -246,22 +246,22 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - local_shape = (self.N // self._P_prime) * ( self.M * self._P_prime // self.size) - y = DistributedArray(global_shape=((self.N // self._P_prime) * self.M * self._P_prime), + local_shape = ((self.N * self.M) // self.size) + y = DistributedArray(global_shape=(self.N * self.M), mask=x.mask, - local_shapes=[ local_shape for _ in range(self.size)], + local_shapes=[local_shape] * self.size, partition=Partition.SCATTER, dtype=self.dtype) x = x.local_array.reshape((self.A.shape[1], -1)) - c_local = np.zeros((self.A.shape[0], x.shape[1])) + Y_local = np.zeros((self.A.shape[0], x.shape[1])) for k in range(self._P_prime): Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A) Xtemp = x.copy() if self._row_id == k else np.empty_like(x) self._row_comm.Bcast(Atemp, root=k) self._col_comm.Bcast(Xtemp, root=k) - c_local += ncp.dot(Atemp, Xtemp) - y[:] = c_local.flatten() + Y_local += ncp.dot(Atemp, Xtemp) + y[:] = Y_local.flatten() return y @@ -270,38 +270,33 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") - local_shape = (self.K // self._P_prime) * (self.M * self._P_prime // self.size) + local_shape = ((self.K * self.M ) // self.size) y = DistributedArray( - global_shape=((self.K // self._P_prime) * self.M * self._P_prime), + global_shape=(self.K * self.M), mask=x.mask, - local_shapes=[local_shape for _ in range(self.size)], + local_shapes=[local_shape] * self.size, partition=Partition.SCATTER, dtype=self.dtype, ) x_reshaped = x.local_array.reshape((self.A.shape[0], -1)) A_local = self.At if hasattr(self, "At") else self.A.T.conj() - c_local = np.zeros((self.A.shape[1], x_reshaped.shape[1])) - P = self._P_prime + Y_local = np.zeros((self.A.shape[1], x_reshaped.shape[1])) - for k in range(P): - temps = {} + for k in range(self._P_prime): requests = [] - for buf, owner, base, name in ( - (A_local, self._row_id, 100, 'A'), - (x_reshaped, self._col_id, 200, 'B'), - ): - tmp = np.empty_like(buf) - temps[name] = tmp - src, tag = k * P + owner, (base + k) * 1000 + self.rank - requests.append(self.base_comm.Irecv(tmp, source=src, tag=tag)) - - if self.rank // P == k: - fixed = self.rank % P - for moving in range(P): - dest = (fixed * P + moving) if name == 'A' else moving * P + fixed - tag = (base + k) * 1000 + dest - requests.append(self.base_comm.Isend(buf, dest=dest, tag=tag)) + ATtemp = np.empty_like(A_local) + srcA = k * self._P_prime + self._row_id + tagA = (100 + k) * 1000 + self.rank + requests.append(self.base_comm.Irecv(ATtemp, source=srcA, tag=tagA)) + if self._row_id == k: + fixed_col = self._col_id + for moving_col in range(self._P_prime): + destA = fixed_col * self._P_prime + moving_col + tagA = (100 + k) * 1000 + destA + requests.append(self.base_comm.Isend(A_local, dest=destA,tag=tagA)) + Xtemp = x_reshaped.copy() if self._row_id == k else np.empty_like(x_reshaped) + requests.append(self._col_comm.Ibcast(Xtemp, root=k)) MPI.Request.Waitall(requests) - c_local += ncp.dot(temps['A'], temps['B']) - y[:] = c_local.flatten() + Y_local += ncp.dot(ATtemp, Xtemp) + y[:] = Y_local.flatten() return y \ No newline at end of file From dc00226f3acc197c1d017b92a776f4d31a30c3c6 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 13 Jul 2025 20:59:04 +0200 Subject: [PATCH 08/25] added handling for padding --- examples/matrixmul.py | 28 ++-- pylops_mpi/basicoperators/MatrixMult.py | 191 ++++++++++++++++++++---- 2 files changed, 181 insertions(+), 38 deletions(-) diff --git a/examples/matrixmul.py b/examples/matrixmul.py index b3b080c1..e4b7a9e1 100644 --- a/examples/matrixmul.py +++ b/examples/matrixmul.py @@ -24,28 +24,30 @@ B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) i, j = divmod(rank, p_prime) - A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm) B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm) -B_dist = pylops_mpi.DistributedArray(global_shape=(K_new*M_new), +B_dist = pylops_mpi.DistributedArray(global_shape=(K * M), local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), base_comm=comm, partition=pylops_mpi.Partition.SCATTER) B_dist.local_array[:] = B_local.flatten() -Aop = MPIMatrixMult(A_local, M_new, base_comm=comm) +Aop = MPIMatrixMult(A_local, M, base_comm=comm) C_dist = Aop @ B_dist Z_dist = Aop.H @ C_dist -C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm) -Z = MPIMatrixMult.block_gather(Z_dist, (K_new,M_new), (K,M), comm) +C = MPIMatrixMult.block_gather(C_dist, (N,M), (N,M), comm) +Z = MPIMatrixMult.block_gather(Z_dist, (K,M), (K,M), comm) if rank == 0 : - print("expected:\n", np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32))) - # print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj()) - # print("calculated:\n",Z.astype(np.int32)) - # print("calculated:\n", (A_data.T.dot((A_data @ B_data).conj())).conj() == Z.astype(np.int32)) - - # print("expected:\n",np.allclose(A_data @ B_data, C)) - # print("expected:\n", A_data @ B_data) - # print("calculated:\n",C) \ No newline at end of file + C_correct = np.allclose(A_data @ B_data, C) + print("C expected: ", C_correct) + if not C_correct: + print("expected:\n", A_data @ B_data) + print("calculated:\n",C) + + Z_correct = np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32)) + print("Z expected: ", Z_correct) + if not Z_correct: + print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj()) + print("calculated:\n", Z.astype(np.int32)) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index c617e114..4e3f6b12 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -136,13 +136,32 @@ def __init__( self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) self.A = A.astype(np.dtype(dtype)) - if saveAt: - self.At = A.T.conj() self.N = self._col_comm.allreduce(A.shape[0]) self.K = self._row_comm.allreduce(A.shape[1]) self.M = M + self._N_padded = math.ceil(self.N / self._P_prime) * self._P_prime + self._K_padded = math.ceil(self.K / self._P_prime) * self._P_prime + self._M_padded = math.ceil(self.M / self._P_prime) * self._P_prime + + bn = self._N_padded // self._P_prime + bk = self._K_padded // self._P_prime + bm = self._M_padded // self._P_prime + + pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0 + pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0 + + if pr < 0 or pc < 0: + raise Exception(f"Improper distribution of A expected local shape " + f"( ≤ {bn}, ≤ {bk}) but got ({A.shape[0]},{A.shape[1]})") + + if pr > 0 or pc > 0: + self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant') + + if saveAt: + self.At = self.A.T.conj() + self.dims = (self.K, self.M) self.dimsd = (self.N, self.M) shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) @@ -218,12 +237,14 @@ def block_distribute(array, proc_i, proc_j, comm): i0, j0 = proc_i * br, proc_j * bc i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) - block = array[i0:i1, j0:j1] + i_end = None if proc_i == p_prime - 1 else i1 + j_end = None if proc_j == p_prime - 1 else j1 + block = array[i0:i_end, j0:j_end] + pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0 pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0 - if pr or pc: - block = np.pad(block, [(0, pr), (0, pc)], mode='constant') - + #comment the padding to get the block as unpadded + # if pr or pc: block = np.pad(block, [(0, pr), (0, pc)], mode='constant') return block, (new_r, new_c) @staticmethod @@ -231,52 +252,170 @@ def block_gather(x, new_shape, orig_shape, comm): ncp = get_module(x.engine) p_prime = math.isqrt(comm.Get_size()) all_blks = comm.allgather(x.local_array) - nr, nc = new_shape + + nr, nc = new_shape orr, orc = orig_shape - br, bc = nr // p_prime, nc // p_prime - C = ncp.array(all_blks).reshape(p_prime, p_prime, br, bc).transpose(0, 2, 1, 3).reshape(nr, nc) + + # Calculate base block sizes + br_base = nr // p_prime + bc_base = nc // p_prime + + # Calculate remainder rows/cols that need to be distributed + r_remainder = nr % p_prime + c_remainder = nc % p_prime + + # Create the output matrix + C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype) + + # Place each block in the correct position + for rank in range(p_prime * p_prime): + # Convert linear rank to 2D grid position + proc_row = rank // p_prime + proc_col = rank % p_prime + + # Calculate this process's block dimensions + block_rows = br_base + (1 if proc_row < r_remainder else 0) + block_cols = bc_base + (1 if proc_col < c_remainder else 0) + + # Calculate starting position in global matrix + start_row = proc_row * br_base + min(proc_row, r_remainder) + start_col = proc_col * bc_base + min(proc_col, c_remainder) + + # Place the block + block = all_blks[rank] + if block.ndim == 1: + block = block.reshape(block_rows, block_cols) + C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block return C[:orr, :orc] def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - local_shape = ((self.N * self.M) // self.size) - y = DistributedArray(global_shape=(self.N * self.M), + + # Calculate local shapes for block distribution + bn = self._N_padded // self._P_prime # block size in N dimension + bm = self._M_padded // self._P_prime # block size in M dimension + + # Calculate actual local shape for this process (considering original dimensions) + local_n = bn + local_m = bm + + # Adjust for edge/corner processes that might have smaller blocks + if self._row_id == self._P_prime - 1: + local_n = self.N - (self._P_prime - 1) * bn + if self._col_id == self._P_prime - 1: + local_m = self.M - (self._P_prime - 1) * bm + + local_shape = local_n * local_m + + # Create local_shapes array for all processes + local_shapes = [] + for rank in range(self.size): + row_id, col_id = divmod(rank, self._P_prime) + proc_n = bn if row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn + proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm + local_shapes.append(proc_n * proc_m) + + y = DistributedArray(global_shape=(self.N * self.M), mask=x.mask, - local_shapes=[local_shape] * self.size, + local_shapes=local_shapes, partition=Partition.SCATTER, - dtype=self.dtype) + dtype=self.dtype, + base_comm=self.base_comm + ) + + # Calculate expected padded dimensions for x + bk = self._K_padded // self._P_prime # block size in K dimension + + # The input x corresponds to blocks from matrix B (K x M) + # This process should receive a block of size (local_k x local_m) + local_k = bk + if self._row_id == self._P_prime - 1: + local_k = self.K - (self._P_prime - 1) * bk + + # Reshape x.local_array to its 2D block form + x_block = x.local_array.reshape((local_k, local_m)) + + # Pad the block to the full padded size if necessary + pad_k = bk - local_k + pad_m = bm - local_m + + if pad_k > 0 or pad_m > 0: + x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant') + + Y_local = np.zeros((self.A.shape[0], bm)) - x = x.local_array.reshape((self.A.shape[1], -1)) - Y_local = np.zeros((self.A.shape[0], x.shape[1])) for k in range(self._P_prime): Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A) - Xtemp = x.copy() if self._row_id == k else np.empty_like(x) + Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block) self._row_comm.Bcast(Atemp, root=k) self._col_comm.Bcast(Xtemp, root=k) Y_local += ncp.dot(Atemp, Xtemp) - y[:] = Y_local.flatten() - return y + Y_local_unpadded = Y_local[:local_n, :local_m] + y[:] = Y_local_unpadded.flatten() + return y def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") - local_shape = ((self.K * self.M ) // self.size) + # Calculate local shapes for block distribution + bk = self._K_padded // self._P_prime # block size in K dimension + bm = self._M_padded // self._P_prime # block size in M dimension + + # Calculate actual local shape for this process (considering original dimensions) + local_k = bk + local_m = bm + + # Adjust for edge/corner processes that might have smaller blocks + if self._row_id == self._P_prime - 1: + local_k = self.K - (self._P_prime - 1) * bk + if self._col_id == self._P_prime - 1: + local_m = self.M - (self._P_prime - 1) * bm + + local_shape = local_k * local_m + + # Create local_shapes array for all processes + local_shapes = [] + for rank in range(self.size): + row_id, col_id = divmod(rank, self._P_prime) + proc_k = bk if row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk + proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm + local_shapes.append(proc_k * proc_m) + y = DistributedArray( global_shape=(self.K * self.M), mask=x.mask, - local_shapes=[local_shape] * self.size, + local_shapes=local_shapes, partition=Partition.SCATTER, dtype=self.dtype, base_comm=self.base_comm ) - x_reshaped = x.local_array.reshape((self.A.shape[0], -1)) + + # Calculate expected padded dimensions for x + bn = self._N_padded // self._P_prime # block size in N dimension + + # The input x corresponds to blocks from the result (N x M) + # This process should receive a block of size (local_n x local_m) + local_n = bn + if self._row_id == self._P_prime - 1: + local_n = self.N - (self._P_prime - 1) * bn + + # Reshape x.local_array to its 2D block form + x_block = x.local_array.reshape((local_n, local_m)) + + # Pad the block to the full padded size if necessary + pad_n = bn - local_n + pad_m = bm - local_m + + if pad_n > 0 or pad_m > 0: + x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant') + A_local = self.At if hasattr(self, "At") else self.A.T.conj() - Y_local = np.zeros((self.A.shape[1], x_reshaped.shape[1])) + Y_local = np.zeros((self.A.shape[1], bm)) for k in range(self._P_prime): requests = [] @@ -289,10 +428,12 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: for moving_col in range(self._P_prime): destA = fixed_col * self._P_prime + moving_col tagA = (100 + k) * 1000 + destA - requests.append(self.base_comm.Isend(A_local, dest=destA,tag=tagA)) - Xtemp = x_reshaped.copy() if self._row_id == k else np.empty_like(x_reshaped) + requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA)) + Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block) requests.append(self._col_comm.Ibcast(Xtemp, root=k)) MPI.Request.Waitall(requests) Y_local += ncp.dot(ATtemp, Xtemp) - y[:] = Y_local.flatten() + + Y_local_unpadded = Y_local[:local_k, :local_m] + y[:] = Y_local_unpadded.flatten() return y \ No newline at end of file From 58d3cebf52cd1a87d81229771af993238cafd698 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Tue, 15 Jul 2025 01:47:45 +0200 Subject: [PATCH 09/25] Cleanup --- .../{matrixmul.py => plot_summamatrixmult.py} | 19 +- pylops_mpi/basicoperators/MatrixMult.py | 531 ++++++++++++------ 2 files changed, 381 insertions(+), 169 deletions(-) rename examples/{matrixmul.py => plot_summamatrixmult.py} (67%) diff --git a/examples/matrixmul.py b/examples/plot_summamatrixmult.py similarity index 67% rename from examples/matrixmul.py rename to examples/plot_summamatrixmult.py index e4b7a9e1..4aa85535 100644 --- a/examples/matrixmul.py +++ b/examples/plot_summamatrixmult.py @@ -3,7 +3,9 @@ from mpi4py import MPI import pylops_mpi -from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, + block_gather, + MPISummaMatrixMult) comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -23,9 +25,12 @@ A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) -i, j = divmod(rank, p_prime) -A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm) -B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm) +A_slice = local_block_spit(A_shape, rank, comm) +B_slice = local_block_spit(B_shape, rank, comm) +A_local = A_data[A_slice] +B_local = B_data[B_slice] +# A_local, (N_new, K_new) = block_distribute(A_data,rank, comm) +# B_local, (K_new, M_new) = block_distribute(B_data,rank, comm) B_dist = pylops_mpi.DistributedArray(global_shape=(K * M), local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), @@ -33,12 +38,12 @@ partition=pylops_mpi.Partition.SCATTER) B_dist.local_array[:] = B_local.flatten() -Aop = MPIMatrixMult(A_local, M, base_comm=comm) +Aop = MPISummaMatrixMult(A_local, M, base_comm=comm) C_dist = Aop @ B_dist Z_dist = Aop.H @ C_dist -C = MPIMatrixMult.block_gather(C_dist, (N,M), (N,M), comm) -Z = MPIMatrixMult.block_gather(Z_dist, (K,M), (K,M), comm) +C = block_gather(C_dist, (N,M), (N,M), comm) +Z = block_gather(Z_dist, (K,M), (K,M), comm) if rank == 0 : C_correct = np.allclose(A_data @ B_data, C) print("C expected: ", C_correct) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 4e3f6b12..6df75ae0 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -1,6 +1,8 @@ -import numpy as np import math +import numpy as np +from typing import Tuple from mpi4py import MPI + from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, NDArray @@ -11,6 +13,148 @@ ) +def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): + r"""Configure active grid + + Configure a square process grid from a parent MPI communicator and + select a subset of "active" processes. Each process in ``base_comm`` + is assigned to a logical 2D grid of size :math:`P' \times P'`, + where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first + :math:`active_dim x active_dim` processes + (by row-major order) are considered "active". Inactive ranks return + immediately with no new communicator. + + Parameters: + ----------- + base_comm : :obj:`mpi4py.MPI.Comm` + MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). + N : :obj:`int` + Number of rows of the global data domain. + M : :obj:`int` + Number of columns of the global data domain. + + Returns: + -------- + comm : :obj:`mpi4py.MPI.Comm` + Sub-communicator including only active ranks. + rank : :obj:`int` + Rank within the new sub-communicator (or original rank + if inactive). + row : :obj:`int` + Grid row index of this process in the active grid (or original rank + if inactive). + col : :obj:`int` + Grid column index of this process in the active grid + (or original rank if inactive). + is_active : :obj:`bool` + Flag indicating whether this rank is in the active sub-grid. + + """ + rank = base_comm.Get_rank() + size = base_comm.Get_size() + p_prime = math.isqrt(size) + row, col = divmod(rank, p_prime) + active_dim = min(N, M, p_prime) + is_active = (row < active_dim and col < active_dim) + + if not is_active: + return None, rank, row, col, False + + active_ranks = [r for r in range(size) + if (r // p_prime) < active_dim and (r % p_prime) < active_dim] + new_group = base_comm.Get_group().Incl(active_ranks) + new_comm = base_comm.Create_group(new_group) + p_prime_new = math.isqrt(len(active_ranks)) + new_rank = new_comm.Get_rank() + new_row, new_col = divmod(new_rank, p_prime_new) + + return new_comm, new_rank, new_row, new_col, True + +def block_distribute(array:NDArray, rank:int, comm: MPI.Comm, pad:bool=False): + size = comm.Get_size() + p_prime = math.isqrt(size) + if p_prime * p_prime != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + proc_i, proc_j = divmod(rank, p_prime) + orig_r, orig_c = array.shape + + new_r = math.ceil(orig_r / p_prime) * p_prime + new_c = math.ceil(orig_c / p_prime) * p_prime + + br, bc = new_r // p_prime, new_c // p_prime + i0, j0 = proc_i * br, proc_j * bc + i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) + + i_end = None if proc_i == p_prime - 1 else i1 + j_end = None if proc_j == p_prime - 1 else j1 + block = array[i0:i_end, j0:j_end] + + pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0 + pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0 + if pad and (pr or pc): block = np.pad(block, [(0, pr), (0, pc)], mode='constant') + return block, (new_r, new_c) + +def local_block_spit(global_shape: Tuple[int, int], rank: int, comm: MPI.Comm) -> Tuple[slice, slice]: + size = comm.Get_size() + p_prime = math.isqrt(size) + if p_prime * p_prime != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + proc_i, proc_j = divmod(rank, p_prime) + orig_r, orig_c = global_shape + new_r = math.ceil(orig_r / p_prime) * p_prime + new_c = math.ceil(orig_c / p_prime) * p_prime + + br, bc = new_r // p_prime, new_c // p_prime + i0, j0 = proc_i * br, proc_j * bc + i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) + + i_end = None if proc_i == p_prime - 1 else i1 + j_end = None if proc_j == p_prime - 1 else j1 + return slice(i0, i_end), slice(j0, j_end) + +def block_gather(x, new_shape, orig_shape, comm): + ncp = get_module(x.engine) + p_prime = math.isqrt(comm.Get_size()) + all_blks = comm.allgather(x.local_array) + + nr, nc = new_shape + orr, orc = orig_shape + + # Calculate base block sizes + br_base = nr // p_prime + bc_base = nc // p_prime + + # Calculate remainder rows/cols that need to be distributed + r_remainder = nr % p_prime + c_remainder = nc % p_prime + + # Create the output matrix + C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype) + + # Place each block in the correct position + for rank in range(p_prime * p_prime): + # Convert linear rank to 2D grid position + proc_row = rank // p_prime + proc_col = rank % p_prime + + # Calculate this process's block dimensions + block_rows = br_base + (1 if proc_row < r_remainder else 0) + block_cols = bc_base + (1 if proc_col < c_remainder else 0) + + # Calculate starting position in global matrix + start_row = proc_row * br_base + min(proc_row, r_remainder) + start_col = proc_col * bc_base + min(proc_col, c_remainder) + + # Place the block + block = all_blks[rank] + if block.ndim == 1: + block = block.reshape(block_rows, block_cols) + C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block + return C[:orr, :orc] + + class MPIMatrixMult(MPILinearOperator): r"""MPI Matrix multiplication @@ -112,6 +256,222 @@ class MPIMatrixMult(MPILinearOperator): the same row perform an ``allreduce`` sum to combine their partial results. This gives the complete ``(K, M_local)`` result for their assigned column. + """ + def __init__( + self, + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + dtype: DTypeLike = "float64", + ) -> None: + rank = base_comm.Get_rank() + size = base_comm.Get_size() + + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size + self._P_prime = math.isqrt(size) + self._C = self._P_prime + if self._P_prime * self._C != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + self._col_id = rank % self._P_prime + self._row_id = rank // self._P_prime + + self.base_comm = base_comm + self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) + self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) + + self.A = A.astype(np.dtype(dtype)) + if saveAt: + self.At = A.T.conj() + + self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM) + self.K = A.shape[1] + self.M = M + + block_cols = int(math.ceil(self.M / self._P_prime)) + blk_rows = int(math.ceil(self.N / self._P_prime)) + + self._row_start = self._col_id * blk_rows + self._row_end = min(self.N, self._row_start + blk_rows) + + self._col_start = self._row_id * block_cols + self._col_end = min(self.M, self._col_start + block_cols) + + self._local_ncols = max(0, self._col_end - self._col_start) + self._rank_col_lens = self.base_comm.allgather(self._local_ncols) + total_ncols = np.sum(self._rank_col_lens) + + self.dims = (self.K, total_ncols) + self.dimsd = (self.N, total_ncols) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + + y = DistributedArray( + global_shape=(self.N * self.dimsd[1]), + local_shapes=[(self.N * c) for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + base_comm=self.base_comm + ) + + my_own_cols = self._rank_col_lens[self.rank] + x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) + X_local = x_arr.astype(self.dtype) + Y_local = ncp.vstack( + self._row_comm.allgather( + ncp.matmul(self.A, X_local) + ) + ) + y[:] = Y_local.flatten() + return y + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + y = DistributedArray( + global_shape=(self.K * self.dimsd[1]), + local_shapes=[self.K * c for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + base_comm=self.base_comm + ) + + x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + X_tile = x_arr[self._row_start:self._row_end, :] + A_local = self.At if hasattr(self, "At") else self.A.T.conj() + Y_local = ncp.matmul(A_local, X_tile) + y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) + y[:] = y_layer.flatten() + return y + +class MPISummaMatrixMult(MPILinearOperator): + r"""MPI SUMMA Matrix multiplication + + Implements distributed matrix-matrix multiplication using the SUMMA algorithm + between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and + input model and data vectors, which are both interpreted as matrices + distributed in block-column fashion. + + Parameters + ---------- + A : :obj:`numpy.ndarray` + Local block of the matrix of shape :math:`[N_{loc} \times K_{loc}]` + where :math:`N_{loc}` and :math:`K_{loc}` are the number of rows and + columns stored on this MPI rank. + M : :obj:`int` + Global number of columns of the matrices representing the input model + and data vectors. + saveAt : :obj:`bool`, optional + Save ``A`` and ``A.H`` to speed up the computation of adjoint + (``True``) or create ``A.H`` on-the-fly (``False``). + Note that ``saveAt=True`` will double the amount of required memory. + Default is ``False``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + dtype : :obj:`str`, optional + Type of elements in input array. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + + Raises + ------ + Exception + If the operator is created with a non-square number of MPI ranks. + ValueError + If input vector does not have the correct partition type. + + Notes + ----- + This operator performs distributed matrix-matrix multiplication using the + SUMMA (Scalable Universal Matrix Multiplication Algorithm), whose forward + operation can be described as :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` where: + + - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` + + The adjoint operation is represented by + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where + :math:`\mathbf{A}^H` is the complex conjugate transpose of :math:`\mathbf{A}`. + + This implementation is based on a 2D block distribution across a square process + grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows: + + - The matrix ``A`` is distributed across MPI processes in 2D blocks where + each process holds a local block of ``A`` with shape :math:`[N_{loc} \times K_{loc}]` + where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`. + + - The operand matrix ``X`` is also distributed across MPI processes in 2D blocks where + each process holds a local block of ``X`` with shape :math:`[K_{loc} \times M_{loc}]` + where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. + + - The result matrix ``Y`` is also distributed across MPI processes in 2D blocks where + each process holds a local block of ``Y`` with shape :math:`[N_{loc} \times M_{loc}]` + where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. + + + **Forward Operation (SUMMA Algorithm)** + + The forward operation implements the SUMMA algorithm: + + 1. **Input Preparation**: The input vector ``x``is reshaped to ``(K_{loc}, M_{loc})`` representing + the local block assigned to the current process. + + 2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` : + + a. **Broadcast A blocks**: Process in column ``k`` broadcasts its ``A`` + block to all other processes in the same process row. + + b. **Broadcast X blocks**: Process in row ``k`` broadcasts its ``X`` + block to all other processes in the same process column. + + c. **Local Computation**: Each process computes the partial matrix + product ``A_broadcast @ X_broadcast`` and accumulates it to its + local result. + + 3. **Result Assembly**: After all k SUMMA iterations, each process has computed + its local block of the result matrix ``Y``. + + **Adjoint Operation (SUMMA Algorithm)** + + The adjoint operation performs the conjugate transpose multiplication using + a modified SUMMA algorithm: + + 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N_{loc}, M_{loc})`` + representing the local block of the input matrix. + + 2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm: + + a. **Broadcast A^H blocks**: The conjugate transpose of ``A`` blocks is + communicated between processes. If ``saveAt=True``, the pre-computed + ``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly. + + b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its ``Y`` + block to all other processes in the same process column. + + c. **Local Adjoint Computation**: Each process computes the partial + matrix product ``A_H_broadcast @ Y_broadcast`` and accumulates it + to the local result. + + 3. **Result Assembly**: After all adjoint SUMMA iterations, each process has + computed its local block of the result matrix ``X_{adj}``. + + The implementation handles padding automatically to ensure proper block sizes + for the square process grid, and unpadding is performed before returning results. + """ def __init__( self, @@ -167,127 +527,6 @@ def __init__( shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) - @staticmethod - def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): - r"""Configure active grid - - Configure a square process grid from a parent MPI communicator and - select a subset of "active" processes. Each process in ``base_comm`` - is assigned to a logical 2D grid of size :math:`P' \times P'`, - where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first - :math:`active_dim x active_dim` processes - (by row-major order) are considered "active". Inactive ranks return - immediately with no new communicator. - - Parameters: - ----------- - base_comm : :obj:`mpi4py.MPI.Comm` - MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). - N : :obj:`int` - Number of rows of the global data domain. - M : :obj:`int` - Number of columns of the global data domain. - - Returns: - -------- - comm : :obj:`mpi4py.MPI.Comm` - Sub-communicator including only active ranks. - rank : :obj:`int` - Rank within the new sub-communicator (or original rank - if inactive). - row : :obj:`int` - Grid row index of this process in the active grid (or original rank - if inactive). - col : :obj:`int` - Grid column index of this process in the active grid - (or original rank if inactive). - is_active : :obj:`bool` - Flag indicating whether this rank is in the active sub-grid. - - """ - rank = base_comm.Get_rank() - size = base_comm.Get_size() - p_prime = math.isqrt(size) - row, col = divmod(rank, p_prime) - active_dim = min(N, M, p_prime) - is_active = (row < active_dim and col < active_dim) - - if not is_active: - return None, rank, row, col, False - - active_ranks = [r for r in range(size) - if (r // p_prime) < active_dim and (r % p_prime) < active_dim] - new_group = base_comm.Get_group().Incl(active_ranks) - new_comm = base_comm.Create_group(new_group) - p_prime_new = math.isqrt(len(active_ranks)) - new_rank = new_comm.Get_rank() - new_row, new_col = divmod(new_rank, p_prime_new) - - return new_comm, new_rank, new_row, new_col, True - - @staticmethod - def block_distribute(array, proc_i, proc_j, comm): - p_prime = math.isqrt(comm.Get_size()) - orig_r, orig_c = array.shape - - new_r = math.ceil(orig_r / p_prime) * p_prime - new_c = math.ceil(orig_c / p_prime) * p_prime - - br, bc = new_r // p_prime, new_c // p_prime - i0, j0 = proc_i * br, proc_j * bc - i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) - - i_end = None if proc_i == p_prime - 1 else i1 - j_end = None if proc_j == p_prime - 1 else j1 - block = array[i0:i_end, j0:j_end] - - pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0 - pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0 - #comment the padding to get the block as unpadded - # if pr or pc: block = np.pad(block, [(0, pr), (0, pc)], mode='constant') - return block, (new_r, new_c) - - @staticmethod - def block_gather(x, new_shape, orig_shape, comm): - ncp = get_module(x.engine) - p_prime = math.isqrt(comm.Get_size()) - all_blks = comm.allgather(x.local_array) - - nr, nc = new_shape - orr, orc = orig_shape - - # Calculate base block sizes - br_base = nr // p_prime - bc_base = nc // p_prime - - # Calculate remainder rows/cols that need to be distributed - r_remainder = nr % p_prime - c_remainder = nc % p_prime - - # Create the output matrix - C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype) - - # Place each block in the correct position - for rank in range(p_prime * p_prime): - # Convert linear rank to 2D grid position - proc_row = rank // p_prime - proc_col = rank % p_prime - - # Calculate this process's block dimensions - block_rows = br_base + (1 if proc_row < r_remainder else 0) - block_cols = bc_base + (1 if proc_col < c_remainder else 0) - - # Calculate starting position in global matrix - start_row = proc_row * br_base + min(proc_row, r_remainder) - start_col = proc_col * bc_base + min(proc_col, c_remainder) - - # Place the block - block = all_blks[rank] - if block.ndim == 1: - block = block.reshape(block_rows, block_cols) - C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block - return C[:orr, :orc] - def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: @@ -297,25 +536,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: bn = self._N_padded // self._P_prime # block size in N dimension bm = self._M_padded // self._P_prime # block size in M dimension - # Calculate actual local shape for this process (considering original dimensions) - local_n = bn - local_m = bm + local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn + local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - # Adjust for edge/corner processes that might have smaller blocks - if self._row_id == self._P_prime - 1: - local_n = self.N - (self._P_prime - 1) * bn - if self._col_id == self._P_prime - 1: - local_m = self.M - (self._P_prime - 1) * bm - - local_shape = local_n * local_m - - # Create local_shapes array for all processes - local_shapes = [] - for rank in range(self.size): - row_id, col_id = divmod(rank, self._P_prime) - proc_n = bn if row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn - proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - local_shapes.append(proc_n * proc_m) + local_shapes = self.base_comm.allgather(local_n * local_m) y = DistributedArray(global_shape=(self.N * self.M), mask=x.mask, @@ -330,9 +554,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: # The input x corresponds to blocks from matrix B (K x M) # This process should receive a block of size (local_k x local_m) - local_k = bk - if self._row_id == self._P_prime - 1: - local_k = self.K - (self._P_prime - 1) * bk + local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk # Reshape x.local_array to its 2D block form x_block = x.local_array.reshape((local_k, local_m)) @@ -367,24 +589,11 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: bm = self._M_padded // self._P_prime # block size in M dimension # Calculate actual local shape for this process (considering original dimensions) - local_k = bk - local_m = bm - # Adjust for edge/corner processes that might have smaller blocks - if self._row_id == self._P_prime - 1: - local_k = self.K - (self._P_prime - 1) * bk - if self._col_id == self._P_prime - 1: - local_m = self.M - (self._P_prime - 1) * bm - - local_shape = local_k * local_m + local_k = bk if self._row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk + local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - # Create local_shapes array for all processes - local_shapes = [] - for rank in range(self.size): - row_id, col_id = divmod(rank, self._P_prime) - proc_k = bk if row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk - proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm - local_shapes.append(proc_k * proc_m) + local_shapes = self.base_comm.allgather(local_k * local_m) y = DistributedArray( global_shape=(self.K * self.M), @@ -400,9 +609,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: # The input x corresponds to blocks from the result (N x M) # This process should receive a block of size (local_n x local_m) - local_n = bn - if self._row_id == self._P_prime - 1: - local_n = self.N - (self._P_prime - 1) * bn + local_n = bn if self._row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn # Reshape x.local_array to its 2D block form x_block = x.local_array.reshape((local_n, local_m)) From 1ef09ab7895625d77b1c22b4ede3cb96fc2674f8 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Wed, 23 Jul 2025 21:29:56 +0200 Subject: [PATCH 10/25] converted Bcast into bcast --- pylops_mpi/basicoperators/MatrixMult.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 6df75ae0..bf48db5c 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -496,6 +496,7 @@ def __init__( self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) self.A = A.astype(np.dtype(dtype)) + if saveAt: self.At = A.T.conj() self.N = self._col_comm.allreduce(A.shape[0]) self.K = self._row_comm.allreduce(A.shape[1]) @@ -571,8 +572,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for k in range(self._P_prime): Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A) Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block) - self._row_comm.Bcast(Atemp, root=k) - self._col_comm.Bcast(Xtemp, root=k) + self._row_comm.bcast(Atemp, root=k) + self._col_comm.bcast(Xtemp, root=k) Y_local += ncp.dot(Atemp, Xtemp) Y_local_unpadded = Y_local[:local_n, :local_m] From 56e9414659bdd237718bfafd31c6fb7c785c3a52 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:01:54 +0200 Subject: [PATCH 11/25] Added docstring --- pylops_mpi/basicoperators/MatrixMult.py | 87 +++++++++++++++++++++---- 1 file changed, 76 insertions(+), 11 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index bf48db5c..6d4e0649 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -95,28 +95,89 @@ def block_distribute(array:NDArray, rank:int, comm: MPI.Comm, pad:bool=False): if pad and (pr or pc): block = np.pad(block, [(0, pr), (0, pc)], mode='constant') return block, (new_r, new_c) -def local_block_spit(global_shape: Tuple[int, int], rank: int, comm: MPI.Comm) -> Tuple[slice, slice]: +def local_block_spit(global_shape: Tuple[int, int], + rank: int, + comm: MPI.Comm) -> Tuple[slice, slice]: + """ + Compute the local sub‐block of a 2D global array for a process in a square process grid. + + Parameters + ---------- + global_shape : Tuple[int, int] + Dimensions of the global 2D array (n_rows, n_cols). + rank : int + Rank of the MPI process in `comm` for which to get the owned block partition. + comm : MPI.Comm + MPI communicator whose total number of processes :math:`\mathbf{P}` + must be a perfect square :math:`\mathbf{P} = \sqrt{\mathbf{P'}}`. + + Returns + ------- + Tuple[slice, slice] + Two `slice` objects `(row_slice, col_slice)` indicating the sub‐block + of the global array owned by this rank. + + Raises + ------ + ValueError + if `rank` is out of range. + RuntimeError + If the number of processes participating in the provided communicator is not a perfect square. + """ size = comm.Get_size() p_prime = math.isqrt(size) if p_prime * p_prime != size: - raise Exception(f"Number of processes must be a square number, provided {size} instead...") + raise RuntimeError(f"Number of processes must be a square number, provided {size} instead...") + if not ( isinstance(rank, int) and 0 <= rank < size ): + raise ValueError(f"rank must be integer in [0, {size}), got {rank!r}") proc_i, proc_j = divmod(rank, p_prime) orig_r, orig_c = global_shape + new_r = math.ceil(orig_r / p_prime) * p_prime new_c = math.ceil(orig_c / p_prime) * p_prime - br, bc = new_r // p_prime, new_c // p_prime - i0, j0 = proc_i * br, proc_j * bc - i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) + blkr, blkc = new_r // p_prime, new_c // p_prime - i_end = None if proc_i == p_prime - 1 else i1 - j_end = None if proc_j == p_prime - 1 else j1 - return slice(i0, i_end), slice(j0, j_end) + i0, j0 = proc_i * blkr, proc_j * blkc + i1, j1 = min(i0 + blkr, orig_r), min(j0 + blkc, orig_c) + + return slice(i0, i1), slice(j0, j1) + + +def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm): + """ + Gather distributed local blocks from 2D block distributed matrix distributed + amongst a square process grid into the full global array. + + Parameters + ---------- + x : :obj:`pylops_mpi.DistributedArray` + The distributed array to gather locally. + new_shape : Tuple[int, int] + Shape `(N', M')` of the padded global array, where both dimensions + are multiples of :math:`\sqrt{\mathbf{P}}`. + orig_shape : Tuple[int, int] + Original shape `(N, M)` of the global array before padding. + comm : MPI.Comm + MPI communicator whose size must be a perfect square (P = p_prime**2). + + Returns + ------- + Array + The reconstructed 2D array of shape `orig_shape`, assembled from + the distributed blocks. -def block_gather(x, new_shape, orig_shape, comm): + Raises + ------ + RuntimeError + If the number of processes participating in the provided communicator is not a perfect square. + """ ncp = get_module(x.engine) p_prime = math.isqrt(comm.Get_size()) + if p_prime * p_prime != comm.Get_size(): + raise RuntimeError(f"Communicator size must be a perfect square, got {comm.Get_size()!r}") + all_blks = comm.allgather(x.local_array) nr, nc = new_shape @@ -151,10 +212,14 @@ def block_gather(x, new_shape, orig_shape, comm): block = all_blks[rank] if block.ndim == 1: block = block.reshape(block_rows, block_cols) - C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block + C[start_row:start_row + block_rows, + start_col:start_col + block_cols] = block + + # Trim off any padding return C[:orr, :orc] + class MPIMatrixMult(MPILinearOperator): r"""MPI Matrix multiplication @@ -360,7 +425,7 @@ class MPISummaMatrixMult(MPILinearOperator): Implements distributed matrix-matrix multiplication using the SUMMA algorithm between a matrix :math:`\mathbf{A}` distributed over a 2D process grid and input model and data vectors, which are both interpreted as matrices - distributed in block-column fashion. + distributed in block fashion wherein each process owns a tile of the matrix. Parameters ---------- From f3d19180973074356703e1018f8ef781ec37afb6 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:02:21 +0200 Subject: [PATCH 12/25] removed block distribute function --- pylops_mpi/basicoperators/MatrixMult.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 6d4e0649..d4f9a0f2 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -70,30 +70,6 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): return new_comm, new_rank, new_row, new_col, True -def block_distribute(array:NDArray, rank:int, comm: MPI.Comm, pad:bool=False): - size = comm.Get_size() - p_prime = math.isqrt(size) - if p_prime * p_prime != size: - raise Exception(f"Number of processes must be a square number, provided {size} instead...") - - proc_i, proc_j = divmod(rank, p_prime) - orig_r, orig_c = array.shape - - new_r = math.ceil(orig_r / p_prime) * p_prime - new_c = math.ceil(orig_c / p_prime) * p_prime - - br, bc = new_r // p_prime, new_c // p_prime - i0, j0 = proc_i * br, proc_j * bc - i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c) - - i_end = None if proc_i == p_prime - 1 else i1 - j_end = None if proc_j == p_prime - 1 else j1 - block = array[i0:i_end, j0:j_end] - - pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0 - pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0 - if pad and (pr or pc): block = np.pad(block, [(0, pr), (0, pc)], mode='constant') - return block, (new_r, new_c) def local_block_spit(global_shape: Tuple[int, int], rank: int, From 9d36b0cd63d22aa0663cb084af665f2ae0444510 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:04:44 +0200 Subject: [PATCH 13/25] removed unnecessary check on local matrix A --- pylops_mpi/basicoperators/MatrixMult.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index d4f9a0f2..81445d8d 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -554,10 +554,6 @@ def __init__( pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0 pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0 - if pr < 0 or pc < 0: - raise Exception(f"Improper distribution of A expected local shape " - f"( ≤ {bn}, ≤ {bk}) but got ({A.shape[0]},{A.shape[1]})") - if pr > 0 or pc > 0: self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant') From 66e3296219cf2590f7cae8efeaf7fcd362125de9 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Thu, 24 Jul 2025 00:21:09 +0200 Subject: [PATCH 14/25] Added Generic MatMulOp with docstring --- pylops_mpi/basicoperators/MatrixMult.py | 109 ++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 5 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 81445d8d..59d6eb8f 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -1,6 +1,6 @@ import math import numpy as np -from typing import Tuple +from typing import Tuple, Union, Literal from mpi4py import MPI from pylops.utils.backend import get_module @@ -196,8 +196,8 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu -class MPIMatrixMult(MPILinearOperator): - r"""MPI Matrix multiplication +class _MPIBlockMatrixMult(MPILinearOperator): + r"""MPI Blocked Matrix multiplication Implement distributed matrix-matrix multiplication between a matrix :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored @@ -395,7 +395,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y[:] = y_layer.flatten() return y -class MPISummaMatrixMult(MPILinearOperator): +class _MPISummaMatrixMult(MPILinearOperator): r"""MPI SUMMA Matrix multiplication Implements distributed matrix-matrix multiplication using the SUMMA algorithm @@ -681,4 +681,103 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: Y_local_unpadded = Y_local[:local_k, :local_m] y[:] = Y_local_unpadded.flatten() - return y \ No newline at end of file + return y + +class MPIMatrixMult(MPILinearOperator): + r""" + MPI Distributed Matrix Multiplication Operator + + This general operator performs distributed matrix-matrix multiplication + using either the SUMMA (Scalable Universal Matrix Multiplication Algorithm) + or a 1D block-row decomposition algorithm, depending on the specified + ``kind`` parameter. + + The forward operation computes:: + + Y = A @ X + + where: + - ``A`` is the distributed operator matrix of shape ``[N x K]`` + - ``X`` is the distributed operand matrix of shape ``[K x M]`` + - ``Y`` is the resulting distributed matrix of shape ``[N x M]`` + + The adjoint (conjugate-transpose) operation computes:: + + X_adj = A.H @ Y + + where ``A.H`` is the complex-conjugate transpose of ``A``. + + Distribution Layouts + -------------------- + :summa: + 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`: + - ``A`` and ``X`` are partitioned into :math:`[N_loc \times K_loc]` and + :math:`[K_loc \times M_loc]` tiles on each rank, respectively. + - Each SUMMA iteration broadcasts row- and column-blocks of ``A`` and + ``X`` and accumulates local partial products. + + :block: + 1D block-row distribution over a 1 x P grid: + - ``A`` is partitioned into :math:`[N_loc \times K]` blocks across ranks. + - ``X`` (and result ``Y``) are partitioned into :math:`[K \times M_loc]` blocks. + - Local multiplication is followed by row-wise gather (forward) or + allreduce (adjoint) across ranks. + + Parameters + ---------- + A : NDArray + Local block of the matrix operator. + M : int + Global number of columns in the operand and result matrices. + saveAt : bool, optional + If ``True``, store both ``A`` and its conjugate transpose ``A.H`` + to accelerate adjoint operations (uses twice the memory). + Default is ``False``. + base_comm : mpi4py.MPI.Comm, optional + MPI communicator to use. Defaults to ``MPI.COMM_WORLD``. + kind : {'summa', 'block'}, optional + Algorithm to use: ``'summa'`` for the SUMMA 2D algorithm, or + ``'block'`` for the block-row-col decomposition. Default is ``'summa'``. + dtype : DTypeLike, optional + Numeric data type for computations. Default is ``np.float64``. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + comm : mpi4py.MPI.Comm + The MPI communicator in use. + kind : str + Selected distributed matrix multiply algorithm ('summa' or 'block'). + + Raises + ------ + NotImplementedError + If ``kind`` is not one of ``'summa'`` or ``'block'``. + Exception + If the MPI communicator does not form a compatible grid for the + selected algorithm. + """ + def __init__( + self, + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + kind:Literal["summa", "block"] = "summa", + dtype: DTypeLike = "float64", + ): + if kind == "summa": + self._f = _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype) + elif kind == "block": + self._f = _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype) + else: + raise NotImplementedError("kind must be summa or block") + self.kind = kind + super().__init__(shape=self._f.shape, dtype=dtype, base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + return self._f.matvec(x) + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + return self._f.rmatvec(x) \ No newline at end of file From 0956e7b7b0182769f82f63e1cb179c92429df174 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Thu, 24 Jul 2025 02:04:29 +0200 Subject: [PATCH 15/25] Converted it to a function --- pylops_mpi/basicoperators/MatrixMult.py | 111 +++++++++++------------- 1 file changed, 50 insertions(+), 61 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 59d6eb8f..8faa1236 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -195,7 +195,6 @@ def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tu return C[:orr, :orc] - class _MPIBlockMatrixMult(MPILinearOperator): r"""MPI Blocked Matrix multiplication @@ -214,7 +213,7 @@ class _MPIBlockMatrixMult(MPILinearOperator): Global leading dimension (i.e., number of columns) of the matrices representing the input model and data vectors. saveAt : :obj:`bool`, optional - Save ``A`` and ``A.H`` to speed up the computation of adjoint + Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint (``True``) or create ``A.H`` on-the-fly (``False``) Note that ``saveAt=True`` will double the amount of required memory. Default is ``False``. @@ -253,22 +252,22 @@ class _MPIBlockMatrixMult(MPILinearOperator): processes by a factor equivalent to :math:`\sqrt{P}` across a square process grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically: - - The matrix ``A`` is distributed across MPI processes in a block-row fashion - and each process holds a local block of ``A`` with shape + - The matrix :math:`\mathbf{A}` is distributed across MPI processes in a block-row fashion + and each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K]` - - The operand matrix ``X`` is distributed in a block-column fashion and - each process holds a local block of ``X`` with shape + - The operand matrix :math:`\mathbf{X}` is distributed in a block-column fashion and + each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K \times M_{loc}]` - Communication is minimized by using a 2D process grid layout **Forward Operation step-by-step** - 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` + 1. **Input Preparation**: The input vector ``x`` (flattened from matrix :math:`\mathbf{X}` of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local`` is the number of columns assigned to the current process. 2. **Local Computation**: Each process computes ``A_local @ X_local`` where: - - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``) + - ``A_local`` is the local block of matrix :math:`\mathbf{A}` (shape ``N_local x K``) - ``X_local`` is the broadcasted operand (shape ``K x M_local``) 3. **Row-wise Gather**: Results from all processes in each row are gathered @@ -283,10 +282,10 @@ class _MPIBlockMatrixMult(MPILinearOperator): representing the local columns of the input matrix. 2. **Local Adjoint Computation**: Each process computes - ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed - and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as + ``A_local.H @ X_tile`` where ``A_local.H`` is either pre-computed + and stored in ``At`` (if ``saveAt=True``), or computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its - transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``) + transposed local :math:`\mathbf{A}` block ``A_local^H`` (shape ``K x N_block``) with the extracted ``X_tile`` (shape ``N_block x M_local``), producing a partial result of shape ``(K, M_local)``. This computes the local contribution of columns of ``A^H`` to the final @@ -413,7 +412,7 @@ class _MPISummaMatrixMult(MPILinearOperator): Global number of columns of the matrices representing the input model and data vectors. saveAt : :obj:`bool`, optional - Save ``A`` and ``A.H`` to speed up the computation of adjoint + Save :math:`\mathbf{A}` and ``A.H`` to speed up the computation of adjoint (``True``) or create ``A.H`` on-the-fly (``False``). Note that ``saveAt=True`` will double the amount of required memory. Default is ``False``. @@ -451,16 +450,16 @@ class _MPISummaMatrixMult(MPILinearOperator): This implementation is based on a 2D block distribution across a square process grid (:math:`\sqrt{P}\times\sqrt{P}`). The matrices are distributed as follows: - - The matrix ``A`` is distributed across MPI processes in 2D blocks where - each process holds a local block of ``A`` with shape :math:`[N_{loc} \times K_{loc}]` + - The matrix :math:`\mathbf{A}` is distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{A}` with shape :math:`[N_{loc} \times K_{loc}]` where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`K_{loc} = \frac{K}{\sqrt{P}}`. - - The operand matrix ``X`` is also distributed across MPI processes in 2D blocks where - each process holds a local block of ``X`` with shape :math:`[K_{loc} \times M_{loc}]` + - The operand matrix :math:`\mathbf{X}` is also distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{X}` with shape :math:`[K_{loc} \times M_{loc}]` where :math:`K_{loc} = \frac{K}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. - - The result matrix ``Y`` is also distributed across MPI processes in 2D blocks where - each process holds a local block of ``Y`` with shape :math:`[N_{loc} \times M_{loc}]` + - The result matrix :math:`\mathbf{Y}` is also distributed across MPI processes in 2D blocks where + each process holds a local block of :math:`\mathbf{Y}` with shape :math:`[N_{loc} \times M_{loc}]` where :math:`N_{loc} = \frac{N}{\sqrt{P}}` and :math:`M_{loc} = \frac{M}{\sqrt{P}}`. @@ -473,10 +472,10 @@ class _MPISummaMatrixMult(MPILinearOperator): 2. **SUMMA Iteration**: For each step ``k`` in the SUMMA algorithm -- :math:`k \in \[ 0, \sqrt{P} \)}` : - a. **Broadcast A blocks**: Process in column ``k`` broadcasts its ``A`` + a. **Broadcast A blocks**: Process in column ``k`` broadcasts its :math:`\mathbf{A}` block to all other processes in the same process row. - b. **Broadcast X blocks**: Process in row ``k`` broadcasts its ``X`` + b. **Broadcast X blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{X}` block to all other processes in the same process column. c. **Local Computation**: Each process computes the partial matrix @@ -484,7 +483,7 @@ class _MPISummaMatrixMult(MPILinearOperator): local result. 3. **Result Assembly**: After all k SUMMA iterations, each process has computed - its local block of the result matrix ``Y``. + its local block of the result matrix :math:`\mathbf{Y}`. **Adjoint Operation (SUMMA Algorithm)** @@ -496,11 +495,11 @@ class _MPISummaMatrixMult(MPILinearOperator): 2. **SUMMA Adjoint Iteration**: For each step ``k`` in the adjoint SUMMA algorithm: - a. **Broadcast A^H blocks**: The conjugate transpose of ``A`` blocks is + a. **Broadcast A^H blocks**: The conjugate transpose of :math:`\mathbf{A}` blocks is communicated between processes. If ``saveAt=True``, the pre-computed ``A.H`` is used; otherwise, ``A.T.conj()`` is computed on-the-fly. - b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its ``Y`` + b. **Broadcast Y blocks**: Process in row ``k`` broadcasts its :math:`\mathbf{Y}` block to all other processes in the same process column. c. **Local Adjoint Computation**: Each process computes the partial @@ -683,7 +682,14 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y[:] = Y_local_unpadded.flatten() return y -class MPIMatrixMult(MPILinearOperator): +def MPIMatrixMult( + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + kind: Literal["summa", "block"] = "summa", + dtype: DTypeLike = "float64", + ): r""" MPI Distributed Matrix Multiplication Operator @@ -694,32 +700,32 @@ class MPIMatrixMult(MPILinearOperator): The forward operation computes:: - Y = A @ X + :math:`\mathbf{Y} = \mathbf{A} \cdot \mathbf{X}` where: - - ``A`` is the distributed operator matrix of shape ``[N x K]`` - - ``X`` is the distributed operand matrix of shape ``[K x M]`` - - ``Y`` is the resulting distributed matrix of shape ``[N x M]`` + - :math:`\mathbf{A}` is the distributed operator matrix of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` The adjoint (conjugate-transpose) operation computes:: + + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` - X_adj = A.H @ Y - - where ``A.H`` is the complex-conjugate transpose of ``A``. + where :math:`\mathbf{A}^H` is the complex-conjugate transpose of :math:`\mathbf{A}`. Distribution Layouts -------------------- :summa: 2D block-grid distribution over a square process grid :math:`[\sqrt{P} \times \sqrt{P}]`: - - ``A`` and ``X`` are partitioned into :math:`[N_loc \times K_loc]` and + - :math:`\mathbf{A}` and :math:`\mathbf{X}` are partitioned into :math:`[N_loc \times K_loc]` and :math:`[K_loc \times M_loc]` tiles on each rank, respectively. - - Each SUMMA iteration broadcasts row- and column-blocks of ``A`` and - ``X`` and accumulates local partial products. + - Each SUMMA iteration broadcasts row- and column-blocks of :math:`\mathbf{A}` and + :math:`\mathbf{X}` and accumulates local partial products. :block: - 1D block-row distribution over a 1 x P grid: - - ``A`` is partitioned into :math:`[N_loc \times K]` blocks across ranks. - - ``X`` (and result ``Y``) are partitioned into :math:`[K \times M_loc]` blocks. + 1D block-row distribution over a :math:`[1 \times P]` grid: + - :math:`\mathbf{A}` is partitioned into :math:`[N_loc \times K]` blocks across ranks. + - :math:`\mathbf{X}` (and result :math:`\mathbf{Y}`) are partitioned into :math:`[K \times M_loc]` blocks. - Local multiplication is followed by row-wise gather (forward) or allreduce (adjoint) across ranks. @@ -730,7 +736,7 @@ class MPIMatrixMult(MPILinearOperator): M : int Global number of columns in the operand and result matrices. saveAt : bool, optional - If ``True``, store both ``A`` and its conjugate transpose ``A.H`` + If ``True``, store both :math:`\mathbf{A}` and its conjugate transpose :math:`\mathbf{A}^H` to accelerate adjoint operations (uses twice the memory). Default is ``False``. base_comm : mpi4py.MPI.Comm, optional @@ -758,26 +764,9 @@ class MPIMatrixMult(MPILinearOperator): If the MPI communicator does not form a compatible grid for the selected algorithm. """ - def __init__( - self, - A: NDArray, - M: int, - saveAt: bool = False, - base_comm: MPI.Comm = MPI.COMM_WORLD, - kind:Literal["summa", "block"] = "summa", - dtype: DTypeLike = "float64", - ): - if kind == "summa": - self._f = _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype) - elif kind == "block": - self._f = _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype) - else: - raise NotImplementedError("kind must be summa or block") - self.kind = kind - super().__init__(shape=self._f.shape, dtype=dtype, base_comm=base_comm) - - def _matvec(self, x: DistributedArray) -> DistributedArray: - return self._f.matvec(x) - - def _rmatvec(self, x: DistributedArray) -> DistributedArray: - return self._f.rmatvec(x) \ No newline at end of file + if kind == "summa": + return _MPISummaMatrixMult(A,M,saveAt,base_comm,dtype) + elif kind == "block": + return _MPIBlockMatrixMult(A, M, saveAt, base_comm, dtype) + else: + raise NotImplementedError("kind must be summa or block") From b053f5bb5a668a903c2fa0cbe5f2218ee4ce14eb Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:00:00 +0200 Subject: [PATCH 16/25] Added SUMMA tests and fixed dtype problem --- examples/plot_matrixmult.py | 6 +- examples/plot_summamatrixmult.py | 78 ++++++++++------- pylops_mpi/basicoperators/MatrixMult.py | 54 ++++++++---- tests/test_matrixmult.py | 110 ++++++++++++++++++++++-- 4 files changed, 192 insertions(+), 56 deletions(-) diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py index 9c7e2d35..082d924c 100644 --- a/examples/plot_matrixmult.py +++ b/examples/plot_matrixmult.py @@ -28,6 +28,7 @@ import pylops_mpi from pylops_mpi import Partition +from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult plt.close("all") @@ -88,8 +89,7 @@ # than the row or columm ranks. base_comm = MPI.COMM_WORLD -comm, rank, row_id, col_id, is_active = \ - pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M) +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") if not is_active: exit(0) @@ -147,7 +147,7 @@ ################################################################################ # We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` # operator and the input matrix :math:`\mathbf{X}` -Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32") +Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block") col_lens = comm.allgather(my_own_cols) total_cols = np.sum(col_lens) diff --git a/examples/plot_summamatrixmult.py b/examples/plot_summamatrixmult.py index 4aa85535..50499287 100644 --- a/examples/plot_summamatrixmult.py +++ b/examples/plot_summamatrixmult.py @@ -1,11 +1,28 @@ +r""" +Distributed SUMMA Matrix Multiplication +======================================= +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPISummaMatrixMult` +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` +distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}` +and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly, +the adjoint operation can be performed with a matrix :math:`\mathbf{Y}` distributed +in the same fashion as matrix :math:`\mathbf{X}`. + +Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly +stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and +:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray` +objects where the different blocks are flattened and stored on different ranks. +Note that to optimize communications, the ranks are organized in a square grid and +blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast +across different ranks during computation - see below for details. +""" + import math import numpy as np from mpi4py import MPI import pylops_mpi -from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, - block_gather, - MPISummaMatrixMult) +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, block_gather, MPIMatrixMult) comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -16,43 +33,40 @@ K = 9 A_shape = (N, K) -B_shape = (K, M) -C_shape = (N, M) +x_shape = (K, M) +y_shape = (N, M) p_prime = math.isqrt(size) -assert p_prime * p_prime == size, "Number of processes must be a perfect square" - A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) -B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape) +x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) A_slice = local_block_spit(A_shape, rank, comm) -B_slice = local_block_spit(B_shape, rank, comm) +x_slice = local_block_spit(x_shape, rank, comm) A_local = A_data[A_slice] -B_local = B_data[B_slice] -# A_local, (N_new, K_new) = block_distribute(A_data,rank, comm) -# B_local, (K_new, M_new) = block_distribute(B_data,rank, comm) +x_local = x_data[x_slice] -B_dist = pylops_mpi.DistributedArray(global_shape=(K * M), - local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]), +x_dist = pylops_mpi.DistributedArray(global_shape=(K * M), + local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]), base_comm=comm, - partition=pylops_mpi.Partition.SCATTER) -B_dist.local_array[:] = B_local.flatten() + partition=pylops_mpi.Partition.SCATTER, + dtype=x_local.dtype) +x_dist.local_array[:] = x_local.flatten() -Aop = MPISummaMatrixMult(A_local, M, base_comm=comm) -C_dist = Aop @ B_dist -Z_dist = Aop.H @ C_dist +Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype) +y_dist = Aop @ x_dist +xadj_dist = Aop.H @ y_dist -C = block_gather(C_dist, (N,M), (N,M), comm) -Z = block_gather(Z_dist, (K,M), (K,M), comm) +y = block_gather(y_dist, (N,M), (N,M), comm) +xadj = block_gather(xadj_dist, (K,M), (K,M), comm) if rank == 0 : - C_correct = np.allclose(A_data @ B_data, C) - print("C expected: ", C_correct) - if not C_correct: - print("expected:\n", A_data @ B_data) - print("calculated:\n",C) - - Z_correct = np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32)) - print("Z expected: ", Z_correct) - if not Z_correct: - print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj()) - print("calculated:\n", Z.astype(np.int32)) + y_correct = np.allclose(A_data @ x_data, y) + print("y expected: ", y_correct) + if not y_correct: + print("expected:\n", A_data @ x_data) + print("calculated:\n",y) + + xadj_correct = np.allclose((A_data.T.dot((A_data @ x_data).conj())).conj(), xadj.astype(np.int32)) + print("xadj expected: ", xadj_correct) + if not xadj_correct: + print("expected:\n", (A_data.T.dot((A_data @ x_data).conj())).conj()) + print("calculated:\n", xadj.astype(np.int32)) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 8faa1236..c466ca2b 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): def local_block_spit(global_shape: Tuple[int, int], rank: int, comm: MPI.Comm) -> Tuple[slice, slice]: - """ + r""" Compute the local sub‐block of a 2D global array for a process in a square process grid. Parameters @@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int], def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm): - """ + r""" Gather distributed local blocks from 2D block distributed matrix distributed amongst a square process grid into the full global array. @@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - + output_dtype = np.result_type(self.dtype, x.dtype) y = DistributedArray( global_shape=(self.N * self.dimsd[1]), local_shapes=[(self.N * c) for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, - dtype=self.dtype, + dtype=output_dtype, base_comm=self.base_comm ) my_own_cols = self._rank_col_lens[self.rank] x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) - X_local = x_arr.astype(self.dtype) + X_local = x_arr.astype(output_dtype) Y_local = ncp.vstack( self._row_comm.allgather( ncp.matmul(self.A, X_local) @@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + # - If A is real: A^H = A^T, + # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) + # - If A is complex: A^H is complex, + # so result will be complex regardless of x + if np.iscomplexobj(self.A): + output_dtype = np.result_type(self.dtype, x.dtype) + else: + # Real matrix: A^T @ x preserves input type complexity + output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype + # But still need to check type promotion for precision + output_dtype = np.result_type(self.dtype, output_dtype) + y = DistributedArray( global_shape=(self.K * self.dimsd[1]), local_shapes=[self.K * c for c in self._rank_col_lens], mask=x.mask, partition=Partition.SCATTER, - dtype=self.dtype, + dtype=output_dtype, base_comm=self.base_comm ) - x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype) X_tile = x_arr[self._row_start:self._row_end, :] A_local = self.At if hasattr(self, "At") else self.A.T.conj() Y_local = ncp.matmul(A_local, X_tile) @@ -536,7 +548,6 @@ def __init__( self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) self.A = A.astype(np.dtype(dtype)) - if saveAt: self.At = A.T.conj() self.N = self._col_comm.allreduce(A.shape[0]) self.K = self._row_comm.allreduce(A.shape[1]) @@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + output_dtype = np.result_type(self.dtype, x.dtype) # Calculate local shapes for block distribution bn = self._N_padded // self._P_prime # block size in N dimension bm = self._M_padded // self._P_prime # block size in M dimension @@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: mask=x.mask, local_shapes=local_shapes, partition=Partition.SCATTER, - dtype=self.dtype, - base_comm=self.base_comm - ) + dtype=output_dtype, + base_comm=self.base_comm) # Calculate expected padded dimensions for x bk = self._K_padded // self._P_prime # block size in K dimension @@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if pad_k > 0 or pad_m > 0: x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant') - Y_local = np.zeros((self.A.shape[0], bm)) + Y_local = np.zeros((self.A.shape[0], bm),dtype=output_dtype) for k in range(self._P_prime): Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A) Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block) - self._row_comm.bcast(Atemp, root=k) - self._col_comm.bcast(Xtemp, root=k) + self._row_comm.Bcast(Atemp, root=k) + self._col_comm.Bcast(Xtemp, root=k) Y_local += ncp.dot(Atemp, Xtemp) Y_local_unpadded = Y_local[:local_n, :local_m] @@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm local_shapes = self.base_comm.allgather(local_k * local_m) + # - If A is real: A^H = A^T, + # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real) + # - If A is complex: A^H is complex, + # so result will be complex regardless of x + if np.iscomplexobj(self.A): + output_dtype = np.result_type(self.dtype, x.dtype) + else: + # Real matrix: A^T @ x preserves input type complexity + output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype + # But still need to check type promotion for precision + output_dtype = np.result_type(self.dtype, output_dtype) y = DistributedArray( global_shape=(self.K * self.M), mask=x.mask, local_shapes=local_shapes, partition=Partition.SCATTER, - dtype=self.dtype, + dtype=output_dtype, base_comm=self.base_comm ) @@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant') A_local = self.At if hasattr(self, "At") else self.A.T.conj() - Y_local = np.zeros((self.A.shape[1], bm)) + Y_local = np.zeros((self.A.shape[1], bm), dtype=output_dtype) for k in range(self._P_prime): requests = [] diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py index 7def7807..de5f0385 100644 --- a/tests/test_matrixmult.py +++ b/tests/test_matrixmult.py @@ -8,9 +8,10 @@ from mpi4py import MPI import pytest -from pylops.basicoperators import FirstDerivative, Identity +from pylops.basicoperators import Conj, FirstDerivative, Identity from pylops_mpi import DistributedArray, Partition -from pylops_mpi.basicoperators import MPIMatrixMult, MPIBlockDiag +from pylops_mpi.basicoperators import MPIBlockDiag, MPIMatrixMult, local_block_spit, block_gather, \ + active_grid_comm np.random.seed(42) base_comm = MPI.COMM_WORLD @@ -55,8 +56,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str): cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 - comm, rank, row_id, col_id, is_active = \ - MPIMatrixMult.active_grid_comm(base_comm, N, M) + comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) if not is_active: return size = comm.Get_size() @@ -86,7 +86,7 @@ def test_MPIMatrixMult(N, K, M, dtype_str): X_p = X_glob[:, col_start_X:col_end_X] # Create MPIMatrixMult operator - Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str) + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="block") # Create DistributedArray for input x (representing B flattened) all_local_col_len = comm.allgather(local_col_X_len) @@ -160,3 +160,103 @@ def test_MPIMatrixMult(N, K, M, dtype_str): rtol=np.finfo(np.dtype(dtype)).resolution, err_msg=f"Rank {rank}: Adjoint verification failed." ) + +@pytest.mark.mpi(min_size=1) +@pytest.mark.parametrize("N, K, M, dtype_str", test_params) +def test_MPISummaMatrixMult(N, K, M, dtype_str): + dtype = np.dtype(dtype_str) + + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + comm, rank, row_id, col_id, is_active = \ + active_grid_comm(base_comm, N, M) + if not is_active: return + + size = comm.Get_size() + + # Fill local matrices + A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_slice = local_block_spit((N, K), rank, comm) + X_slice = local_block_spit((K, M), rank, comm) + + A_p = A_glob[A_slice] + X_p = X_glob[X_slice] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str, kind="summa") + + x_dist = DistributedArray( + global_shape=(K * M), + local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]), + partition=Partition.SCATTER, + base_comm=comm, + dtype=dtype + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ x (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + # Re-organize in local matrix + y = block_gather(y_dist, (N,M), (N,M), comm) + xadj = block_gather(xadj_dist, (K,M), (K,M), comm) + + if rank == 0: + y_loc = A_glob @ X_glob + assert_allclose( + y.squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob.conj().T @ y_loc + assert_allclose( + xadj.squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) + + # Chain with another operator + Dop = Conj(dims=(A_p.shape[0], X_p.shape[1])) + DBop = MPIBlockDiag(ops=[Dop,], base_comm=comm) + Op = DBop @ Aop + + y1_dist = Op @ x_dist + xadj1_dist = Op.H @ y1_dist + + # Re-organize in local matrix + y1 = block_gather(y1_dist, (N, M), (N, M), comm) + xadj1 = block_gather(xadj1_dist, (K,M), (K,M), comm) + + if rank == 0: + y1_loc = ((A_glob @ X_glob).conj().ravel()).reshape(N, M) + 1.0j + y1_loc = y1_loc - 1.0j + + assert_allclose( + y1.squeeze(), + y1_loc.squeeze(), + rtol=np.finfo(y1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj1_loc = ((A_glob.conj().T @ y1_loc.conj()).ravel()).reshape(K, M) + assert_allclose( + xadj1.squeeze().ravel(), + xadj1_loc.squeeze().ravel(), + rtol=np.finfo(xadj1_loc.dtype).resolution, + err_msg=f"Rank {rank}: Adjoint verification failed." + ) From 8851e0599e993fe6133aacbecdd0b4a11f0175f2 Mon Sep 17 00:00:00 2001 From: astroC86 <66444189+astroC86@users.noreply.github.com> Date: Sun, 27 Jul 2025 03:07:31 +0200 Subject: [PATCH 17/25] Added documentation and example explination --- docs/source/api/index.rst | 12 ++- examples/plot_summamatrixmult.py | 128 +++++++++++++++++++----- pylops_mpi/LinearOperator.py | 1 - pylops_mpi/basicoperators/MatrixMult.py | 7 +- 4 files changed, 116 insertions(+), 32 deletions(-) diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 66e1a373..fca1c1cd 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -42,7 +42,7 @@ Basic Operators .. autosummary:: :toctree: generated/ - MPIMatrixMult + MatrixMult.MPIMatrixMult MPIBlockDiag MPIStackedBlockDiag MPIVStack @@ -118,6 +118,16 @@ Utils local_split +.. currentmodule:: pylops_mpi.basicoperators.MatrixMult + +.. autosummary:: + :toctree: generated/ + + block_gather + local_block_split + active_grid_comm + + .. currentmodule:: pylops_mpi.utils.dottest .. autosummary:: diff --git a/examples/plot_summamatrixmult.py b/examples/plot_summamatrixmult.py index 50499287..dd3f0225 100644 --- a/examples/plot_summamatrixmult.py +++ b/examples/plot_summamatrixmult.py @@ -1,7 +1,7 @@ r""" Distributed SUMMA Matrix Multiplication ======================================= -This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPISummaMatrixMult` +This example shows how to use the :py:class:`pylops_mpi.basicoperators._MPISummaMatrixMult` operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` distributed in 2D blocks across a square process grid and matrices :math:`\mathbf{X}` and :math:`\mathbf{Y}` distributed in 2D blocks across the same grid. Similarly, @@ -20,53 +20,127 @@ import math import numpy as np from mpi4py import MPI +from matplotlib import pyplot as plt import pylops_mpi -from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, block_gather, MPIMatrixMult) +from pylops import Conj +from pylops_mpi.basicoperators.MatrixMult import (local_block_spit, MPIMatrixMult, active_grid_comm) -comm = MPI.COMM_WORLD -rank = comm.Get_rank() -size = comm.Get_size() +plt.close("all") -N = 9 -M = 9 -K = 9 +############################################################################### +# We set the seed such that all processes can create the input matrices filled +# with the same random number. In practical application, such matrices will be +# filled with data that is appropriate that is appropriate the use-case. +np.random.seed(42) -A_shape = (N, K) -x_shape = (K, M) -y_shape = (N, M) -p_prime = math.isqrt(size) +N, M, K = 6, 6, 6 +A_shape, x_shape, y_shape= (N, K), (K, M), (N, M) + + +base_comm = MPI.COMM_WORLD +comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M) +print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}") + + +############################################################################### +# We are now ready to create the input matrices for our distributed matrix +# multiplication example. We need to set up: +# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand) +# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand) +# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size :math:`N \times M` +# +# For distributed computation, we arrange processes in a square grid of size +# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total +# number of MPI processes. Each process will own a block of each matrix +# according to this 2D grid layout. + +p_prime = math.isqrt(comm.Get_size()) +print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes") + +# Create global test matrices with sequential values for easy verification +# Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering) +# Matrix X: Each element :math:`X_{i,j} = i \cdot M + j` A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape) x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape) +print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})") +print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})") +print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)") + +################################################################################ +# Determine which block of each matrix this process should own +# The 2D block distribution ensures: +# - Process at grid position :math:`(i,j)` gets block :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]` +# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil` with edge processes handling remainder +# +# .. raw:: html +# +#