Skip to content

Commit b7b6265

Browse files
chang-lmnabian
andauthored
[Feature] Add row-decomposition of adj. matrix to reduce graph partitioning overhead (#720)
* Add matrix partition to replace current graph partition * Refactor and renaming * Add back orig impl and test * Minor update to remove notimpl blocks * Address comments * Update change log * re-format --------- Co-authored-by: Mohammad Amin Nabian <[email protected]>
1 parent 7571751 commit b7b6265

File tree

3 files changed

+276
-2
lines changed

3 files changed

+276
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Added
1212

1313
- DoMINO model architecture, datapipe and training recipe
14+
- Added matrix decomposition scheme to improve graph partitioning
1415
- DrivAerML dataset support in FIGConvNet example.
1516

1617
### Changed

modulus/models/gnn_layers/distributed_graph.py

+226-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
17+
import logging
1818
from dataclasses import dataclass
1919
from typing import List, Optional
2020

@@ -29,6 +29,8 @@
2929
scatter_v,
3030
)
3131

32+
logger = logging.getLogger(__name__)
33+
3234

3335
@dataclass
3436
class GraphPartition:
@@ -67,6 +69,8 @@ class GraphPartition:
6769
partition_size: int
6870
partition_rank: int
6971
device: torch.device
72+
# flag to indicate using adj matrix 1-D row-decomp
73+
matrix_decomp: bool = False
7074

