Skip to content

Commit 00dbb6f

Browse files
authored
Improve the state in gradio web server (lm-sys#1348)
1 parent 7329f94 commit 00dbb6f

15 files changed

+244
-127
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ The following models are tested:
111111
- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
112112
- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
113113
- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b)
114+
- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2)
114115
- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
115116
- [OpenAssistant/oasst-sft-1-pythia-12b](https://huggingface.co/OpenAssistant/oasst-sft-1-pythia-12b)
116117
- [project-baize/baize-lora-7B](https://huggingface.co/project-baize/baize-lora-7B)
117118
- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
118119
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
119120
- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT)
120-
- [h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2)
121121

122122
Help us [add more](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
123123

docs/arena.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ If you want to see a specific model in the arena, you can follow the steps below
1717
```
1818

1919
Some major files you need to modify include
20-
- Implement a conversation template for the new model at https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py. You can follow existing examples and use `register_conv_template` to add a new one.
21-
- Implement a model adapter for the new model at https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py. You can follow existing examples and use `register_model_adapter` to add a new one.
20+
- Implement a conversation template for the new model at [fastchat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py). You can follow existing examples and use `register_conv_template` to add a new one.
21+
- Implement a model adapter for the new model at [fastchat/model/model_adapter.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py). You can follow existing examples and use `register_model_adapter` to add a new one.
22+
- (Optional) add the model name to the "Supported Models" section in [README.md](https://github.com/lm-sys/FastChat#supported-models) and add more inforamtion in [fastchat/model/model_registry.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_registry.py).
23+
2224
2. After the model is supported, we will try to schedule some computing resources to host the model in the arena.
2325
However, due to the limited resources we have, we may not be able to serve every model.
2426
We will select the models based on popularity, quality, diversity, and other factors.

docs/openai_api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ curl http://localhost:8000/v1/embeddings \
102102
}'
103103
```
104104

105-
## Tunning
106-
Runner should answer within 20 seconds. If your model/hardware is slower, you wil get Timeout errors. You can change this timeout through ENV variables : "export WORKER_API_TIMEOUT=<larger timeout in seconds>"
105+
## Adjusting Timeout
106+
By default, a timeout error will occur if a model worker does not response within 20 seconds. If your model/hardware is slower, you can change this timeout through an environment variable: `export FASTCHAT_WORKER_API_TIMEOUT=<larger timeout in seconds>`
107107

108108
## Todos
109109
Some features to be implemented:

fastchat/constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
LOGDIR = "."
1313

1414
# For the controller and workers(could be overwritten through ENV variables.)
15-
CONTROLLER_HEART_BEAT_EXPIRATION = int(os.getenv("CONTROLLER_HEART_BEAT_EXPIRATION", 90))
16-
WORKER_HEART_BEAT_INTERVAL = int(os.getenv("WORKER_HEART_BEAT_INTERVAL", 30))
17-
WORKER_API_TIMEOUT = int(os.getenv("WORKER_API_TIMEOUT", 20))
15+
CONTROLLER_HEART_BEAT_EXPIRATION = int(
16+
os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
17+
)
18+
WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 30))
19+
WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 20))
1820

1921

2022
class ErrorCode(IntEnum):

fastchat/conversation.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Conversation:
2727

2828
# The name of this template
2929
name: str
30-
# System prompts
30+
# The System prompt
3131
system: str
3232
# Two roles
3333
roles: List[str]
@@ -44,12 +44,6 @@ class Conversation:
4444
# Stops generation if meeting any token in this list
4545
stop_token_ids: List[int] = None
4646

47-
# Used for the state in the gradio servers.
48-
# TODO(lmzheng): move this out of this class.
49-
conv_id: Any = None
50-
skip_next: bool = False
51-
model_name: str = None
52-
5347
def get_prompt(self) -> str:
5448
"""Get the prompt for generation."""
5549
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
@@ -174,8 +168,6 @@ def copy(self):
174168
sep2=self.sep2,
175169
stop_str=self.stop_str,
176170
stop_token_ids=self.stop_token_ids,
177-
conv_id=self.conv_id,
178-
model_name=self.model_name,
179171
)
180172

181173
def dict(self):
@@ -185,8 +177,6 @@ def dict(self):
185177
"roles": self.roles,
186178
"messages": self.messages,
187179
"offset": self.offset,
188-
"conv_id": self.conv_id,
189-
"model_name": self.model_name,
190180
}
191181

192182

@@ -479,8 +469,8 @@ def get_conv_template(name: str) -> Conversation:
479469
offset=0,
480470
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
481471
sep="\n",
482-
stop_str="<human>:",
483-
)
472+
stop_str="<human>",
473+
)
484474
)
485475

486476
# h2oGPT default template
@@ -493,10 +483,10 @@ def get_conv_template(name: str) -> Conversation:
493483
offset=0,
494484
sep_style=SeparatorStyle.NO_COLON_SINGLE,
495485
sep="</s>",
496-
stop_str="</s>",
497486
)
498487
)
499488

