Skip to content

Commit f0fd7cb

Browse files
committed
Implemented Adjoint and updated example
1 parent bca1956 commit f0fd7cb

File tree

2 files changed

+144
-49
lines changed

2 files changed

+144
-49
lines changed

examples/matixmul.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import math
23
import numpy as np
34
from mpi4py import MPI
@@ -17,9 +18,9 @@
1718
assert P_prime * C >= nProcs
1819

1920
# matrix dims
20-
M = 5 # any M
21-
K = 4 # any K
22-
N = 5 # any N
21+
M = 37 # any M
22+
K = 37 # any K
23+
N = 37 # any N
2324

2425
blk_rows = int(math.ceil(M / P_prime))
2526
blk_cols = int(math.ceil(N / P_prime))
@@ -65,7 +66,7 @@
6566

6667
comm.Barrier()
6768

68-
MMop_MPI = SUMMAMatrixMult(A_p, N)
69+
Aop = SUMMAMatrixMult(A_p, N)
6970
col_lens = comm.allgather(my_own_cols)
7071
total_cols = np.add.reduce(col_lens, 0)
7172
x = DistributedArray(global_shape=K * total_cols,
@@ -74,13 +75,31 @@
7475
mask=[i % P_prime for i in range(comm.Get_size())],
7576
dtype=np.float32)
7677
x[:] = B_p.flatten()
77-
y = MMop_MPI @ x
78+
y = Aop @ x
7879

7980
# ======================= VERIFICATION =================-=============
80-
C_true = (np.arange(M*K).reshape(M, K).astype(np.float32)
81-
@ np.arange(K*N).reshape(K, N).astype(np.float32))
82-
expect = C_true[row_start:row_end, :]
83-
if not np.allclose(y.local_array, expect, atol=1e-6):
84-
print(f"RANK {rank}: VERIFICATION FAILED")
81+
A = np.arange(M*K).reshape(M, K).astype(np.float32)
82+
B = np.arange(K*N).reshape(K, N).astype(np.float32)
83+
C_true = A @ B
84+
Z_true = (A.T.dot(C_true.conj())).conj()
85+
86+
87+
col_start = my_layer * blk_cols # note: same my_group index on cols
88+
col_end = min(N, col_start + blk_cols)
89+
my_own_cols = col_end - col_start
90+
expected_y = C_true[:,col_start:col_end].flatten()
91+
92+
if not np.allclose(y.local_array, expected_y, atol=1e-6):
93+
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
94+
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
95+
else:
96+
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
97+
98+
99+
z = Aop.H @ y
100+
expected_z = Z_true[:,col_start:col_end].flatten()
101+
if not np.allclose(z.local_array, expected_z, atol=1e-6):
102+
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
103+
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
85104
else:
86-
print(f"RANK {rank}: VERIFICATION PASSED")
105+
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")

pylops_mpi/basicoperators/MatrixMultiply.py

Lines changed: 114 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,137 @@
1010
Partition
1111
)
1212

13+
1314
class SUMMAMatrixMult(MPILinearOperator):
1415
def __init__(
1516
self,
16-
A: NDArray, #I am going to have to assume that the partitioning has been done correctly
17+
A: NDArray,
1718
N: int,
1819
base_comm: MPI.Comm = MPI.COMM_WORLD,
1920
dtype: DTypeLike = "float64",
2021
) -> None:
21-
rank = base_comm.Get_rank()
22-
nProcs = base_comm.Get_size()
23-
self._P_prime = int(math.ceil(math.sqrt(nProcs)))
24-
self._C = int(math.ceil(nProcs / self._P_prime))
25-
assert self._P_prime * self._C >= nProcs
22+
rank = base_comm.Get_rank()
23+
size = base_comm.Get_size()
24+
25+
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
26+
self._P_prime = int(math.ceil(math.sqrt(size)))
27+
self._C = int(math.ceil(size / self._P_prime))
28+
assert self._P_prime * self._C >= size
29+
30+
# Compute this process's group and layer indices
31+
self._group_id = rank % self._P_prime
32+
self._layer_id = rank // self._P_prime
33+
34+
# Split communicators by layer (rows) and by group (columns)
35+
self.base_comm = base_comm
36+
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
37+
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
2638

27-
self.N = N
2839
self.A = A
29-
self._my_group = rank % self._P_prime
30-
self._my_layer = rank // self._P_prime
31-
self._layer_comm = base_comm.Split(color=self._my_layer, key=self._my_group)
32-
self._group_comm = base_comm.Split(color=self._my_group, key=self._my_layer)
33-
K_global = A.shape[1]
34-
35-
blk_cols = int(math.ceil(self.N / self._P_prime))
36-
col_start = self._my_group * blk_cols
37-
col_end = min(self.N, col_start + blk_cols)
38-
my_own_cols = col_end - col_start
39-
total_cols = base_comm.allreduce(my_own_cols, op=MPI.SUM)
40-
self.dims = (K_global, total_cols)
4140

42-
super().__init__(shape=(1, int(np.prod(self.dims))), dtype=np.dtype(dtype), base_comm=base_comm)
41+
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
42+
self.K = A.shape[1]
43+
self.N = N
4344

