Skip to content

Commit 71712ac

Browse files
committedApr 16, 2022
Update utils.py
1 parent a2cf278 commit 71712ac

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed
 

‎utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,31 @@
1010
from torch_geometric.transforms import OneHotDegree
1111
from torch_geometric.data import Data, InMemoryDataset
1212
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, dense_to_sparse
13-
from torch_geometric.utils import softmax, degree, sort_edge_index
13+
from torch_geometric.utils import softmax, degree
1414
from torch_scatter import scatter
1515
from torch_cluster import random_walk
1616
from torch_sparse import spspmm, coalesce
1717

1818

19+
def sort_edge_index(edge_index, edge_attr=None, num_nodes=None):
20+
r"""Row-wise sorts edge indices :obj:`edge_index`.
21+
22+
Args:
23+
edge_index (LongTensor): The edge indices.
24+
edge_attr (Tensor, optional): Edge weights or multi-dimensional
25+
edge features. (default: :obj:`None`)
26+
num_nodes (int, optional): The number of nodes, *i.e.*
27+
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
28+
29+
:rtype: (:class:`LongTensor`, :class:`Tensor`)
30+
"""
31+
32+
idx = edge_index[0] * num_nodes + edge_index[1]
33+
perm = idx.argsort()
34+
35+
return edge_index[:, perm], None if edge_attr is None else edge_attr[perm]
36+
37+
1938
class BinaryFuncDataset(InMemoryDataset):
2039
def __init__(self, root, name, transform=None, pre_transform=None, pre_filter=None):
2140
self.dir_name = os.path.join(root, name)

0 commit comments

Comments
 (0)