@@ -111,6 +111,12 @@ def __call__(self, target, input):
111111 return gen_loss + disc_loss
112112
113113
114+ def filter_dict (d , filter_key ):
115+ return {str (k )[len (filter_key ):]: v for k , v in
116+ d .items () if
117+ str (k ).startswith (filter_key )}
118+
119+
114120class Electra (JCIBaseNet ):
115121 NAME = "Electra"
116122
@@ -151,26 +157,26 @@ def __init__(self, **kwargs):
151157 self .config = ElectraConfig (** kwargs ["config" ], output_attentions = True )
152158 self .word_dropout = nn .Dropout (kwargs ["config" ].get ("word_dropout" , 0 ))
153159 model_prefix = kwargs .get ("load_prefix" , None )
154- if pretrained_checkpoint :
155- with open (pretrained_checkpoint , "rb" ) as fin :
156- model_dict = torch .load (fin ,map_location = self .device )
157- if model_prefix :
158- state_dict = {str (k )[len (model_prefix ):]:v for k ,v in model_dict ["state_dict" ].items () if str (k ).startswith (model_prefix )}
159- else :
160- state_dict = model_dict ["state_dict" ]
161- self .electra = ElectraModel .from_pretrained (None , state_dict = state_dict , config = self .config )
162- else :
163- self .electra = ElectraModel (config = self .config )
164160
165161 in_d = self .config .hidden_size
166-
167162 self .output = nn .Sequential (
168163 nn .Dropout (self .config .hidden_dropout_prob ),
169164 nn .Linear (in_d , in_d ),
170165 nn .GELU (),
171166 nn .Dropout (self .config .hidden_dropout_prob ),
172167 nn .Linear (in_d , self .config .num_labels ),
173168 )
169+ if pretrained_checkpoint :
170+ with open (pretrained_checkpoint , "rb" ) as fin :
171+ model_dict = torch .load (fin ,map_location = self .device )
172+ if model_prefix :
173+ state_dict = filter_dict (model_dict ["state_dict" ], model_prefix )
174+ else :
175+ state_dict = model_dict ["state_dict" ]
176+ self .electra = ElectraModel .from_pretrained (None , state_dict = {k :v for (k ,v ) in state_dict .items () if k .startswith ("electra." )}, config = self .config )
177+ self .output .load_state_dict (filter_dict (state_dict ,"output." ))
178+ else :
179+ self .electra = ElectraModel (config = self .config )
174180
175181 def _get_data_for_loss (self , model_output , labels ):
176182 mask = model_output .get ("target_mask" )
0 commit comments