-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmessage_passing2.py
145 lines (116 loc) · 5.78 KB
/
message_passing2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import sys
import inspect
import torch
from torch_geometric.utils import scatter_
special_args = [
'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
]
__size_error_msg__ = ('All tensors which should get mapped to the same source '
'or target nodes must be of same size in dimension 0.')
is_python2 = sys.version_info[0] < 3
getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec
class MessagePassing2(torch.nn.Module):
r"""Base class for creating message passing layers
.. math::
\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
\square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
where :math:`\square` denotes a differentiable, permutation invariant
function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
MLPs.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_gnn.html>`__ for the accompanying tutorial.
Args:
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
(default: :obj:`"add"`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
node_dim (int, optional): The axis along which to propagate.
(default: :obj:`0`)
"""
__aggr__ = ("add", "mean", "max")
def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
super(MessagePassing2, self).__init__()
self.aggr = aggr
assert self.aggr in self.__aggr__
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
assert self.node_dim >= 0
self.__message_args__ = getargspec(self.message)[0][1:]
self.__special_args__ = [(i, arg)
for i, arg in enumerate(self.__message_args__)
if arg in special_args]
self.__message_args__ = [
arg for arg in self.__message_args__ if arg not in special_args
]
def propagate(self, edge_index, size=None, **kwargs):
r"""The initial call to start propagating messages.
Args:
edge_index (Tensor): The indices of a general (sparse) assignment
matrix with shape :obj:`[N, M]` (can be directed or
undirected).
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size is tried to
get automatically inferred and assumed to be symmetric.
(default: :obj:`None`)
**kwargs: Any additional data which is needed to construct messages
and to update node embeddings.
"""
dim = self.node_dim
size = [None, None] if size is None else list(size)
assert len(size) == 2
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
ij = {"_i": i, "_j": j}
message_args = []
for arg in self.__message_args__:
if arg[-2:] in ij.keys():
tmp = kwargs.get(arg[:-2], None)
if tmp is None: # pragma: no cover
message_args.append(tmp)
else:
idx = ij[arg[-2:]]
if isinstance(tmp, tuple) or isinstance(tmp, list):
assert len(tmp) == 2
if tmp[1 - idx] is not None:
if size[1 - idx] is None:
size[1 - idx] = tmp[1 - idx].size(dim)
if size[1 - idx] != tmp[1 - idx].size(dim):
raise ValueError(__size_error_msg__)
tmp = tmp[idx]
if tmp is None:
message_args.append(tmp)
else:
if size[idx] is None:
size[idx] = tmp.size(dim)
if size[idx] != tmp.size(dim):
raise ValueError(__size_error_msg__)
tmp = torch.index_select(tmp, dim, edge_index[idx])
message_args.append(tmp)
else:
message_args.append(kwargs.get(arg, None))
size[0] = size[1] if size[0] is None else size[0]
size[1] = size[0] if size[1] is None else size[1]
kwargs['edge_index'] = edge_index
kwargs['size'] = size
for (idx, arg) in self.__special_args__:
if arg[-2:] in ij.keys():
message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
else:
message_args.insert(idx, kwargs[arg])
out = self.message(*message_args)
return out
def message(self, x_j): # pragma: no cover
r"""Constructs messages to node :math:`i` in analogy to
:math:`\phi_{\mathbf{\Theta}}` for each edge in
:math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
:math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
Can take any argument which was initially passed to :meth:`propagate`.
In addition, tensors passed to :meth:`propagate` can be mapped to the
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
return x_j