16
16
from graph4nlp .pytorch .modules .graph_construction .node_embedding_based_graph_construction import (
17
17
NodeEmbeddingBasedGraphConstruction ,
18
18
)
19
- from graph4nlp .pytorch .modules .graph_construction .node_embedding_based_refined_graph_construction import (
19
+ from graph4nlp .pytorch .modules .graph_construction .node_embedding_based_refined_graph_construction import ( # noqa
20
20
NodeEmbeddingBasedRefinedGraphConstruction ,
21
21
)
22
22
from graph4nlp .pytorch .modules .graph_embedding .gat import GAT
26
26
27
27
28
28
class Graph2XBase (nn .Module ):
29
- def __init__ (self , vocab_model , embedding_style , graph_type , emb_input_size , emb_hidden_size ,
30
-
31
- gnn , gnn_num_layers , gnn_direction_option , gnn_input_size , gnn_hidden_size , gnn_output_size ,
32
-
33
- # dropout
34
- emb_word_dropout = 0.0 , emb_rnn_dropout = 0.0 , emb_fix_word_emb = False , emb_fix_bert_emb = False ,
35
-
36
- gnn_feats_dropout = 0.0 , gnn_attn_dropout = 0.0 ,
37
- ** kwargs ):
29
+ def __init__ (
30
+ self ,
31
+ vocab_model ,
32
+ embedding_style ,
33
+ graph_type ,
34
+ emb_input_size ,
35
+ emb_hidden_size ,
36
+ gnn ,
37
+ gnn_num_layers ,
38
+ gnn_direction_option ,
39
+ gnn_input_size ,
40
+ gnn_hidden_size ,
41
+ gnn_output_size ,
42
+ # dropout
43
+ emb_word_dropout = 0.0 ,
44
+ emb_rnn_dropout = 0.0 ,
45
+ emb_fix_word_emb = False ,
46
+ emb_fix_bert_emb = False ,
47
+ gnn_feats_dropout = 0.0 ,
48
+ gnn_attn_dropout = 0.0 ,
49
+ ** kwargs
50
+ ):
38
51
super (Graph2XBase , self ).__init__ ()
39
52
self .vocab_model = vocab_model
40
- self ._build_embedding_encoder (graph_type = graph_type , embedding_style = embedding_style , vocab_model = vocab_model ,
41
- emb_input_size = emb_input_size , emb_hidden_size = emb_hidden_size , emb_word_dropout = emb_word_dropout ,
42
- emb_rnn_dropout = emb_rnn_dropout , emb_fix_word_emb = emb_fix_word_emb ,
43
- emb_fix_bert_emb = emb_fix_bert_emb , ** kwargs )
53
+ self ._build_embedding_encoder (
54
+ graph_type = graph_type ,
55
+ embedding_style = embedding_style ,
56
+ vocab_model = vocab_model ,
57
+ emb_input_size = emb_input_size ,
58
+ emb_hidden_size = emb_hidden_size ,
59
+ emb_word_dropout = emb_word_dropout ,
60
+ emb_rnn_dropout = emb_rnn_dropout ,
61
+ emb_fix_word_emb = emb_fix_word_emb ,
62
+ emb_fix_bert_emb = emb_fix_bert_emb ,
63
+ ** kwargs
64
+ )
44
65
45
- self ._build_gnn_encoder (gnn = gnn , num_layers = gnn_num_layers ,
46
- input_size = gnn_input_size , hidden_size = gnn_hidden_size , output_size = gnn_output_size ,
47
- direction_option = gnn_direction_option , feats_dropout = gnn_feats_dropout ,
48
- gnn_attn_dropout = gnn_attn_dropout , ** kwargs )
66
+ self ._build_gnn_encoder (
67
+ gnn = gnn ,
68
+ num_layers = gnn_num_layers ,
69
+ input_size = gnn_input_size ,
70
+ hidden_size = gnn_hidden_size ,
71
+ output_size = gnn_output_size ,
72
+ direction_option = gnn_direction_option ,
73
+ feats_dropout = gnn_feats_dropout ,
74
+ gnn_attn_dropout = gnn_attn_dropout ,
75
+ ** kwargs
76
+ )
49
77
50
- def _build_embedding_encoder (self , graph_type , embedding_style , vocab_model ,
51
- emb_input_size , emb_hidden_size , emb_rnn_dropout ,
52
- emb_word_dropout ,
53
- # dynamic parameters
54
- emb_sim_metric_type = None , emb_num_heads = None , emb_top_k_neigh = None ,
55
- emb_epsilon_neigh = None ,
56
- emb_smoothness_ratio = None , emb_connectivity_ratio = None , emb_sparsity_ratio = None ,
57
- emb_alpha_fusion = None ,
58
- emb_fix_word_emb = False , emb_fix_bert_emb = False , ** kwargs ):
78
+ def _build_embedding_encoder (
79
+ self ,
80
+ graph_type ,
81
+ embedding_style ,
82
+ vocab_model ,
83
+ emb_input_size ,
84
+ emb_hidden_size ,
85
+ emb_rnn_dropout ,
86
+ emb_word_dropout ,
87
+ # dynamic parameters
88
+ emb_sim_metric_type = None ,
89
+ emb_num_heads = None ,
90
+ emb_top_k_neigh = None ,
91
+ emb_epsilon_neigh = None ,
92
+ emb_smoothness_ratio = None ,
93
+ emb_connectivity_ratio = None ,
94
+ emb_sparsity_ratio = None ,
95
+ emb_alpha_fusion = None ,
96
+ emb_fix_word_emb = False ,
97
+ emb_fix_bert_emb = False ,
98
+ ** kwargs
99
+ ):
59
100
60
101
if not isinstance (graph_type , str ):
61
102
raise ValueError ("graph_type parameter should be str" )
62
103
63
104
if graph_type == "dependency" :
64
- self .graph_topology = DependencyBasedGraphConstruction (embedding_style = embedding_style ,
65
- vocab = vocab_model .in_word_vocab ,
66
- hidden_size = emb_hidden_size ,
67
- word_dropout = emb_word_dropout ,
68
- rnn_dropout = emb_rnn_dropout ,
69
- fix_word_emb = emb_fix_word_emb ,
70
- fix_bert_emb = emb_fix_bert_emb )
105
+ self .graph_topology = DependencyBasedGraphConstruction (
106
+ embedding_style = embedding_style ,
107
+ vocab = vocab_model .in_word_vocab ,
108
+ hidden_size = emb_hidden_size ,
109
+ word_dropout = emb_word_dropout ,
110
+ rnn_dropout = emb_rnn_dropout ,
111
+ fix_word_emb = emb_fix_word_emb ,
112
+ fix_bert_emb = emb_fix_bert_emb ,
113
+ )
71
114
elif graph_type == "constituency" :
72
- self .graph_topology = ConstituencyBasedGraphConstruction (embedding_style = embedding_style ,
73
- vocab = vocab_model .in_word_vocab ,
74
- hidden_size = emb_hidden_size ,
75
- word_dropout = emb_word_dropout ,
76
- rnn_dropout = emb_rnn_dropout ,
77
- fix_word_emb = emb_fix_word_emb )
115
+ self .graph_topology = ConstituencyBasedGraphConstruction (
116
+ embedding_style = embedding_style ,
117
+ vocab = vocab_model .in_word_vocab ,
118
+ hidden_size = emb_hidden_size ,
119
+ word_dropout = emb_word_dropout ,
120
+ rnn_dropout = emb_rnn_dropout ,
121
+ fix_word_emb = emb_fix_word_emb ,
122
+ )
78
123
elif graph_type == "ie" :
79
- self .graph_topology = IEBasedGraphConstruction (embedding_style = embedding_style ,
80
- vocab = vocab_model .in_word_vocab ,
81
- hidden_size = emb_hidden_size ,
82
- word_dropout = emb_word_dropout ,
83
- rnn_dropout = emb_rnn_dropout ,
84
- fix_word_emb = emb_fix_word_emb )
124
+ self .graph_topology = IEBasedGraphConstruction (
125
+ embedding_style = embedding_style ,
126
+ vocab = vocab_model .in_word_vocab ,
127
+ hidden_size = emb_hidden_size ,
128
+ word_dropout = emb_word_dropout ,
129
+ rnn_dropout = emb_rnn_dropout ,
130
+ fix_word_emb = emb_fix_word_emb ,
131
+ )
85
132
elif graph_type == "node_emb" :
86
133
self .graph_topology = NodeEmbeddingBasedGraphConstruction (
87
134
vocab_model .in_word_vocab ,
@@ -98,7 +145,8 @@ def _build_embedding_encoder(self, graph_type, embedding_style, vocab_model,
98
145
fix_word_emb = emb_fix_word_emb ,
99
146
fix_bert_emb = emb_fix_bert_emb ,
100
147
word_dropout = emb_word_dropout ,
101
- rnn_dropout = emb_rnn_dropout )
148
+ rnn_dropout = emb_rnn_dropout ,
149
+ )
102
150
elif graph_type == "node_emb_refined" :
103
151
self .graph_topology = NodeEmbeddingBasedRefinedGraphConstruction (
104
152
vocab_model .in_word_vocab ,
@@ -116,36 +164,90 @@ def _build_embedding_encoder(self, graph_type, embedding_style, vocab_model,
116
164
fix_word_emb = emb_fix_word_emb ,
117
165
fix_bert_emb = emb_fix_bert_emb ,
118
166
word_dropout = emb_word_dropout ,
119
- rnn_dropout = emb_rnn_dropout )
167
+ rnn_dropout = emb_rnn_dropout ,
168
+ )
120
169
else :
121
170
raise NotImplementedError ()
122
- self .enc_word_emb = self .graph_topology .embedding_layer .word_emb_layers ['w2v' ] if 'w2v' in self .graph_topology .embedding_layer .word_emb_layers else None
171
+ self .enc_word_emb = (
172
+ self .graph_topology .embedding_layer .word_emb_layers ["w2v" ]
173
+ if "w2v" in self .graph_topology .embedding_layer .word_emb_layers
174
+ else None
175
+ )
123
176
124
- def _build_gnn_encoder (self , gnn , num_layers , input_size , hidden_size , output_size , direction_option , feats_dropout ,
125
- gnn_heads = None , gnn_residual = True , gnn_attn_dropout = 0.0 , gnn_activation = F .relu , # gat
126
- gnn_bias = True , gnn_allow_zero_in_degree = True , gnn_norm = 'both' , gnn_weight = True ,
127
- gnn_use_edge_weight = False , gnn_gcn_norm = 'both' , # gcn
128
- gnn_n_etypes = 1 , # ggnn
129
- gnn_aggregator_type = "lstm" , # graphsage
130
- ** kwargs ):
177
+ def _build_gnn_encoder (
178
+ self ,
179
+ gnn ,
180
+ num_layers ,
181
+ input_size ,
182
+ hidden_size ,
183
+ output_size ,
184
+ direction_option ,
185
+ feats_dropout ,
186
+ gnn_heads = None ,
187
+ gnn_residual = True ,
188
+ gnn_attn_dropout = 0.0 ,
189
+ gnn_activation = F .relu , # gat
190
+ gnn_bias = True ,
191
+ gnn_allow_zero_in_degree = True ,
192
+ gnn_norm = "both" ,
193
+ gnn_weight = True ,
194
+ gnn_use_edge_weight = False ,
195
+ gnn_gcn_norm = "both" , # gcn
196
+ gnn_n_etypes = 1 , # ggnn
197
+ gnn_aggregator_type = "lstm" , # graphsage
198
+ ** kwargs
199
+ ):
131
200
if gnn == "gat" :
132
- self .gnn_encoder = GAT (num_layers , input_size , hidden_size , output_size , gnn_heads ,
133
- direction_option = direction_option ,
134
- feat_drop = feats_dropout , attn_drop = gnn_attn_dropout , activation = gnn_activation ,
135
- residual = gnn_residual , allow_zero_in_degree = gnn_allow_zero_in_degree )
201
+ self .gnn_encoder = GAT (
202
+ num_layers ,
203
+ input_size ,
204
+ hidden_size ,
205
+ output_size ,
206
+ gnn_heads ,
207
+ direction_option = direction_option ,
208
+ feat_drop = feats_dropout ,
209
+ attn_drop = gnn_attn_dropout ,
210
+ activation = gnn_activation ,
211
+ residual = gnn_residual ,
212
+ allow_zero_in_degree = gnn_allow_zero_in_degree ,
213
+ )
136
214
elif gnn == "ggnn" :
137
- self .gnn_encoder = GGNN (num_layers , input_size , hidden_size , output_size , direction_option = direction_option ,
138
- use_edge_weight = gnn_use_edge_weight , feat_drop = feats_dropout , n_etypes = gnn_n_etypes )
215
+ self .gnn_encoder = GGNN (
216
+ num_layers ,
217
+ input_size ,
218
+ hidden_size ,
219
+ output_size ,
220
+ direction_option = direction_option ,
221
+ use_edge_weight = gnn_use_edge_weight ,
222
+ feat_drop = feats_dropout ,
223
+ n_etypes = gnn_n_etypes ,
224
+ )
139
225
elif gnn == "graphsage" :
140
- self .gnn_encoder = GraphSAGE (num_layers , input_size , hidden_size , output_size ,
141
- aggregator_type = gnn_aggregator_type ,
142
- direction_option = direction_option , feat_drop = feats_dropout ,
143
- activation = gnn_activation , bias = gnn_bias , use_edge_weight = gnn_use_edge_weight )
226
+ self .gnn_encoder = GraphSAGE (
227
+ num_layers ,
228
+ input_size ,
229
+ hidden_size ,
230
+ output_size ,
231
+ aggregator_type = gnn_aggregator_type ,
232
+ direction_option = direction_option ,
233
+ feat_drop = feats_dropout ,
234
+ activation = gnn_activation ,
235
+ bias = gnn_bias ,
236
+ use_edge_weight = gnn_use_edge_weight ,
237
+ )
144
238
elif gnn == "gcn" :
145
- self .gnn_encoder = GCN (num_layers , input_size , hidden_size , output_size ,
146
- direction_option = direction_option , weight = gnn_weight , gcn_norm = gnn_gcn_norm ,
147
- allow_zero_in_degree = gnn_allow_zero_in_degree , activation = gnn_activation ,
148
- use_edge_weight = gnn_use_edge_weight )
239
+ self .gnn_encoder = GCN (
240
+ num_layers ,
241
+ input_size ,
242
+ hidden_size ,
243
+ output_size ,
244
+ direction_option = direction_option ,
245
+ weight = gnn_weight ,
246
+ gcn_norm = gnn_gcn_norm ,
247
+ allow_zero_in_degree = gnn_allow_zero_in_degree ,
248
+ activation = gnn_activation ,
249
+ use_edge_weight = gnn_use_edge_weight ,
250
+ )
149
251
else :
150
252
raise NotImplementedError ()
151
253
0 commit comments