-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgenerate.py
More file actions
59 lines (51 loc) · 1.77 KB
/
generate.py
File metadata and controls
59 lines (51 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from tokenizer import Tokenizer
from model import TransformerLM
from utils import load, ModelConfig
from config_args import generate_args
def main():
file_args = generate_args()
num_tokens_to_generate = file_args.num_tokens_to_generate
temperature = file_args.temperature
topp = file_args.top_p
save_file_name = file_args.save_file_name
backend = file_args.backend
compile_model = file_args.compile_model
tokenizer_file_name = file_args.tokenizer_file_name
load_mistral_tokenizer = file_args.load_mistral_tokenizer
tokenizer = Tokenizer()
if load_mistral_tokenizer:
tokenizer.load_mistral_tokenizer(tokenizer_file_name)
else:
tokenizer.load(tokenizer_file_name)
input_tokens = [tokenizer.encode(file_args.input_text) for _ in range(10)]
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
checkpoint = load(save_file_name, map_location=dev)
model_config: ModelConfig = load(
save_file_name, "model_config.pt", False, map_location=dev.type
)
model_config.flash = False
model = TransformerLM(model_config, tokenizer.vocab_size).to(dev)
model: TransformerLM = torch.compile(
model, backend=backend, disable=not compile_model
)
model.load_state_dict(checkpoint["model"])
model.eval()
stop_tokens = [
tokenizer.special_token["</s>"],
tokenizer.special_token["<unk>"],
]
generated_tokens = model.generate(
dev,
tokenizer.special_token["<pad>"],
stop_tokens,
input_tokens,
num_tokens_to_generate,
temperature,
topp,
)
for tokens in generated_tokens:
print(tokenizer.decode(tokens))
print("-" * 50)
if __name__ == "__main__":
main()