|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 |
| - |
| 17 | +import logging |
18 | 18 | from dataclasses import dataclass
|
19 | 19 | from typing import List, Optional
|
20 | 20 |
|
|
29 | 29 | scatter_v,
|
30 | 30 | )
|
31 | 31 |
|
| 32 | +logger = logging.getLogger(__name__) |
| 33 | + |
32 | 34 |
|
33 | 35 | @dataclass
|
34 | 36 | class GraphPartition:
|
@@ -67,6 +69,8 @@ class GraphPartition:
|
67 | 69 | partition_size: int
|
68 | 70 | partition_rank: int
|
69 | 71 | device: torch.device
|
| 72 | + # flag to indicate using adj matrix 1-D row-decomp |
| 73 | + matrix_decomp: bool = False |
70 | 74 |
|
71 | 75 | # data structures defining partition
|
72 | 76 | # set in after initialization or during execution
|
@@ -394,12 +398,177 @@ def partition_graph_with_id_mapping(
|
394 | 398 | return graph_partition
|
395 | 399 |
|
396 | 400 |
|
| 401 | +def partition_graph_with_matrix_decomposition( |
| 402 | + global_offsets: torch.Tensor, |
| 403 | + global_indices: torch.Tensor, |
| 404 | + num_nodes: int, |
| 405 | + partition_book: torch.Tensor, |
| 406 | + partition_size: int, |
| 407 | + partition_rank: int, |
| 408 | + device: torch.device, |
| 409 | +) -> GraphPartition: |
| 410 | + """ |
| 411 | + Utility function which partitions a global graph given as CSC structure based on its adjacency |
| 412 | + matirx using 1-D row-wise decomposition. This approach ensures a 1D uniform distribution of nodes |
| 413 | + and their associated 1-hop incoming edges. By treating source and destination nodes equivalently |
| 414 | + during partitioning, this approach assumes the graph is not bipartite. |
| 415 | + This decomposition also ensures that the graph convolution (spMM) remains local by maintaining a copy of |
| 416 | + the local incoming edge features and the local node outputs from the graph convolution. |
| 417 | + The memory complexity of this approach is O[(N/P + E/P)*hid_dim*L], where N/E are the number of nodes/edges. |
| 418 | + The transformation from local node storage to local edge storage is achieved using nccl `alltoall`. |
| 419 | +
|
| 420 | + Key differences from the existing graph partition scheme (partition_graph_with_id_mapping): |
| 421 | + (1) This function partitions the global node ID space uniformly, without distinguishing |
| 422 | + between source and destination nodes (i.e., matrix row ordering or column ordering). Both |
| 423 | + src/dst or row/col nodes are indexed consistently within the adjacency matrix. |
| 424 | + (2) Each local graph (sub-matrix) can be defined/constructed by just node/edge offsets from |
| 425 | + global graph. |
| 426 | + (3) The partitioning is performed on a global graph stored in CPU memory, and then each device |
| 427 | + (rank) constructs its local graph independently from the global csc matrix. |
| 428 | +
|
| 429 | + Parameters |
| 430 | + ---------- |
| 431 | + global_offsets : torch.Tensor |
| 432 | + CSC offsets, can live on the CPU |
| 433 | + global_indices : torch.Tensor |
| 434 | + CSC indices, can live on the CPU |
| 435 | + num_nodes : int |
| 436 | + number of nodes in the global graph |
| 437 | + partition_book : torch.Tensor |
| 438 | + the boundaries of 1-D row-decomp of adj. matrix for all ranks |
| 439 | + partition_size : int |
| 440 | + number of process groups across which graph is partitioned, |
| 441 | + i.e. the number of graph partitions |
| 442 | + partition_rank : int |
| 443 | + rank within process group managing the distributed graph, i.e. |
| 444 | + the rank determining which partition the corresponding local rank |
| 445 | + will manage |
| 446 | + device : torch.device |
| 447 | + device connected to the passed partition rank, i.e. the device |
| 448 | + on which the local graph and related buffers will live on |
| 449 | + """ |
| 450 | + |
| 451 | + # initialize graph partition |
| 452 | + graph_partition = GraphPartition( |
| 453 | + partition_size=partition_size, partition_rank=partition_rank, device=device |
| 454 | + ) |
| 455 | + dtype = global_indices.dtype |
| 456 | + # -------------------------------------------------------------- |
| 457 | + # First partition the global row ptrs (dst nodes) to local row ptrs |
| 458 | + num_edges = global_indices.size(0) |
| 459 | + node_offset = partition_book[partition_rank] |
| 460 | + num_local_nodes = ( |
| 461 | + partition_book[partition_rank + 1] - partition_book[partition_rank] |
| 462 | + ) |
| 463 | + edge_partition_offset = global_offsets[node_offset] |
| 464 | + if node_offset + num_local_nodes > num_nodes: |
| 465 | + raise ValueError("Invalid node offset and number of local nodes") |
| 466 | + |
| 467 | + local_offsets = global_offsets[node_offset : node_offset + num_local_nodes + 1].to( |
| 468 | + device=device, non_blocking=True |
| 469 | + ) |
| 470 | + graph_partition.local_offsets = local_offsets - edge_partition_offset |
| 471 | + graph_partition.num_local_dst_nodes = num_local_nodes |
| 472 | + |
| 473 | + # Scan through all partitions and compress the source nodes (edges) for each partition |
| 474 | + # to fill the local send/recv buffers for all-to-all communications |
| 475 | + partition_book = partition_book.to(device=device) |
| 476 | + for to_partition in range(partition_size): |
| 477 | + local_indices = global_indices[ |
| 478 | + global_offsets[partition_book[to_partition]] : global_offsets[ |
| 479 | + partition_book[to_partition + 1] |
| 480 | + ] |
| 481 | + ].to(device=device, non_blocking=True) |
| 482 | + # compress the columns (src nodes or local_indices) for each partition and record mapping (inverse_indices) |
| 483 | + global_src_node_at_partition, inverse_indices = local_indices.unique( |
| 484 | + sorted=True, return_inverse=True |
| 485 | + ) |
| 486 | + global_src_node_at_partition_rank = ( |
| 487 | + torch.bucketize(global_src_node_at_partition, partition_book, right=True) |
| 488 | + - 1 |
| 489 | + ) |
| 490 | + src_node_indices = torch.nonzero( |
| 491 | + global_src_node_at_partition_rank == partition_rank, as_tuple=False |
| 492 | + ).squeeze(1) |
| 493 | + # fill local send buffer for alltoalls (scatter selected nodes to_partition rank) |
| 494 | + graph_partition.scatter_indices[to_partition] = ( |
| 495 | + global_src_node_at_partition[src_node_indices] - node_offset |
| 496 | + ) |
| 497 | + # fill the numbers of indices (edges), dst nodes and src nodes for each partition |
| 498 | + graph_partition.num_indices_in_each_partition[ |
| 499 | + to_partition |
| 500 | + ] = local_indices.size(0) |
| 501 | + graph_partition.num_dst_nodes_in_each_partition[to_partition] = ( |
| 502 | + partition_book[to_partition + 1] - partition_book[to_partition] |
| 503 | + ) |
| 504 | + graph_partition.num_src_nodes_in_each_partition[ |
| 505 | + to_partition |
| 506 | + ] = global_src_node_at_partition.size(0) |
| 507 | + |
| 508 | + if to_partition == partition_rank: |
| 509 | + graph_partition.local_indices = inverse_indices |
| 510 | + graph_partition.num_local_indices = graph_partition.local_indices.size(0) |
| 511 | + graph_partition.num_local_src_nodes = global_src_node_at_partition.size(0) |
| 512 | + # map from local (compressed) column indices [0, ..., num_local_src_nodes] to their global node IDs |
| 513 | + graph_partition.map_partitioned_src_ids_to_global = ( |
| 514 | + global_src_node_at_partition |
| 515 | + ) |
| 516 | + |
| 517 | + for from_partition in range(partition_size): |
| 518 | + # fill all recv buffer sizes for alltoalls |
| 519 | + graph_partition.sizes[from_partition][to_partition] = torch.count_nonzero( |
| 520 | + global_src_node_at_partition_rank == from_partition |
| 521 | + ) |
| 522 | + |
| 523 | + # trivial mappings due to 1D row-wise decomposition |
| 524 | + graph_partition.map_partitioned_dst_ids_to_global = torch.arange( |
| 525 | + node_offset, node_offset + num_local_nodes, dtype=dtype, device=device |
| 526 | + ) |
| 527 | + graph_partition.map_partitioned_edge_ids_to_global = torch.arange( |
| 528 | + edge_partition_offset, |
| 529 | + edge_partition_offset + graph_partition.num_local_indices, |
| 530 | + dtype=dtype, |
| 531 | + device=device, |
| 532 | + ) |
| 533 | + # trivial mappings due to 1D row-wise decomposition, with mem. cost O(E, N) at each dev; need to optimize |
| 534 | + graph_partition.map_concatenated_local_src_ids_to_global = torch.arange( |
| 535 | + num_nodes, dtype=dtype, device=device |
| 536 | + ) |
| 537 | + graph_partition.map_concatenated_local_edge_ids_to_global = torch.arange( |
| 538 | + num_edges, dtype=dtype, device=device |
| 539 | + ) |
| 540 | + graph_partition.map_concatenated_local_dst_ids_to_global = ( |
| 541 | + graph_partition.map_concatenated_local_src_ids_to_global |
| 542 | + ) |
| 543 | + graph_partition.map_global_src_ids_to_concatenated_local = ( |
| 544 | + graph_partition.map_concatenated_local_src_ids_to_global |
| 545 | + ) |
| 546 | + graph_partition.map_global_dst_ids_to_concatenated_local = ( |
| 547 | + graph_partition.map_concatenated_local_src_ids_to_global |
| 548 | + ) |
| 549 | + graph_partition.map_global_edge_ids_to_concatenated_local = ( |
| 550 | + graph_partition.map_concatenated_local_edge_ids_to_global |
| 551 | + ) |
| 552 | + graph_partition.matrix_decomp = True |
| 553 | + |
| 554 | + for r in range(graph_partition.partition_size): |
| 555 | + err_msg = "error in graph partition: list containing sizes of exchanged indices does not match the tensor of indices to be exchanged" |
| 556 | + if ( |
| 557 | + graph_partition.sizes[graph_partition.partition_rank][r] |
| 558 | + != graph_partition.scatter_indices[r].numel() |
| 559 | + ): |
| 560 | + raise AssertionError(err_msg) |
| 561 | + graph_partition = graph_partition.to(device=device) |
| 562 | + return graph_partition |
| 563 | + |
| 564 | + |
397 | 565 | def partition_graph_nodewise(
|
398 | 566 | global_offsets: torch.Tensor,
|
399 | 567 | global_indices: torch.Tensor,
|
400 | 568 | partition_size: int,
|
401 | 569 | partition_rank: int,
|
402 | 570 | device: torch.device,
|
| 571 | + matrix_decomp: bool = False, |
403 | 572 | ) -> GraphPartition:
|
404 | 573 | """
|
405 | 574 | Utility function which partitions a global graph given as CSC structure naively
|
@@ -429,13 +598,42 @@ def partition_graph_nodewise(
|
429 | 598 | device : torch.device
|
430 | 599 | device connected to the passed partition rank, i.e. the device
|
431 | 600 | on which the local graph and related buffers will live on
|
| 601 | + matrix_decomp : bool |
| 602 | + flag to enable matrix decomposition for partitioning |
432 | 603 | """
|
433 |
| - |
434 | 604 | num_global_src_nodes = global_indices.max().item() + 1
|
435 | 605 | num_global_dst_nodes = global_offsets.size(0) - 1
|
436 | 606 | num_dst_nodes_per_partition = (
|
437 | 607 | num_global_dst_nodes + partition_size - 1
|
438 | 608 | ) // partition_size
|
| 609 | + |
| 610 | + if matrix_decomp: |
| 611 | + if num_global_src_nodes != num_global_dst_nodes: |
| 612 | + raise ValueError( |
| 613 | + "Must be square adj. matrix (num_src=num_dst) for matrix decomposition" |
| 614 | + ) |
| 615 | + partition_book = torch.arange( |
| 616 | + 0, |
| 617 | + num_global_dst_nodes, |
| 618 | + num_dst_nodes_per_partition, |
| 619 | + dtype=global_indices.dtype, |
| 620 | + ) |
| 621 | + partition_book = torch.cat( |
| 622 | + [ |
| 623 | + partition_book, |
| 624 | + torch.tensor([num_global_dst_nodes], dtype=global_indices.dtype), |
| 625 | + ] |
| 626 | + ) |
| 627 | + return partition_graph_with_matrix_decomposition( |
| 628 | + global_offsets, |
| 629 | + global_indices, |
| 630 | + num_global_dst_nodes, |
| 631 | + partition_book, |
| 632 | + partition_size, |
| 633 | + partition_rank, |
| 634 | + device, |
| 635 | + ) |
| 636 | + |
439 | 637 | num_src_nodes_per_partition = (
|
440 | 638 | num_global_src_nodes + partition_size - 1
|
441 | 639 | ) // partition_size
|
@@ -769,6 +967,19 @@ def get_src_node_features_in_partition(
|
769 | 967 | ) -> torch.Tensor: # pragma: no cover
|
770 | 968 | # if global features only on local rank 0 also scatter, split them
|
771 | 969 | # according to the partition and scatter them to other ranks
|
| 970 | + |
| 971 | + if self.graph_partition.matrix_decomp: |
| 972 | + logger.warning( |
| 973 | + "Matrix decomposition assumes one type of node feature partition, and the graph" |
| 974 | + "adjacency matrix is square with identical src/dst node domains. " |
| 975 | + "So, only `get_dst_node_features_in_partition` is used/needed to get src or dst" |
| 976 | + "node features within a partition." |
| 977 | + ) |
| 978 | + return self.get_dst_node_features_in_partition( |
| 979 | + global_node_features, |
| 980 | + scatter_features=scatter_features, |
| 981 | + src_rank=src_rank, |
| 982 | + ) |
772 | 983 | if scatter_features:
|
773 | 984 | global_node_features = global_node_features[
|
774 | 985 | self.graph_partition.map_concatenated_local_src_ids_to_global
|
@@ -872,6 +1083,19 @@ def get_global_src_node_features(
|
872 | 1083 | if partitioned_node_features.device != self.device:
|
873 | 1084 | raise AssertionError(error_msg)
|
874 | 1085 |
|
| 1086 | + if self.graph_partition.matrix_decomp: |
| 1087 | + logger.warning( |
| 1088 | + "Matrix decomposition assumes one type of node feature partition, and the graph" |
| 1089 | + "adjacency matrix is square with identical src/dst node domains. " |
| 1090 | + "So, only `get_global_dst_node_features` is used/needed to get global src or dst" |
| 1091 | + "node features." |
| 1092 | + ) |
| 1093 | + return self.get_global_dst_node_features( |
| 1094 | + partitioned_node_features, |
| 1095 | + get_on_all_ranks=get_on_all_ranks, |
| 1096 | + dst_rank=dst_rank, |
| 1097 | + ) |
| 1098 | + |
875 | 1099 | if not get_on_all_ranks:
|
876 | 1100 | global_node_feat = gather_v(
|
877 | 1101 | partitioned_node_features,
|
|
0 commit comments