Skip to content

Commit a82fa59

Browse files
committed
streamlit web demo
1 parent 519e9e2 commit a82fa59

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

finetune.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# based on Baichuan2 code
2+
13
import os
24
import math
35
import pathlib

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ transformers
22
torch
33
accelerate
44
deepspeed
5+
streamlit

web_streamlit.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# based on Baichuan2 code
2+
3+
import fire
4+
import json
5+
import torch
6+
import streamlit as st
7+
from transformers import AutoModelForCausalLM, AutoTokenizer
8+
from transformers.generation.utils import GenerationConfig
9+
10+
11+
st.set_page_config(page_title="LLM utils")
12+
st.title("LLM utils")
13+
14+
15+
@st.cache_resource
16+
def init_model(base_model):
17+
model = AutoModelForCausalLM.from_pretrained(
18+
base_model,
19+
torch_dtype=torch.float16,
20+
device_map="auto",
21+
trust_remote_code=True
22+
)
23+
model.generation_config = GenerationConfig.from_pretrained(
24+
base_model
25+
)
26+
tokenizer = AutoTokenizer.from_pretrained(
27+
base_model,
28+
use_fast=False,
29+
trust_remote_code=True
30+
)
31+
return model, tokenizer
32+
33+
34+
def clear_chat_history():
35+
del st.session_state.messages
36+
37+
38+
def init_chat_history():
39+
if "messages" in st.session_state:
40+
for message in st.session_state.messages:
41+
avatar = '🧑' if message["role"] == "user" else '🤖'
42+
with st.chat_message(message["role"], avatar=avatar):
43+
st.markdown(message["content"])
44+
else:
45+
st.session_state.messages = []
46+
47+
return st.session_state.messages
48+
49+
50+
def main(
51+
base_model: str = ""
52+
):
53+
54+
assert (
55+
base_model
56+
), "Please specify a --base_model"
57+
58+
model, tokenizer = init_model(base_model)
59+
messages = init_chat_history()
60+
61+
if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
62+
with st.chat_message("user", avatar='🧑'):
63+
st.markdown(prompt)
64+
messages.append({"role": "user", "content": prompt})
65+
print(f"[user] {prompt}", flush=True)
66+
with st.chat_message("assistant", avatar='🤖'):
67+
placeholder = st.empty()
68+
for response in model.chat(tokenizer, messages, stream=True):
69+
placeholder.markdown(response)
70+
if torch.backends.mps.is_available():
71+
torch.mps.empty_cache()
72+
messages.append({"role": "assistant", "content": response})
73+
print(json.dumps(messages, ensure_ascii=False), flush=True)
74+
75+
st.button("清空对话", on_click=clear_chat_history)
76+
77+
if __name__ == "__main__":
78+
fire.Fire(main)

0 commit comments

Comments
 (0)