Skip to content

Commit cb2b91f

Browse files
committed
fix
1 parent e687a89 commit cb2b91f

File tree

2 files changed

+173
-70
lines changed

2 files changed

+173
-70
lines changed

.flake8

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ select = B,C,E,F,W,T4,B9
66
exclude = build,.git,.vector_cache,.github,dev,distribution,docs,__pycache__,imgs,
77
*.pyc,*.txt,.pkl,*.pt,*.tar.gz,*.yml,
88
examples/pytorch/kg_completion/spodernet,
9+
**/spodernet/**,
910
setup.py
1011
per-file-ignores =
1112
**/__init__.py:F401,F403,E402

graph4nlp/pytorch/models/base.py

+172-70
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from graph4nlp.pytorch.modules.graph_construction.node_embedding_based_graph_construction import (
1717
NodeEmbeddingBasedGraphConstruction,
1818
)
19-
from graph4nlp.pytorch.modules.graph_construction.node_embedding_based_refined_graph_construction import (
19+
from graph4nlp.pytorch.modules.graph_construction.node_embedding_based_refined_graph_construction import ( # noqa
2020
NodeEmbeddingBasedRefinedGraphConstruction,
2121
)
2222
from graph4nlp.pytorch.modules.graph_embedding.gat import GAT
@@ -26,62 +26,109 @@
2626

2727

2828
class Graph2XBase(nn.Module):
29-
def __init__(self, vocab_model, embedding_style, graph_type, emb_input_size, emb_hidden_size,
30-
31-
gnn, gnn_num_layers, gnn_direction_option, gnn_input_size, gnn_hidden_size, gnn_output_size,
32-
33-
# dropout
34-
emb_word_dropout=0.0, emb_rnn_dropout=0.0, emb_fix_word_emb=False, emb_fix_bert_emb=False,
35-
36-
gnn_feats_dropout=0.0, gnn_attn_dropout=0.0,
37-
**kwargs):
29+
def __init__(
30+
self,
31+
vocab_model,
32+
embedding_style,
33+
graph_type,
34+
emb_input_size,
35+
emb_hidden_size,
36+
gnn,
37+
gnn_num_layers,
38+
gnn_direction_option,
39+
gnn_input_size,
40+
gnn_hidden_size,
41+
gnn_output_size,
42+
# dropout
43+
emb_word_dropout=0.0,
44+
emb_rnn_dropout=0.0,
45+
emb_fix_word_emb=False,
46+
emb_fix_bert_emb=False,
47+
gnn_feats_dropout=0.0,
48+
gnn_attn_dropout=0.0,
49+
**kwargs
50+
):
3851
super(Graph2XBase, self).__init__()
3952
self.vocab_model = vocab_model
40-
self._build_embedding_encoder(graph_type=graph_type, embedding_style=embedding_style, vocab_model=vocab_model,
41-
emb_input_size=emb_input_size, emb_hidden_size=emb_hidden_size, emb_word_dropout=emb_word_dropout,
42-
emb_rnn_dropout=emb_rnn_dropout, emb_fix_word_emb=emb_fix_word_emb,
43-
emb_fix_bert_emb=emb_fix_bert_emb, **kwargs)
53+
self._build_embedding_encoder(
54+
graph_type=graph_type,
55+
embedding_style=embedding_style,
56+
vocab_model=vocab_model,
57+
emb_input_size=emb_input_size,
58+
emb_hidden_size=emb_hidden_size,
59+
emb_word_dropout=emb_word_dropout,
60+
emb_rnn_dropout=emb_rnn_dropout,
61+
emb_fix_word_emb=emb_fix_word_emb,
62+
emb_fix_bert_emb=emb_fix_bert_emb,
63+
**kwargs
64+
)
4465

45-
self._build_gnn_encoder(gnn=gnn, num_layers=gnn_num_layers,
46-
input_size=gnn_input_size, hidden_size=gnn_hidden_size, output_size=gnn_output_size,
47-
direction_option=gnn_direction_option, feats_dropout=gnn_feats_dropout,
48-
gnn_attn_dropout=gnn_attn_dropout, **kwargs)
66+
self._build_gnn_encoder(
67+
gnn=gnn,
68+
num_layers=gnn_num_layers,
69+
input_size=gnn_input_size,
70+
hidden_size=gnn_hidden_size,
71+
output_size=gnn_output_size,
72+
direction_option=gnn_direction_option,
73+
feats_dropout=gnn_feats_dropout,
74+
gnn_attn_dropout=gnn_attn_dropout,
75+
**kwargs
76+
)
4977

50-
def _build_embedding_encoder(self, graph_type, embedding_style, vocab_model,
51-
emb_input_size, emb_hidden_size, emb_rnn_dropout,
52-
emb_word_dropout,
53-
# dynamic parameters
54-
emb_sim_metric_type=None, emb_num_heads=None, emb_top_k_neigh=None,
55-
emb_epsilon_neigh=None,
56-
emb_smoothness_ratio=None, emb_connectivity_ratio=None, emb_sparsity_ratio=None,
57-
emb_alpha_fusion=None,
58-
emb_fix_word_emb=False, emb_fix_bert_emb=False, **kwargs):
78+
def _build_embedding_encoder(
79+
self,
80+
graph_type,
81+
embedding_style,
82+
vocab_model,
83+
emb_input_size,
84+
emb_hidden_size,
85+
emb_rnn_dropout,
86+
emb_word_dropout,
87+
# dynamic parameters
88+
emb_sim_metric_type=None,
89+
emb_num_heads=None,
90+
emb_top_k_neigh=None,
91+
emb_epsilon_neigh=None,
92+
emb_smoothness_ratio=None,
93+
emb_connectivity_ratio=None,
94+
emb_sparsity_ratio=None,
95+
emb_alpha_fusion=None,
96+
emb_fix_word_emb=False,
97+
emb_fix_bert_emb=False,
98+
**kwargs
99+
):
59100

60101
if not isinstance(graph_type, str):
61102
raise ValueError("graph_type parameter should be str")
62103

63104
if graph_type == "dependency":
64-
self.graph_topology = DependencyBasedGraphConstruction(embedding_style=embedding_style,
65-
vocab=vocab_model.in_word_vocab,
66-
hidden_size=emb_hidden_size,
67-
word_dropout=emb_word_dropout,
68-
rnn_dropout=emb_rnn_dropout,
69-
fix_word_emb=emb_fix_word_emb,
70-
fix_bert_emb=emb_fix_bert_emb)
105+
self.graph_topology = DependencyBasedGraphConstruction(
106+
embedding_style=embedding_style,
107+
vocab=vocab_model.in_word_vocab,
108+
hidden_size=emb_hidden_size,
109+
word_dropout=emb_word_dropout,
110+
rnn_dropout=emb_rnn_dropout,
111+
fix_word_emb=emb_fix_word_emb,
112+
fix_bert_emb=emb_fix_bert_emb,
113+
)
71114
elif graph_type == "constituency":
72-
self.graph_topology = ConstituencyBasedGraphConstruction(embedding_style=embedding_style,
73-
vocab=vocab_model.in_word_vocab,
74-
hidden_size=emb_hidden_size,
75-
word_dropout=emb_word_dropout,
76-
rnn_dropout=emb_rnn_dropout,
77-
fix_word_emb=emb_fix_word_emb)
115+
self.graph_topology = ConstituencyBasedGraphConstruction(
116+
embedding_style=embedding_style,
117+
vocab=vocab_model.in_word_vocab,
118+
hidden_size=emb_hidden_size,
119+
word_dropout=emb_word_dropout,
120+
rnn_dropout=emb_rnn_dropout,
121+
fix_word_emb=emb_fix_word_emb,
122+
)
78123
elif graph_type == "ie":
79-
self.graph_topology = IEBasedGraphConstruction(embedding_style=embedding_style,
80-
vocab=vocab_model.in_word_vocab,
81-
hidden_size=emb_hidden_size,
82-
word_dropout=emb_word_dropout,
83-
rnn_dropout=emb_rnn_dropout,
84-
fix_word_emb=emb_fix_word_emb)
124+
self.graph_topology = IEBasedGraphConstruction(
125+
embedding_style=embedding_style,
126+
vocab=vocab_model.in_word_vocab,
127+
hidden_size=emb_hidden_size,
128+
word_dropout=emb_word_dropout,
129+
rnn_dropout=emb_rnn_dropout,
130+
fix_word_emb=emb_fix_word_emb,
131+
)
85132
elif graph_type == "node_emb":
86133
self.graph_topology = NodeEmbeddingBasedGraphConstruction(
87134
vocab_model.in_word_vocab,
@@ -98,7 +145,8 @@ def _build_embedding_encoder(self, graph_type, embedding_style, vocab_model,
98145
fix_word_emb=emb_fix_word_emb,
99146
fix_bert_emb=emb_fix_bert_emb,
100147
word_dropout=emb_word_dropout,
101-
rnn_dropout=emb_rnn_dropout)
148+
rnn_dropout=emb_rnn_dropout,
149+
)
102150
elif graph_type == "node_emb_refined":
103151
self.graph_topology = NodeEmbeddingBasedRefinedGraphConstruction(
104152
vocab_model.in_word_vocab,
@@ -116,36 +164,90 @@ def _build_embedding_encoder(self, graph_type, embedding_style, vocab_model,
116164
fix_word_emb=emb_fix_word_emb,
117165
fix_bert_emb=emb_fix_bert_emb,
118166
word_dropout=emb_word_dropout,
119-
rnn_dropout=emb_rnn_dropout)
167+
rnn_dropout=emb_rnn_dropout,
168+
)
120169
else:
121170
raise NotImplementedError()
122-
self.enc_word_emb = self.graph_topology.embedding_layer.word_emb_layers['w2v'] if 'w2v' in self.graph_topology.embedding_layer.word_emb_layers else None
171+
self.enc_word_emb = (
172+
self.graph_topology.embedding_layer.word_emb_layers["w2v"]
173+
if "w2v" in self.graph_topology.embedding_layer.word_emb_layers
174+
else None
175+
)
123176

124-
def _build_gnn_encoder(self, gnn, num_layers, input_size, hidden_size, output_size, direction_option, feats_dropout,
125-
gnn_heads=None, gnn_residual=True, gnn_attn_dropout=0.0, gnn_activation=F.relu, # gat
126-
gnn_bias=True, gnn_allow_zero_in_degree=True, gnn_norm='both', gnn_weight=True,
127-
gnn_use_edge_weight=False, gnn_gcn_norm='both', # gcn
128-
gnn_n_etypes=1, # ggnn
129-
gnn_aggregator_type="lstm", # graphsage
130-
**kwargs):
177+
def _build_gnn_encoder(
178+
self,
179+
gnn,
180+
num_layers,
181+
input_size,
182+
hidden_size,
183+
output_size,
184+
direction_option,
185+
feats_dropout,
186+
gnn_heads=None,
187+
gnn_residual=True,
188+
gnn_attn_dropout=0.0,
189+
gnn_activation=F.relu, # gat
190+
gnn_bias=True,
191+
gnn_allow_zero_in_degree=True,
192+
gnn_norm="both",
193+
gnn_weight=True,
194+
gnn_use_edge_weight=False,
195+
gnn_gcn_norm="both", # gcn
196+
gnn_n_etypes=1, # ggnn
197+
gnn_aggregator_type="lstm", # graphsage
198+
**kwargs
199+
):
131200
if gnn == "gat":
132-
self.gnn_encoder = GAT(num_layers, input_size, hidden_size, output_size, gnn_heads,
133-
direction_option=direction_option,
134-
feat_drop=feats_dropout, attn_drop=gnn_attn_dropout, activation=gnn_activation,
135-
residual=gnn_residual, allow_zero_in_degree=gnn_allow_zero_in_degree)
201+
self.gnn_encoder = GAT(
202+
num_layers,
203+
input_size,
204+
hidden_size,
205+
output_size,
206+
gnn_heads,
207+
direction_option=direction_option,
208+
feat_drop=feats_dropout,
209+
attn_drop=gnn_attn_dropout,
210+
activation=gnn_activation,
211+
residual=gnn_residual,
212+
allow_zero_in_degree=gnn_allow_zero_in_degree,
213+
)
136214
elif gnn == "ggnn":
137-
self.gnn_encoder = GGNN(num_layers, input_size, hidden_size, output_size, direction_option=direction_option,
138-
use_edge_weight=gnn_use_edge_weight, feat_drop=feats_dropout, n_etypes=gnn_n_etypes)
215+
self.gnn_encoder = GGNN(
216+
num_layers,
217+
input_size,
218+
hidden_size,
219+
output_size,
220+
direction_option=direction_option,
221+
use_edge_weight=gnn_use_edge_weight,
222+
feat_drop=feats_dropout,
223+
n_etypes=gnn_n_etypes,
224+
)
139225
elif gnn == "graphsage":
140-
self.gnn_encoder = GraphSAGE(num_layers, input_size, hidden_size, output_size,
141-
aggregator_type=gnn_aggregator_type,
142-
direction_option=direction_option, feat_drop=feats_dropout,
143-
activation=gnn_activation, bias=gnn_bias, use_edge_weight=gnn_use_edge_weight)
226+
self.gnn_encoder = GraphSAGE(
227+
num_layers,
228+
input_size,
229+
hidden_size,
230+
output_size,
231+
aggregator_type=gnn_aggregator_type,
232+
direction_option=direction_option,
233+
feat_drop=feats_dropout,
234+
activation=gnn_activation,
235+
bias=gnn_bias,
236+
use_edge_weight=gnn_use_edge_weight,
237+
)
144238
elif gnn == "gcn":
145-
self.gnn_encoder = GCN(num_layers, input_size, hidden_size, output_size,
146-
direction_option=direction_option, weight=gnn_weight, gcn_norm=gnn_gcn_norm,
147-
allow_zero_in_degree=gnn_allow_zero_in_degree, activation=gnn_activation,
148-
use_edge_weight=gnn_use_edge_weight)
239+
self.gnn_encoder = GCN(
240+
num_layers,
241+
input_size,
242+
hidden_size,
243+
output_size,
244+
direction_option=direction_option,
245+
weight=gnn_weight,
246+
gcn_norm=gnn_gcn_norm,
247+
allow_zero_in_degree=gnn_allow_zero_in_degree,
248+
activation=gnn_activation,
249+
use_edge_weight=gnn_use_edge_weight,
250+
)
149251
else:
150252
raise NotImplementedError()
151253

0 commit comments

Comments
 (0)