-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhf_publisher.py
31 lines (23 loc) · 873 Bytes
/
hf_publisher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from engine.lit.lightning_module import TaskTrainer
import argparse
import torch.nn as nn
parser = argparse.ArgumentParser()
parser.add_argument("--hf_name")
parser.add_argument("--model_ckpt")
parser.add_argument("--mode")
args = parser.parse_args()
hf_name = args.hf_name
model_ckpt = args.model_ckpt
mode = args.mode
task_trainer = TaskTrainer.load_from_checkpoint(model_ckpt, map_location="cpu")
tokenizer = task_trainer.task.tokenizer
if mode == "mlm":
model = task_trainer.task.mlm_model
else:
model = task_trainer.task.lm_model
if mode == "add_head":
vocab_len, hs = model.gpt_neox.get_input_embeddings().weight.shape
model.embed_out = nn.Linear(hs, vocab_len, bias=False)
model.embed_out.weight.data = model.get_input_embeddings().weight.data.clone()
model.push_to_hub(hf_name, private=True)
tokenizer.push_to_hub(hf_name, private=True)