45+
# Determine how many columns each group holds
46+
block_cols = int(math.ceil(self.N / self._P_prime))
47+
local_col_start = self._group_id * block_cols
48+
local_col_end = min(self.N, local_col_start + block_cols)
49+
local_ncols = local_col_end - local_col_start
50+
51+
# Sum up the total number of input columns across all processes
52+
total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM)
53+
self.dims = (self.K, total_ncols)
54+
55+
# Recompute how many output columns each layer holds
56+
layer_col_start = self._layer_id * block_cols
57+
layer_col_end = min(self.N, layer_col_start + block_cols)
58+
layer_ncols = layer_col_end - layer_col_start
59+
total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM)
60+
61+
self.dimsd = (self.M, total_layer_cols)
62+
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
63+
64+
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
65+
4466
def _matvec(self, x: DistributedArray) -> DistributedArray:
4567
ncp = get_module(x.engine)
4668
if x.partition != Partition.SCATTER:
4769
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
48-
blk_cols = int(math.ceil(self.N / self._P_prime))
49-
col_start = self._my_group * blk_cols
50-
col_end = min(self.N, col_start + blk_cols)
51-
my_own_cols = col_end - col_start
70+
blk_cols = int(math.ceil(self.N / self._P_prime))
71+
col_start = self._group_id * blk_cols
72+
col_end = min(self.N, col_start + blk_cols)
73+
my_own_cols = col_end - col_start
5274
x = x.local_array.reshape((self.dims[0], my_own_cols))
5375
C_local = None
5476
for t in range(self._P_prime):
5577
responsible_layer = t % self._C
56-
if self._my_layer == responsible_layer:
57-
B_block = self._layer_comm.bcast(x if self._my_group == t else None, root=t)
58-
if t == self._my_layer: C_local = ncp.matmul(self.A, B_block)
59-
self.base_comm.Barrier()
60-
my_C_rows = ncp.hstack(self._group_comm.allgather(C_local))
61-
62-
mask = [i % self._P_prime for i in range(self.size)]
63-
row_lens = self.base_comm.allgather(self.A.shape[0])
64-
tot_row_lens = np.add.reduce(row_lens, 0)
65-
y = DistributedArray(global_shape=(tot_row_lens, self.N),
66-
local_shapes=[(r, self.N) for r in row_lens],
67-
mask = mask,
68-
partition=Partition.SCATTER)
69-
y[:] = my_C_rows
78+
if self._layer_id == responsible_layer:
79+
B_block = self._layer_comm.bcast(x if self._group_id == t else None, root=t)
80+
if t == self._layer_id:
81+
C_local = ncp.vstack(
82+
self._layer_comm.allgather(
83+
ncp.matmul(self.A, B_block, dtype=self.dtype)
84+
)
85+
)
86+
87+
layer_col_start = self._layer_id * blk_cols
88+
layer_col_end = min(self.N, layer_col_start + blk_cols)
89+
layer_ncols = layer_col_end - layer_col_start
90+
layer_col_lens = self.base_comm.allgather(layer_ncols)
91+
mask = [i // self._P_prime for i in range(self.size)]
92+
93+
y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
94+
local_shapes=[(self.M * c) for c in layer_col_lens],
95+
mask=mask,
96+
#axis=1,
97+
partition=Partition.SCATTER,
98+
dtype=self.dtype)
99+
y[:] = C_local.flatten()
100+
return y
101+
102+
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
103+
ncp = get_module(x.engine)
104+
if x.partition != Partition.SCATTER:
105+
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
106+
107+
# Determine local column block for this layer
108+
blk_cols = int(math.ceil(self.N / self._P_prime))
109+
layer_col_start = self._layer_id * blk_cols
110+
layer_col_end = min(self.N, layer_col_start + blk_cols)
111+
layer_ncols = layer_col_end - layer_col_start
112+
layer_col_lens = self.base_comm.allgather(layer_ncols)
113+
x = x.local_array.reshape((self.M, layer_ncols))
114+
115+
# Determine local row block for this process group
116+
blk_rows = int(math.ceil(self.M / self._P_prime))
117+
row_start = self._group_id * blk_rows
118+
row_end = min(self.M, row_start + blk_rows)
119+
120+
B_tile = x[row_start:row_end, :]
121+
A_local = self.A.T.conj()
122+
123+
# Pad A_local so its first dimension is divisible by _P_prime, then batch it
124+
m, b = A_local.shape
125+
r = math.ceil(m / self._P_prime)
126+
A_pad = np.zeros((r * self._P_prime, b), dtype=self.dtype)
127+
A_pad[:m, :] = A_local
128+
A_batch = A_pad.reshape(self._P_prime, r, b)
129+
130+
# Perform local matmul and unpad
131+
Y_batch = ncp.matmul(A_batch, B_tile)
132+
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
133+
y_local = Y_pad[:m, :]
134+
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)
135+
136+
mask = [i // self._P_prime for i in range(self.size)]
137+
y = DistributedArray(
138+
global_shape=(self.K * self.dimsd[1]),
139+
local_shapes=[self.K * c for c in layer_col_lens],
140+
mask=mask,
141+
#axis=1
142+
partition=Partition.SCATTER,
143+
dtype=self.dtype,
144+
)
145+
y[:] = y_layer.flatten()
70146
return y

0 commit comments

Comments
 (0)