|
10 | 10 | from torch_geometric.transforms import OneHotDegree
|
11 | 11 | from torch_geometric.data import Data, InMemoryDataset
|
12 | 12 | from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, dense_to_sparse
|
13 |
| -from torch_geometric.utils import softmax, degree, sort_edge_index |
| 13 | +from torch_geometric.utils import softmax, degree |
14 | 14 | from torch_scatter import scatter
|
15 | 15 | from torch_cluster import random_walk
|
16 | 16 | from torch_sparse import spspmm, coalesce
|
17 | 17 |
|
18 | 18 |
|
| 19 | +def sort_edge_index(edge_index, edge_attr=None, num_nodes=None): |
| 20 | + r"""Row-wise sorts edge indices :obj:`edge_index`. |
| 21 | +
|
| 22 | + Args: |
| 23 | + edge_index (LongTensor): The edge indices. |
| 24 | + edge_attr (Tensor, optional): Edge weights or multi-dimensional |
| 25 | + edge features. (default: :obj:`None`) |
| 26 | + num_nodes (int, optional): The number of nodes, *i.e.* |
| 27 | + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) |
| 28 | +
|
| 29 | + :rtype: (:class:`LongTensor`, :class:`Tensor`) |
| 30 | + """ |
| 31 | + |
| 32 | + idx = edge_index[0] * num_nodes + edge_index[1] |
| 33 | + perm = idx.argsort() |
| 34 | + |
| 35 | + return edge_index[:, perm], None if edge_attr is None else edge_attr[perm] |
| 36 | + |
| 37 | + |
19 | 38 | class BinaryFuncDataset(InMemoryDataset):
|
20 | 39 | def __init__(self, root, name, transform=None, pre_transform=None, pre_filter=None):
|
21 | 40 | self.dir_name = os.path.join(root, name)
|
|
0 commit comments