Skip to content

Commit 06c3d4b

Browse files
authored
feat: accept list as prompt and use first string (#1702)
This PR allows the `CompletionRequest.prompt` to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown Fixes: #1690 Similar to: https://github.com/vllm-project/vllm/pull/323/files
1 parent e4d31a4 commit 06c3d4b

File tree

11 files changed

+1188
-107
lines changed

11 files changed

+1188
-107
lines changed

clients/python/text_generation/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
5959
usage: Optional[Any] = None
6060

6161

62+
class CompletionComplete(BaseModel):
63+
# Index of the chat completion
64+
index: int
65+
# Message associated with the chat completion
66+
text: str
67+
# Log probabilities for the chat completion
68+
logprobs: Optional[Any]
69+
# Reason for completion
70+
finish_reason: str
71+
72+
6273
class Function(BaseModel):
6374
name: Optional[str]
6475
arguments: str
@@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
104115
usage: Any
105116

106117

118+
class Completion(BaseModel):
119+
# Completion details
120+
id: str
121+
object: str
122+
created: int
123+
model: str
124+
system_fingerprint: str
125+
choices: List[CompletionComplete]
126+
127+
107128
class ChatRequest(BaseModel):
108129
# Model identifier
109130
model: str

docs/source/basic_tutorials/launcher.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ Options:
398398
-e, --env
399399
Display a lot of information about your runtime environment
400400
401+
```
402+
## MAX_CLIENT_BATCH_SIZE
403+
```shell
404+
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
405+
Control the maximum number of inputs that a client can send in a single request
406+
407+
[env: MAX_CLIENT_BATCH_SIZE=]
408+
[default: 4]
409+
401410
```
402411
## HELP
403412
```shell

integration-tests/conftest.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import math
1010
import time
1111
import random
12+
import re
1213

1314
from docker.errors import NotFound
1415
from typing import Optional, List, Dict
@@ -26,6 +27,7 @@
2627
ChatComplete,
2728
ChatCompletionChunk,
2829
ChatCompletionComplete,
30+
Completion,
2931
)
3032

3133
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@@ -69,17 +71,22 @@ def convert_data(data):
6971
data = json.loads(data)
7072
if isinstance(data, Dict) and "choices" in data:
7173
choices = data["choices"]
72-
if (
73-
isinstance(choices, List)
74-
and len(choices) >= 1
75-
and "delta" in choices[0]
76-
):
77-
return ChatCompletionChunk(**data)
74+
if isinstance(choices, List) and len(choices) >= 1:
75+
if "delta" in choices[0]:
76+
return ChatCompletionChunk(**data)
77+
if "text" in choices[0]:
78+
return Completion(**data)
7879
return ChatComplete(**data)
7980

8081
if isinstance(data, Dict):
8182
return Response(**data)
8283
if isinstance(data, List):
84+
if (
85+
len(data) > 0
86+
and "object" in data[0]
87+
and data[0]["object"] == "text_completion"
88+
):
89+
return [Completion(**d) for d in data]
8390
return [Response(**d) for d in data]
8491
raise NotImplementedError
8592

@@ -161,6 +168,9 @@ def eq_details(details: Details, other: Details) -> bool:
161168
)
162169
)
163170

171+
def eq_completion(response: Completion, other: Completion) -> bool:
172+
return response.choices[0].text == other.choices[0].text
173+
164174
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
165175
return (
166176
response.choices[0].message.content == other.choices[0].message.content
@@ -184,6 +194,11 @@ def eq_response(response: Response, other: Response) -> bool:
184194
if not isinstance(snapshot_data, List):
185195
snapshot_data = [snapshot_data]
186196

197+
if isinstance(serialized_data[0], Completion):
198+
return len(snapshot_data) == len(serialized_data) and all(
199+
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
200+
)
201+
187202
if isinstance(serialized_data[0], ChatComplete):
188203
return len(snapshot_data) == len(serialized_data) and all(
189204
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "eos_token",
5+
"index": 1,
6+
"logprobs": null,
7+
"text": " PR for more information?"
8+
},
9+
{
10+
"finish_reason": "length",
11+
"index": 0,
12+
"logprobs": null,
13+
"text": "le Business Incubator is providing a workspace"
14+
},
15+
{
16+
"finish_reason": "length",
17+
"index": 2,
18+
"logprobs": null,
19+
"text": " severely flawed and often has a substandard"
20+
},
21+
{
22+
"finish_reason": "length",
23+
"index": 3,
24+
"logprobs": null,
25+
"text": "hd20220811-"
26+
}
27+
],
28+
"created": 1713284455,
29+
"id": "",
30+
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
31+
"object": "text_completion",
32+
"system_fingerprint": "2.0.0-native",
33+
"usage": {
34+
"completion_tokens": 36,
35+
"prompt_tokens": 8,
36+
"total_tokens": 44
37+
}
38+
}

0 commit comments

Comments
 (0)