Skip to content

Commit bca1956

Browse files
authored
Allowed non-divisible rows and cols
1 parent ae1846e commit bca1956

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

examples/matixmul.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
assert P_prime * C >= nProcs
1818

1919
# matrix dims
20-
M = 4 # any M
20+
M = 5 # any M
2121
K = 4 # any K
22-
N = 4 # any N
22+
N = 5 # any N
2323

2424
blk_rows = int(math.ceil(M / P_prime))
2525
blk_cols = int(math.ceil(N / P_prime))
@@ -65,9 +65,11 @@
6565

6666
comm.Barrier()
6767

68-
MMop_MPI = SUMMAMatrixMult(A_p, N)
69-
70-
x = DistributedArray(global_shape=K * blk_cols * nProcs,
68+
MMop_MPI = SUMMAMatrixMult(A_p, N)
69+
col_lens = comm.allgather(my_own_cols)
70+
total_cols = np.add.reduce(col_lens, 0)
71+
x = DistributedArray(global_shape=K * total_cols,
72+
local_shapes=[K * col_len for col_len in col_lens],
7173
partition=Partition.SCATTER,
7274
mask=[i % P_prime for i in range(comm.Get_size())],
7375
dtype=np.float32)

0 commit comments

Comments
 (0)