-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweb_multiturn.py
162 lines (124 loc) · 4.86 KB
/
web_multiturn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A simple web interactive chat demo based on gradio."""
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DEFAULT_CKPT_PATH = '../Qwen1.5-7B-Chat'
def _get_args():
parser = ArgumentParser()
parser.add_argument("-m", "--base_model", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu_only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--server_port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server_name", type=str, default="0.0.0.0",
help="Demo server name.")
args = parser.parse_args()
return args
def _load_model_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.base_model, resume_download=True, trust_remote_code=True,
)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "auto"
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype="auto",
device_map=device_map,
resume_download=True,
trust_remote_code=True,
).eval()
model.generation_config.max_new_tokens = 2048 # For chat.
return model, tokenizer
def _chat_stream(model, tokenizer, query, system, history):
sys = 'You are a helpful assistant.'
if len(system) > 0:
sys = system
conversation = [
{'role': 'system', 'content': sys },
]
for query_h, response_h in history:
conversation.append({'role': 'user', 'content': query_h})
conversation.append({'role': 'assistant', 'content': response_h})
conversation.append({'role': 'user', 'content': query})
#print(conversation)
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors='pt',
)
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _launch_demo(args, model, tokenizer):
def predict(_query, _system, _chatbot, _task_history):
print(f"System: {_system}")
print(f"User: {_query}")
_chatbot.append((_query, ""))
full_response = ""
response = ""
for new_text in _chat_stream(model, tokenizer, _query, _system, history=_task_history):
response += new_text
_chatbot[-1] = (_query, response)
yield _chatbot
full_response = response
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Chatbot: {full_response}")
def regenerate(_system, _chatbot, _task_history):
if not _task_history:
yield _chatbot
return
item = _task_history.pop(-1)
_chatbot.pop(-1)
yield from predict(item[0], _system, _chatbot, _task_history)
def reset_user_input():
return gr.update(value="")
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
_gc()
return _chatbot
with gr.Blocks() as demo:
gr.Markdown("""<center><font size=8>LLMUtils Chatbot</center>""")
chatbot = gr.Chatbot(elem_classes="control-height", height=500)
system = gr.Textbox(lines=2, label='System')
query = gr.Textbox(lines=1, label='Input')
task_history = gr.State([])
with gr.Row():
empty_btn = gr.Button("🧹 Clear History (清除历史)")
submit_btn = gr.Button("🚀 Submit (发送)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
submit_btn.click(predict, [query, system, chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
regen_btn.click(regenerate, [system, chatbot, task_history], [chatbot], show_progress=True)
demo.queue().launch(
server_port=args.server_port,
server_name=args.server_name,
)
def main():
args = _get_args()
model, tokenizer = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer)
if __name__ == '__main__':
main()