Skip to content

Commit cfaa858

Browse files
feat(server): support fp16 for t5 (#360)
Fixes #349
1 parent 94377ef commit cfaa858

File tree

6 files changed

+357
-6
lines changed

6 files changed

+357
-6
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{
2+
"details": {
3+
"best_of_sequences": null,
4+
"finish_reason": "eos_token",
5+
"generated_tokens": 7,
6+
"prefill": [
7+
{
8+
"id": 0,
9+
"logprob": null,
10+
"text": "<pad>"
11+
}
12+
],
13+
"seed": null,
14+
"tokens": [
15+
{
16+
"id": 3,
17+
"logprob": -0.7001953,
18+
"special": false,
19+
"text": " "
20+
},
21+
{
22+
"id": 18,
23+
"logprob": -1.1943359,
24+
"special": false,
25+
"text": "-"
26+
},
27+
{
28+
"id": 26937,
29+
"logprob": -1.2099609,
30+
"special": false,
31+
"text": "196"
32+
},
33+
{
34+
"id": 3,
35+
"logprob": -1.2451172,
36+
"special": false,
37+
"text": " "
38+
},
39+
{
40+
"id": 1956,
41+
"logprob": -0.3322754,
42+
"special": false,
43+
"text": "°"
44+
},
45+
{
46+
"id": 254,
47+
"logprob": -0.19213867,
48+
"special": false,
49+
"text": "C"
50+
},
51+
{
52+
"id": 1,
53+
"logprob": -0.030151367,
54+
"special": true,
55+
"text": "</s>"
56+
}
57+
]
58+
},
59+
"generated_text": "-196 °C"
60+
}
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
[
2+
{
3+
"details": {
4+
"best_of_sequences": null,
5+
"finish_reason": "eos_token",
6+
"generated_tokens": 7,
7+
"prefill": [
8+
{
9+
"id": 0,
10+
"logprob": null,
11+
"text": "<pad>"
12+
}
13+
],
14+
"seed": null,
15+
"tokens": [
16+
{
17+
"id": 3,
18+
"logprob": -0.7001953,
19+
"special": false,
20+
"text": " "
21+
},
22+
{
23+
"id": 18,
24+
"logprob": -1.1943359,
25+
"special": false,
26+
"text": "-"
27+
},
28+
{
29+
"id": 26937,
30+
"logprob": -1.2119141,
31+
"special": false,
32+
"text": "196"
33+
},
34+
{
35+
"id": 3,
36+
"logprob": -1.2480469,
37+
"special": false,
38+
"text": " "
39+
},
40+
{
41+
"id": 1956,
42+
"logprob": -0.33203125,
43+
"special": false,
44+
"text": "°"
45+
},
46+
{
47+
"id": 254,
48+
"logprob": -0.19250488,
49+
"special": false,
50+
"text": "C"
51+
},
52+
{
53+
"id": 1,
54+
"logprob": -0.030166626,
55+
"special": true,
56+
"text": "</s>"
57+
}
58+
]
59+
},
60+
"generated_text": "-196 °C"
61+
},
62+
{
63+
"details": {
64+
"best_of_sequences": null,
65+
"finish_reason": "eos_token",
66+
"generated_tokens": 7,
67+
"prefill": [
68+
{
69+
"id": 0,
70+
"logprob": null,
71+
"text": "<pad>"
72+
}
73+
],
74+
"seed": null,
75+
"tokens": [
76+
{
77+
"id": 3,
78+
"logprob": -0.7001953,
79+
"special": false,
80+
"text": " "
81+
},
82+
{
83+
"id": 18,
84+
"logprob": -1.1943359,
85+
"special": false,
86+
"text": "-"
87+
},
88+
{
89+
"id": 26937,
90+
"logprob": -1.2119141,
91+
"special": false,
92+
"text": "196"
93+
},
94+
{
95+
"id": 3,
96+
"logprob": -1.2480469,
97+
"special": false,
98+
"text": " "
99+
},
100+
{
101+
"id": 1956,
102+
"logprob": -0.33203125,
103+
"special": false,
104+
"text": "°"
105+
},
106+
{
107+
"id": 254,
108+
"logprob": -0.19250488,
109+
"special": false,
110+
"text": "C"
111+
},
112+
{
113+
"id": 1,
114+
"logprob": -0.030166626,
115+
"special": true,
116+
"text": "</s>"
117+
}
118+
]
119+
},
120+
"generated_text": "-196 °C"
121+
},
122+
{
123+
"details": {
124+
"best_of_sequences": null,
125+
"finish_reason": "eos_token",
126+
"generated_tokens": 7,
127+
"prefill": [
128+
{
129+
"id": 0,
130+
"logprob": null,
131+
"text": "<pad>"
132+
}
133+
],
134+
"seed": null,
135+
"tokens": [
136+
{
137+
"id": 3,
138+
"logprob": -0.7001953,
139+
"special": false,
140+
"text": " "
141+
},
142+
{
143+
"id": 18,
144+
"logprob": -1.1943359,
145+
"special": false,
146+
"text": "-"
147+
},
148+
{
149+
"id": 26937,
150+
"logprob": -1.2119141,
151+
"special": false,
152+
"text": "196"
153+
},
154+
{
155+
"id": 3,
156+
"logprob": -1.2480469,
157+
"special": false,
158+
"text": " "
159+
},
160+
{
161+
"id": 1956,
162+
"logprob": -0.33203125,
163+
"special": false,
164+
"text": "°"
165+
},
166+
{
167+
"id": 254,
168+
"logprob": -0.19250488,
169+
"special": false,
170+
"text": "C"
171+
},
172+
{
173+
"id": 1,
174+
"logprob": -0.030166626,
175+
"special": true,
176+
"text": "</s>"
177+
}
178+
]
179+
},
180+
"generated_text": "-196 °C"
181+
},
182+
{
183+
"details": {
184+
"best_of_sequences": null,
185+
"finish_reason": "eos_token",
186+
"generated_tokens": 7,
187+
"prefill": [
188+
{
189+
"id": 0,
190+
"logprob": null,
191+
"text": "<pad>"
192+
}
193+
],
194+
"seed": null,
195+
"tokens": [
196+
{
197+
"id": 3,
198+
"logprob": -0.7001953,
199+
"special": false,
200+
"text": " "
201+
},
202+
{
203+
"id": 18,
204+
"logprob": -1.1943359,
205+
"special": false,
206+
"text": "-"
207+
},
208+
{
209+
"id": 26937,
210+
"logprob": -1.2099609,
211+
"special": false,
212+
"text": "196"
213+
},
214+
{
215+
"id": 3,
216+
"logprob": -1.2451172,
217+
"special": false,
218+
"text": " "
219+
},
220+
{
221+
"id": 1956,
222+
"logprob": -0.3322754,
223+
"special": false,
224+
"text": "°"
225+
},
226+
{
227+
"id": 254,
228+
"logprob": -0.19213867,
229+
"special": false,
230+
"text": "C"
231+
},
232+
{
233+
"id": 1,
234+
"logprob": -0.030151367,
235+
"special": true,
236+
"text": "</s>"
237+
}
238+
]
239+
},
240+
"generated_text": "-196 °C"
241+
}
242+
]

integration-tests/models/test_flash_neox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
3636
generated_texts = [r.generated_text for r in responses]
3737

3838
assert len(generated_texts) == 4
39-
assert generated_texts, all([text == generated_texts[0] for text in generated_texts])
39+
assert generated_texts, all(
40+
[text == generated_texts[0] for text in generated_texts]
41+
)
4042

4143
assert responses == response_snapshot
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
3+
4+
@pytest.fixture(scope="module")
5+
def t5_sharded_handle(launcher):
6+
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
7+
yield handle
8+
9+
10+
@pytest.fixture(scope="module")
11+
async def t5_sharded(t5_sharded_handle):
12+
await t5_sharded_handle.health(240)
13+
return t5_sharded_handle.client
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_t5_sharded(t5_sharded, response_snapshot):
18+
response = await t5_sharded.generate(
19+
"Please answer the following question. What is the boiling point of Nitrogen?",
20+
max_new_tokens=10,
21+
)
22+
23+
assert response == response_snapshot
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
28+
responses = await generate_load(
29+
t5_sharded,
30+
"Please answer the following question. What is the boiling point of Nitrogen?",
31+
max_new_tokens=10,
32+
n=4,
33+
)
34+
35+
assert len(responses) == 4
36+
assert all([r.generated_text == responses[0].generated_text for r in responses])
37+
38+
assert responses == response_snapshot

server/text_generation_server/models/bloom.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ def load_weights(
160160
# XXX: Hack for Rowlinear to add the bias only once.
161161
if rank != 0:
162162
tensor = torch.zeros_like(tensor)
163-
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
163+
elif (
164+
isinstance(module, TensorParallelEmbedding)
165+
or name == "lm_head.weight"
166+
):
164167
size = slice_.get_shape()[0]
165168
block_size = size // world_size
166169
start = rank * block_size

0 commit comments

Comments
 (0)