We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ae1846e commit bca1956Copy full SHA for bca1956
examples/matixmul.py
@@ -17,9 +17,9 @@
17
assert P_prime * C >= nProcs
18
19
# matrix dims
20
-M = 4 # any M
+M = 5 # any M
21
K = 4 # any K
22
-N = 4 # any N
+N = 5 # any N
23
24
blk_rows = int(math.ceil(M / P_prime))
25
blk_cols = int(math.ceil(N / P_prime))
@@ -65,9 +65,11 @@
65
66
comm.Barrier()
67
68
-MMop_MPI = SUMMAMatrixMult(A_p, N)
69
-
70
-x = DistributedArray(global_shape=K * blk_cols * nProcs,
+MMop_MPI = SUMMAMatrixMult(A_p, N)
+col_lens = comm.allgather(my_own_cols)
+total_cols = np.add.reduce(col_lens, 0)
71
+x = DistributedArray(global_shape=K * total_cols,
72
+ local_shapes=[K * col_len for col_len in col_lens],
73
partition=Partition.SCATTER,
74
mask=[i % P_prime for i in range(comm.Get_size())],
75
dtype=np.float32)
0 commit comments