7175
# data structures defining partition
7276
# set in after initialization or during execution
@@ -394,12 +398,177 @@ def partition_graph_with_id_mapping(
394398
return graph_partition
395399

396400

401+
def partition_graph_with_matrix_decomposition(
402+
global_offsets: torch.Tensor,
403+
global_indices: torch.Tensor,
404+
num_nodes: int,
405+
partition_book: torch.Tensor,
406+
partition_size: int,
407+
partition_rank: int,
408+
device: torch.device,
409+
) -> GraphPartition:
410+
"""
411+
Utility function which partitions a global graph given as CSC structure based on its adjacency
412+
matirx using 1-D row-wise decomposition. This approach ensures a 1D uniform distribution of nodes
413+
and their associated 1-hop incoming edges. By treating source and destination nodes equivalently
414+
during partitioning, this approach assumes the graph is not bipartite.
415+
This decomposition also ensures that the graph convolution (spMM) remains local by maintaining a copy of
416+
the local incoming edge features and the local node outputs from the graph convolution.
417+
The memory complexity of this approach is O[(N/P + E/P)*hid_dim*L], where N/E are the number of nodes/edges.
418+
The transformation from local node storage to local edge storage is achieved using nccl `alltoall`.
419+
420+
Key differences from the existing graph partition scheme (partition_graph_with_id_mapping):
421+
(1) This function partitions the global node ID space uniformly, without distinguishing
422+
between source and destination nodes (i.e., matrix row ordering or column ordering). Both
423+
src/dst or row/col nodes are indexed consistently within the adjacency matrix.
424+
(2) Each local graph (sub-matrix) can be defined/constructed by just node/edge offsets from
425+
global graph.
426+
(3) The partitioning is performed on a global graph stored in CPU memory, and then each device
427+
(rank) constructs its local graph independently from the global csc matrix.
428+
429+
Parameters
430+
----------
431+
global_offsets : torch.Tensor
432+
CSC offsets, can live on the CPU
433+
global_indices : torch.Tensor
434+
CSC indices, can live on the CPU
435+
num_nodes : int
436+
number of nodes in the global graph
437+
partition_book : torch.Tensor
438+
the boundaries of 1-D row-decomp of adj. matrix for all ranks
439+
partition_size : int
440+
number of process groups across which graph is partitioned,
441+
i.e. the number of graph partitions
442+
partition_rank : int
443+
rank within process group managing the distributed graph, i.e.
444+
the rank determining which partition the corresponding local rank
445+
will manage
446+
device : torch.device
447+
device connected to the passed partition rank, i.e. the device
448+
on which the local graph and related buffers will live on
449+
"""
450+
451+
# initialize graph partition
452+
graph_partition = GraphPartition(
453+
partition_size=partition_size, partition_rank=partition_rank, device=device
454+
)
455+
dtype = global_indices.dtype
456+
# --------------------------------------------------------------
457+
# First partition the global row ptrs (dst nodes) to local row ptrs
458+
num_edges = global_indices.size(0)
459+
node_offset = partition_book[partition_rank]
460+
num_local_nodes = (
461+
partition_book[partition_rank + 1] - partition_book[partition_rank]
462+
)
463+
edge_partition_offset = global_offsets[node_offset]
464+
if node_offset + num_local_nodes > num_nodes:
465+
raise ValueError("Invalid node offset and number of local nodes")
466+
467+
local_offsets = global_offsets[node_offset : node_offset + num_local_nodes + 1].to(
468+
device=device, non_blocking=True
469+
)
470+
graph_partition.local_offsets = local_offsets - edge_partition_offset
471+
graph_partition.num_local_dst_nodes = num_local_nodes
472+
473+
# Scan through all partitions and compress the source nodes (edges) for each partition
474+
# to fill the local send/recv buffers for all-to-all communications
475+
partition_book = partition_book.to(device=device)
476+
for to_partition in range(partition_size):
477+
local_indices = global_indices[
478+
global_offsets[partition_book[to_partition]] : global_offsets[
479+
partition_book[to_partition + 1]
480+
]
481+
].to(device=device, non_blocking=True)
482+
# compress the columns (src nodes or local_indices) for each partition and record mapping (inverse_indices)
483+
global_src_node_at_partition, inverse_indices = local_indices.unique(
484+
sorted=True, return_inverse=True
485+
)
486+
global_src_node_at_partition_rank = (
487+
torch.bucketize(global_src_node_at_partition, partition_book, right=True)
488+
- 1
489+
)
490+
src_node_indices = torch.nonzero(
491+
global_src_node_at_partition_rank == partition_rank, as_tuple=False
492+
).squeeze(1)
493+
# fill local send buffer for alltoalls (scatter selected nodes to_partition rank)
494+
graph_partition.scatter_indices[to_partition] = (
495+
global_src_node_at_partition[src_node_indices] - node_offset
496+
)
497+
# fill the numbers of indices (edges), dst nodes and src nodes for each partition
498+
graph_partition.num_indices_in_each_partition[
499+
to_partition
500+
] = local_indices.size(0)
501+
graph_partition.num_dst_nodes_in_each_partition[to_partition] = (
502+
partition_book[to_partition + 1] - partition_book[to_partition]
503+
)
504+
graph_partition.num_src_nodes_in_each_partition[
505+
to_partition
506+
] = global_src_node_at_partition.size(0)
507+
508+
if to_partition == partition_rank:
509+
graph_partition.local_indices = inverse_indices
510+
graph_partition.num_local_indices = graph_partition.local_indices.size(0)
511+
graph_partition.num_local_src_nodes = global_src_node_at_partition.size(0)
512+
# map from local (compressed) column indices [0, ..., num_local_src_nodes] to their global node IDs
513+
graph_partition.map_partitioned_src_ids_to_global = (
514+
global_src_node_at_partition
515+
)
516+
517+
for from_partition in range(partition_size):
518+
# fill all recv buffer sizes for alltoalls
519+
graph_partition.sizes[from_partition][to_partition] = torch.count_nonzero(
520+
global_src_node_at_partition_rank == from_partition
521+
)
522+
523+
# trivial mappings due to 1D row-wise decomposition
524+
graph_partition.map_partitioned_dst_ids_to_global = torch.arange(
525+
node_offset, node_offset + num_local_nodes, dtype=dtype, device=device
526+
)
527+
graph_partition.map_partitioned_edge_ids_to_global = torch.arange(
528+
edge_partition_offset,
529+
edge_partition_offset + graph_partition.num_local_indices,
530+
dtype=dtype,
531+
device=device,
532+
)
533+
# trivial mappings due to 1D row-wise decomposition, with mem. cost O(E, N) at each dev; need to optimize
534+
graph_partition.map_concatenated_local_src_ids_to_global = torch.arange(
535+
num_nodes, dtype=dtype, device=device
536+
)
537+
graph_partition.map_concatenated_local_edge_ids_to_global = torch.arange(
538+
num_edges, dtype=dtype, device=device
539+
)
540+
graph_partition.map_concatenated_local_dst_ids_to_global = (
541+
graph_partition.map_concatenated_local_src_ids_to_global
542+
)
543+
graph_partition.map_global_src_ids_to_concatenated_local = (
544+
graph_partition.map_concatenated_local_src_ids_to_global
545+
)
546+
graph_partition.map_global_dst_ids_to_concatenated_local = (
547+
graph_partition.map_concatenated_local_src_ids_to_global
548+
)
549+
graph_partition.map_global_edge_ids_to_concatenated_local = (
550+
graph_partition.map_concatenated_local_edge_ids_to_global
551+
)
552+
graph_partition.matrix_decomp = True
553+
554+
for r in range(graph_partition.partition_size):
555+
err_msg = "error in graph partition: list containing sizes of exchanged indices does not match the tensor of indices to be exchanged"
556+
if (
557+
graph_partition.sizes[graph_partition.partition_rank][r]
558+
!= graph_partition.scatter_indices[r].numel()
559+
):
560+
raise AssertionError(err_msg)
561+
graph_partition = graph_partition.to(device=device)
562+
return graph_partition
563+
564+
397565
def partition_graph_nodewise(
398566
global_offsets: torch.Tensor,
399567
global_indices: torch.Tensor,
400568
partition_size: int,
401569
partition_rank: int,
402570
device: torch.device,
571+
matrix_decomp: bool = False,
403572
) -> GraphPartition:
404573
"""
405574
Utility function which partitions a global graph given as CSC structure naively
@@ -429,13 +598,42 @@ def partition_graph_nodewise(
429598
device : torch.device
430599
device connected to the passed partition rank, i.e. the device
431600
on which the local graph and related buffers will live on
601+
matrix_decomp : bool
602+
flag to enable matrix decomposition for partitioning
432603
"""
433-
434604
num_global_src_nodes = global_indices.max().item() + 1
435605
num_global_dst_nodes = global_offsets.size(0) - 1
436606
num_dst_nodes_per_partition = (
437607
num_global_dst_nodes + partition_size - 1
438608
) // partition_size
609+
610+
if matrix_decomp:
611+
if num_global_src_nodes != num_global_dst_nodes:
612+
raise ValueError(
613+
"Must be square adj. matrix (num_src=num_dst) for matrix decomposition"
614+
)
615+
partition_book = torch.arange(
616+
0,
617+
num_global_dst_nodes,
618+
num_dst_nodes_per_partition,
619+
dtype=global_indices.dtype,
620+
)
621+
partition_book = torch.cat(
622+
[
623+
partition_book,
624+
torch.tensor([num_global_dst_nodes], dtype=global_indices.dtype),
625+
]
626+
)
627+
return partition_graph_with_matrix_decomposition(
628+
global_offsets,
629+
global_indices,
630+
num_global_dst_nodes,
631+
partition_book,
632+
partition_size,
633+
partition_rank,
634+
device,
635+
)
636+
439637
num_src_nodes_per_partition = (
440638
num_global_src_nodes + partition_size - 1
441639
) // partition_size
@@ -769,6 +967,19 @@ def get_src_node_features_in_partition(
769967
) -> torch.Tensor: # pragma: no cover
770968
# if global features only on local rank 0 also scatter, split them
771969
# according to the partition and scatter them to other ranks
970+
971+
if self.graph_partition.matrix_decomp:
972+
logger.warning(
973+
"Matrix decomposition assumes one type of node feature partition, and the graph"
974+
"adjacency matrix is square with identical src/dst node domains. "
975+
"So, only `get_dst_node_features_in_partition` is used/needed to get src or dst"
976+
"node features within a partition."
977+
)
978+
return self.get_dst_node_features_in_partition(
979+
global_node_features,
980+
scatter_features=scatter_features,
981+
src_rank=src_rank,
982+
)
772983
if scatter_features:
773984
global_node_features = global_node_features[
774985
self.graph_partition.map_concatenated_local_src_ids_to_global
@@ -872,6 +1083,19 @@ def get_global_src_node_features(
8721083
if partitioned_node_features.device != self.device:
8731084
raise AssertionError(error_msg)
8741085

1086+
if self.graph_partition.matrix_decomp:
1087+
logger.warning(
1088+
"Matrix decomposition assumes one type of node feature partition, and the graph"
1089+
"adjacency matrix is square with identical src/dst node domains. "
1090+
"So, only `get_global_dst_node_features` is used/needed to get global src or dst"
1091+
"node features."
1092+
)
1093+
return self.get_global_dst_node_features(
1094+
partitioned_node_features,
1095+
get_on_all_ranks=get_on_all_ranks,
1096+
dst_rank=dst_rank,
1097+
)
1098+
8751099
if not get_on_all_ranks:
8761100
global_node_feat = gather_v(
8771101
partitioned_node_features,

test/models/test_graph_partition.py

+49
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def global_graph():
3636
return (offsets, indices, num_src_nodes, num_dst_nodes)
3737

3838

39+
@pytest.fixture
40+
def global_graph_square():
41+
"""test fixture: simple non-bipartie graph with a degree of 2 per node"""
42+
# num_src_nodes = 4
43+
# num_dst_nodes = 4
44+
# num_edges = 8
45+
offsets = torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64)
46+
indices = torch.tensor([0, 3, 2, 1, 1, 0, 1, 2], dtype=torch.int64)
47+
48+
return (offsets, indices, 4, 4)
49+
50+
3951
def assert_partitions_are_equal(a, b):
4052
"""test utility: check if a matches b"""
4153
attributes = [
@@ -163,6 +175,43 @@ def test_gp_nodewise(global_graph, device):
163175
assert_partitions_are_equal(pg, pg_expected)
164176

165177

178+
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
179+
def test_gp_matrixdecomp(global_graph_square, device):
180+
offsets, indices, num_src_nodes, num_dst_nodes = global_graph_square
181+
partition_size = 4
182+
partition_rank = 0
183+
184+
pg = partition_graph_nodewise(
185+
offsets, indices, partition_size, partition_rank, device, matrix_decomp=True
186+
)
187+
188+
pg_expected = GraphPartition(
189+
partition_size=4,
190+
partition_rank=0,
191+
device=device,
192+
local_offsets=torch.tensor([0, 2]),
193+
local_indices=torch.tensor([0, 1]),
194+
num_local_src_nodes=2,
195+
num_local_dst_nodes=1,
196+
num_local_indices=2,
197+
map_partitioned_src_ids_to_global=torch.tensor([0, 3]),
198+
map_partitioned_dst_ids_to_global=torch.tensor([0]),
199+
map_partitioned_edge_ids_to_global=torch.tensor([0, 1]),
200+
sizes=[[1, 0, 1, 0], [0, 1, 1, 1], [0, 1, 0, 1], [1, 0, 0, 0]],
201+
scatter_indices=[
202+
torch.tensor([0]),
203+
torch.tensor([], dtype=torch.int64),
204+
torch.tensor([0]),
205+
torch.tensor([], dtype=torch.int64),
206+
],
207+
num_src_nodes_in_each_partition=[2, 2, 2, 2],
208+
num_dst_nodes_in_each_partition=[1, 1, 1, 1],
209+
num_indices_in_each_partition=[2, 2, 2, 2],
210+
).to(device=device)
211+
212+
assert_partitions_are_equal(pg, pg_expected)
213+
214+
166215
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
167216
def test_gp_coordinate_bbox(global_graph, device):
168217
offsets, indices, num_src_nodes, num_dst_nodes = global_graph

0 commit comments

Comments
 (0)