|
10 | 10 | Partition
|
11 | 11 | )
|
12 | 12 |
|
| 13 | + |
13 | 14 | class SUMMAMatrixMult(MPILinearOperator):
|
14 | 15 | def __init__(
|
15 | 16 | self,
|
16 |
| - A: NDArray, #I am going to have to assume that the partitioning has been done correctly |
| 17 | + A: NDArray, |
17 | 18 | N: int,
|
18 | 19 | base_comm: MPI.Comm = MPI.COMM_WORLD,
|
19 | 20 | dtype: DTypeLike = "float64",
|
20 | 21 | ) -> None:
|
21 |
| - rank = base_comm.Get_rank() |
22 |
| - nProcs = base_comm.Get_size() |
23 |
| - self._P_prime = int(math.ceil(math.sqrt(nProcs))) |
24 |
| - self._C = int(math.ceil(nProcs / self._P_prime)) |
25 |
| - assert self._P_prime * self._C >= nProcs |
| 22 | + rank = base_comm.Get_rank() |
| 23 | + size = base_comm.Get_size() |
| 24 | + |
| 25 | + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size |
| 26 | + self._P_prime = int(math.ceil(math.sqrt(size))) |
| 27 | + self._C = int(math.ceil(size / self._P_prime)) |
| 28 | + assert self._P_prime * self._C >= size |
| 29 | + |
| 30 | + # Compute this process's group and layer indices |
| 31 | + self._group_id = rank % self._P_prime |
| 32 | + self._layer_id = rank // self._P_prime |
| 33 | + |
| 34 | + # Split communicators by layer (rows) and by group (columns) |
| 35 | + self.base_comm = base_comm |
| 36 | + self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id) |
| 37 | + self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id) |
26 | 38 |
|
27 |
| - self.N = N |
28 | 39 | self.A = A
|
29 |
| - self._my_group = rank % self._P_prime |
30 |
| - self._my_layer = rank // self._P_prime |
31 |
| - self._layer_comm = base_comm.Split(color=self._my_layer, key=self._my_group) |
32 |
| - self._group_comm = base_comm.Split(color=self._my_group, key=self._my_layer) |
33 |
| - K_global = A.shape[1] |
34 |
| - |
35 |
| - blk_cols = int(math.ceil(self.N / self._P_prime)) |
36 |
| - col_start = self._my_group * blk_cols |
37 |
| - col_end = min(self.N, col_start + blk_cols) |
38 |
| - my_own_cols = col_end - col_start |
39 |
| - total_cols = base_comm.allreduce(my_own_cols, op=MPI.SUM) |
40 |
| - self.dims = (K_global, total_cols) |
41 | 40 |
|
42 |
| - super().__init__(shape=(1, int(np.prod(self.dims))), dtype=np.dtype(dtype), base_comm=base_comm) |
| 41 | + self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM) |
| 42 | + self.K = A.shape[1] |
| 43 | + self.N = N |
43 | 44 |
|
| 45 | + # Determine how many columns each group holds |
| 46 | + block_cols = int(math.ceil(self.N / self._P_prime)) |
| 47 | + local_col_start = self._group_id * block_cols |
| 48 | + local_col_end = min(self.N, local_col_start + block_cols) |
| 49 | + local_ncols = local_col_end - local_col_start |
| 50 | + |
| 51 | + # Sum up the total number of input columns across all processes |
| 52 | + total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM) |
| 53 | + self.dims = (self.K, total_ncols) |
| 54 | + |
| 55 | + # Recompute how many output columns each layer holds |
| 56 | + layer_col_start = self._layer_id * block_cols |
| 57 | + layer_col_end = min(self.N, layer_col_start + block_cols) |
| 58 | + layer_ncols = layer_col_end - layer_col_start |
| 59 | + total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM) |
| 60 | + |
| 61 | + self.dimsd = (self.M, total_layer_cols) |
| 62 | + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) |
| 63 | + |
| 64 | + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) |
| 65 | + |
44 | 66 | def _matvec(self, x: DistributedArray) -> DistributedArray:
|
45 | 67 | ncp = get_module(x.engine)
|
46 | 68 | if x.partition != Partition.SCATTER:
|
47 | 69 | raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
|
48 |
| - blk_cols = int(math.ceil(self.N / self._P_prime)) |
49 |
| - col_start = self._my_group * blk_cols |
50 |
| - col_end = min(self.N, col_start + blk_cols) |
51 |
| - my_own_cols = col_end - col_start |
| 70 | + blk_cols = int(math.ceil(self.N / self._P_prime)) |
| 71 | + col_start = self._group_id * blk_cols |
| 72 | + col_end = min(self.N, col_start + blk_cols) |
| 73 | + my_own_cols = col_end - col_start |
52 | 74 | x = x.local_array.reshape((self.dims[0], my_own_cols))
|
53 | 75 | C_local = None
|
54 | 76 | for t in range(self._P_prime):
|
55 | 77 | responsible_layer = t % self._C
|
56 |
| - if self._my_layer == responsible_layer: |
57 |
| - B_block = self._layer_comm.bcast(x if self._my_group == t else None, root=t) |
58 |
| - if t == self._my_layer: C_local = ncp.matmul(self.A, B_block) |
59 |
| - self.base_comm.Barrier() |
60 |
| - my_C_rows = ncp.hstack(self._group_comm.allgather(C_local)) |
61 |
| - |
62 |
| - mask = [i % self._P_prime for i in range(self.size)] |
63 |
| - row_lens = self.base_comm.allgather(self.A.shape[0]) |
64 |
| - tot_row_lens = np.add.reduce(row_lens, 0) |
65 |
| - y = DistributedArray(global_shape=(tot_row_lens, self.N), |
66 |
| - local_shapes=[(r, self.N) for r in row_lens], |
67 |
| - mask = mask, |
68 |
| - partition=Partition.SCATTER) |
69 |
| - y[:] = my_C_rows |
| 78 | + if self._layer_id == responsible_layer: |
| 79 | + B_block = self._layer_comm.bcast(x if self._group_id == t else None, root=t) |
| 80 | + if t == self._layer_id: |
| 81 | + C_local = ncp.vstack( |
| 82 | + self._layer_comm.allgather( |
| 83 | + ncp.matmul(self.A, B_block, dtype=self.dtype) |
| 84 | + ) |
| 85 | + ) |
| 86 | + |
| 87 | + layer_col_start = self._layer_id * blk_cols |
| 88 | + layer_col_end = min(self.N, layer_col_start + blk_cols) |
| 89 | + layer_ncols = layer_col_end - layer_col_start |
| 90 | + layer_col_lens = self.base_comm.allgather(layer_ncols) |
| 91 | + mask = [i // self._P_prime for i in range(self.size)] |
| 92 | + |
| 93 | + y = DistributedArray(global_shape= (self.M * self.dimsd[1]), |
| 94 | + local_shapes=[(self.M * c) for c in layer_col_lens], |
| 95 | + mask=mask, |
| 96 | + #axis=1, |
| 97 | + partition=Partition.SCATTER, |
| 98 | + dtype=self.dtype) |
| 99 | + y[:] = C_local.flatten() |
| 100 | + return y |
| 101 | + |
| 102 | + def _rmatvec(self, x: DistributedArray) -> DistributedArray: |
| 103 | + ncp = get_module(x.engine) |
| 104 | + if x.partition != Partition.SCATTER: |
| 105 | + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") |
| 106 | + |
| 107 | + # Determine local column block for this layer |
| 108 | + blk_cols = int(math.ceil(self.N / self._P_prime)) |
| 109 | + layer_col_start = self._layer_id * blk_cols |
| 110 | + layer_col_end = min(self.N, layer_col_start + blk_cols) |
| 111 | + layer_ncols = layer_col_end - layer_col_start |
| 112 | + layer_col_lens = self.base_comm.allgather(layer_ncols) |
| 113 | + x = x.local_array.reshape((self.M, layer_ncols)) |
| 114 | + |
| 115 | + # Determine local row block for this process group |
| 116 | + blk_rows = int(math.ceil(self.M / self._P_prime)) |
| 117 | + row_start = self._group_id * blk_rows |
| 118 | + row_end = min(self.M, row_start + blk_rows) |
| 119 | + |
| 120 | + B_tile = x[row_start:row_end, :] |
| 121 | + A_local = self.A.T.conj() |
| 122 | + |
| 123 | + # Pad A_local so its first dimension is divisible by _P_prime, then batch it |
| 124 | + m, b = A_local.shape |
| 125 | + r = math.ceil(m / self._P_prime) |
| 126 | + A_pad = np.zeros((r * self._P_prime, b), dtype=self.dtype) |
| 127 | + A_pad[:m, :] = A_local |
| 128 | + A_batch = A_pad.reshape(self._P_prime, r, b) |
| 129 | + |
| 130 | + # Perform local matmul and unpad |
| 131 | + Y_batch = ncp.matmul(A_batch, B_tile) |
| 132 | + Y_pad = Y_batch.reshape(r * self._P_prime, -1) |
| 133 | + y_local = Y_pad[:m, :] |
| 134 | + y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM) |
| 135 | + |
| 136 | + mask = [i // self._P_prime for i in range(self.size)] |
| 137 | + y = DistributedArray( |
| 138 | + global_shape=(self.K * self.dimsd[1]), |
| 139 | + local_shapes=[self.K * c for c in layer_col_lens], |
| 140 | + mask=mask, |
| 141 | + #axis=1 |
| 142 | + partition=Partition.SCATTER, |
| 143 | + dtype=self.dtype, |
| 144 | + ) |
| 145 | + y[:] = y_layer.flatten() |
70 | 146 | return y
|
0 commit comments