489+
500490
if __name__ == "__main__":
501491
conv = get_conv_template("vicuna_v1.1")
502492
conv.append_message(conv.roles[0], "Hello!")

fastchat/model/model_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class ClaudeAdapter(BaseAdapter):
445445
"""The model adapter for Claude."""
446446

447447
def match(self, model_path: str):
448-
return model_path in ["claude-v1", "claude-instant-v1.1"]
448+
return model_path in ["claude-v1", "claude-instant-v1"]
449449

450450
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
451451
raise NotImplementedError()
@@ -495,7 +495,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
495495
def get_default_conv_template(self, model_path: str) -> Conversation:
496496
return get_conv_template("redpajama-incite")
497497

498-
498+
499499
class H2OGPTAdapter(BaseAdapter):
500500
"""The model adapter for h2oGPT."""
501501

fastchat/model/model_registry.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Additional information of the models."""
12
from collections import namedtuple
23
from typing import List
34

@@ -37,7 +38,7 @@ def get_model_info(name: str) -> ModelInfo:
3738
"Claude by Anthropic",
3839
)
3940
register_model_info(
40-
["claude-instant-v1.1"],
41+
["claude-instant-v1"],
4142
"Claude Instant",
4243
"https://www.anthropic.com/index/introducing-claude",
4344
"Claude Instant by Anthropic",
@@ -49,7 +50,7 @@ def get_model_info(name: str) -> ModelInfo:
4950
"Bard based on the PaLM 2 Chat API by Google",
5051
)
5152
register_model_info(
52-
["vicuna-13b"],
53+
["vicuna-13b", "vicuna-7b"],
5354
"Vicuna",
5455
"https://lmsys.org/blog/2023-03-30-vicuna/",
5556
"a chat assistant fine-tuned from LLaMA on user-shared conversations by LMSYS",
@@ -62,7 +63,7 @@ def get_model_info(name: str) -> ModelInfo:
6263
)
6364
register_model_info(
6465
["oasst-pythia-12b"],
65-
"OpenAssistant",
66+
"OpenAssistant (oasst)",
6667
"https://open-assistant.io",
6768
"an Open Assistant for everyone by LAION",
6869
)
@@ -124,11 +125,11 @@ def get_model_info(name: str) -> ModelInfo:
124125
["billa-7b-sft"],
125126
"BiLLa-7B-SFT",
126127
"https://huggingface.co/Neutralzz/BiLLa-7B-SFT",
127-
"an instruction-tuned bilingual llama with enhanced reasoning ability by an independent researcher",
128+
"an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher",
128129
)
129130
register_model_info(
130131
["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"],
131132
"h2oGPT-GM-7b",
132133
"https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
133-
"an instruction-tuned Apache 2.0 licensed llama with enhanced conversational ability by H2O.ai",
134+
"an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai",
134135
)

fastchat/serve/api_provider.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def bard_api_stream_iter(state):
7070
# TODO: we will use the official PaLM 2 API sooner or later,
7171
# and we will update this function accordingly. So here we just hard code the
7272
# Bard worker address. It is going to be deprecated anyway.
73+
conv = state.conv
7374

7475
# Make requests
7576
gen_params = {
@@ -81,14 +82,14 @@ def bard_api_stream_iter(state):
8182
response = requests.post(
8283
"http://localhost:18900/chat",
8384
json={
84-
"content": state.messages[-2][-1],
85-
"state": state.session_state,
85+
"content": conv.messages[-2][-1],
86+
"state": state.bard_session_state,
8687
},
8788
stream=False,
8889
timeout=WORKER_API_TIMEOUT,
8990
)
9091
resp_json = response.json()
91-
state.session_state = resp_json["state"]
92+
state.bard_session_state = resp_json["state"]
9293
content = resp_json["content"]
9394
# The Bard Web API does not support streaming yet. Here we have to simulate
9495
# the streaming behavior by adding some time.sleep().

fastchat/serve/cli.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def main(args):
150150
choices=["simple", "rich"],
151151
help="Display style.",
152152
)
153-
parser.add_argument("--debug", action="store_true", help="Print debug information")
153+
parser.add_argument(
154+
"--debug",
155+
action="store_true",
156+
help="Print useful debug information (e.g., prompts)",
157+
)
154158
args = parser.parse_args()
155159
main(args)

fastchat/serve/gradio_block_arena_anony.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from fastchat.model.model_adapter import get_conversation_template
1919
from fastchat.serve.gradio_patch import Chatbot as grChatbot
2020
from fastchat.serve.gradio_web_server import (
21+
State,
2122
http_bot,
2223
get_conv_log_filename,
2324
no_change_btn,
@@ -138,8 +139,7 @@ def regenerate(state0, state1, request: gr.Request):
138139
logger.info(f"regenerate (anony). ip: {request.client.host}")
139140
states = [state0, state1]
140141
for i in range(num_models):
141-
states[i].messages[-1][-1] = None
142-
states[i].skip_next = False
142+
states[i].conv.messages[-1][-1] = None
143143
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
144144

145145

@@ -166,13 +166,14 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
166166
"gpt-4": 1.5,
167167
"gpt-3.5-turbo": 1.5,
168168
"claude-v1": 1.5,
169-
"claude-instant-v1.1": 1.5,
169+
"claude-instant-v1": 1.5,
170170
"bard": 1.5,
171171
"vicuna-13b": 1.5,
172172
"koala-13b": 1.5,
173-
"RWKV-4-Raven-14B": 1.2,
174-
"oasst-pythia-12b": 1.2,
173+
"vicuna-7b": 1.2,
175174
"mpt-7b-chat": 1.2,
175+
"oasst-pythia-12b": 1.2,
176+
"RWKV-4-Raven-14B": 1.2,
176177
"fastchat-t5-3b": 1,
177178
"alpaca-13b": 1,
178179
"chatglm-6b": 1,
@@ -182,9 +183,12 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
182183
}
183184

184185

185-
def add_text(state0, state1, text, request: gr.Request):
186+
def add_text(
187+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
188+
):
186189
logger.info(f"add_text (anony). ip: {request.client.host}. len: {len(text)}")
187190
states = [state0, state1]
191+
model_selectors = [model_selector0, model_selector1]
188192

189193
if states[0] is None:
190194
assert states[1] is None
@@ -198,11 +202,9 @@ def add_text(state0, state1, text, request: gr.Request):
198202
model_left = model_right = models[0]
199203

200204
states = [
201-
get_conversation_template("vicuna"),
202-
get_conversation_template("vicuna"),
205+
State(model_left),
206+
State(model_right),
203207
]
204-
states[0].model_name = model_left
205-
states[1].model_name = model_right
206208

207209
if len(text) <= 0:
208210
for i in range(num_models):
@@ -235,7 +237,8 @@ def add_text(state0, state1, text, request: gr.Request):
235237
* 6
236238
)
237239

238-
if (len(states[0].messages) - states[0].offset) // 2 >= CONVERSATION_LEN_LIMIT:
240+
conv = states[0].conv
241+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_LEN_LIMIT:
239242
logger.info(
240243
f"hit conversation length limit. ip: {request.client.host}. text: {text}"
241244
)
@@ -253,8 +256,8 @@ def add_text(state0, state1, text, request: gr.Request):
253256

254257
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
255258
for i in range(num_models):
256-
states[i].append_message(states[i].roles[0], text)
257-
states[i].append_message(states[i].roles[1], None)
259+
states[i].conv.append_message(states[i].conv.roles[0], text)
260+
states[i].conv.append_message(states[i].conv.roles[1], None)
258261
states[i].skip_next = False
259262

260263
return (
@@ -271,8 +274,6 @@ def add_text(state0, state1, text, request: gr.Request):
271274
def http_bot_all(
272275
state0,
273276
state1,
274-
model_selector0,
275-
model_selector1,
276277
temperature,
277278
top_p,
278279
max_new_tokens,
@@ -291,13 +292,11 @@ def http_bot_all(
291292
return
292293

293294
states = [state0, state1]
294-
model_selector = [state0.model_name, state1.model_name]
295295
gen = []
296296
for i in range(num_models):
297297
gen.append(
298298
http_bot(
299299
states[i],
300-
model_selector[i],
301300
temperature,
302301
top_p,
303302
max_new_tokens,
@@ -447,7 +446,7 @@ def build_side_by_side_ui_anony(models):
447446
regenerate, states, states + chatbots + [textbox] + btn_list
448447
).then(
449448
http_bot_all,
450-
states + model_selectors + [temperature, top_p, max_output_tokens],
449+
states + [temperature, top_p, max_output_tokens],
451450
states + chatbots + btn_list,
452451
)
453452
clear_btn.click(
@@ -477,17 +476,21 @@ def build_side_by_side_ui_anony(models):
477476
share_btn.click(share_click, states + model_selectors, [], _js=share_js)
478477

479478
textbox.submit(
480-
add_text, states + [textbox], states + chatbots + [textbox] + btn_list
479+
add_text,
480+
states + model_selectors + [textbox],
481+
states + chatbots + [textbox] + btn_list,
481482
).then(
482483
http_bot_all,
483-
states + model_selectors + [temperature, top_p, max_output_tokens],
484+
states + [temperature, top_p, max_output_tokens],
484485
states + chatbots + btn_list,
485486
)
486487
send_btn.click(
487-
add_text, states + [textbox], states + chatbots + [textbox] + btn_list
488+
add_text,
489+
states + model_selectors + [textbox],
490+
states + chatbots + [textbox] + btn_list,
488491
).then(
489492
http_bot_all,
490-
states + model_selectors + [temperature, top_p, max_output_tokens],
493+
states + [temperature, top_p, max_output_tokens],
491494
states + chatbots + btn_list,
492495
)
493496

0 commit comments

Comments
 (0)