Skip to content

Commit 1a9cc28

Browse files
committed
Fix electra loader
1 parent 7cdfa64 commit 1a9cc28

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

chebai/models/electra.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def __call__(self, target, input):
111111
return gen_loss + disc_loss
112112

113113

114+
def filter_dict(d, filter_key):
115+
return {str(k)[len(filter_key):]: v for k, v in
116+
d.items() if
117+
str(k).startswith(filter_key)}
118+
119+
114120
class Electra(JCIBaseNet):
115121
NAME = "Electra"
116122

@@ -151,26 +157,26 @@ def __init__(self, **kwargs):
151157
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
152158
self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0))
153159
model_prefix = kwargs.get("load_prefix", None)
154-
if pretrained_checkpoint:
155-
with open(pretrained_checkpoint, "rb") as fin:
156-
model_dict = torch.load(fin,map_location=self.device)
157-
if model_prefix:
158-
state_dict = {str(k)[len(model_prefix):]:v for k,v in model_dict["state_dict"].items() if str(k).startswith(model_prefix)}
159-
else:
160-
state_dict = model_dict["state_dict"]
161-
self.electra = ElectraModel.from_pretrained(None, state_dict=state_dict, config=self.config)
162-
else:
163-
self.electra = ElectraModel(config=self.config)
164160

165161
in_d = self.config.hidden_size
166-
167162
self.output = nn.Sequential(
168163
nn.Dropout(self.config.hidden_dropout_prob),
169164
nn.Linear(in_d, in_d),
170165
nn.GELU(),
171166
nn.Dropout(self.config.hidden_dropout_prob),
172167
nn.Linear(in_d, self.config.num_labels),
173168
)
169+
if pretrained_checkpoint:
170+
with open(pretrained_checkpoint, "rb") as fin:
171+
model_dict = torch.load(fin,map_location=self.device)
172+
if model_prefix:
173+
state_dict = filter_dict(model_dict["state_dict"], model_prefix)
174+
else:
175+
state_dict = model_dict["state_dict"]
176+
self.electra = ElectraModel.from_pretrained(None, state_dict={k:v for (k,v) in state_dict.items() if k.startswith("electra.")}, config=self.config)
177+
self.output.load_state_dict(filter_dict(state_dict,"output."))
178+
else:
179+
self.electra = ElectraModel(config=self.config)
174180

175181
def _get_data_for_loss(self, model_output, labels):
176182
mask = model_output.get("target_mask")

0 commit comments

Comments
 (0)