Skip to content

Commit e4d31a4

Browse files
authored
fix: bump clients test base url to llama (#1751)
This PR bumps the client tests from `google/flan-t5-xxl` to `meta-llama/Llama-2-7b-chat-hf` to resolve issues when calling the endpoint and `google/flan-t5-xxl` is not available run with ```bash make python-client-tests clients/python/tests/test_client.py .............. [ 43%] clients/python/tests/test_errors.py .......... [ 75%] clients/python/tests/test_inference_api.py ...... [ 93%] clients/python/tests/test_types.py .. [100%] ``` **note `google/flan-t5-xxl` function is currently unused but still included in the `conftest.py`
1 parent 00f3653 commit e4d31a4

File tree

2 files changed

+45
-32
lines changed

2 files changed

+45
-32
lines changed

clients/python/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ def flan_t5_xxl():
99
return "google/flan-t5-xxl"
1010

1111

12+
@pytest.fixture
13+
def llama_7b():
14+
return "meta-llama/Llama-2-7b-chat-hf"
15+
16+
1217
@pytest.fixture
1318
def fake_model():
1419
return "fake/model"
@@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
3439
return f"{base_url}/{flan_t5_xxl}"
3540

3641

42+
@pytest.fixture
43+
def llama_7b_url(base_url, llama_7b):
44+
return f"{base_url}/{llama_7b}"
45+
46+
3747
@pytest.fixture
3848
def fake_url(base_url, fake_model):
3949
return f"{base_url}/{fake_model}"

clients/python/tests/test_client.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@
55
from text_generation.types import FinishReason, InputToken
66

77

8-
def test_generate(flan_t5_xxl_url, hf_headers):
9-
client = Client(flan_t5_xxl_url, hf_headers)
8+
def test_generate(llama_7b_url, hf_headers):
9+
client = Client(llama_7b_url, hf_headers)
1010
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
1111

12-
assert response.generated_text == ""
12+
assert response.generated_text == "_"
1313
assert response.details.finish_reason == FinishReason.Length
1414
assert response.details.generated_tokens == 1
1515
assert response.details.seed is None
16-
assert len(response.details.prefill) == 1
17-
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
16+
assert len(response.details.prefill) == 2
17+
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
1818
assert len(response.details.tokens) == 1
19-
assert response.details.tokens[0].id == 3
20-
assert response.details.tokens[0].text == " "
19+
assert response.details.tokens[0].id == 29918
20+
assert response.details.tokens[0].text == "_"
2121
assert not response.details.tokens[0].special
2222

2323

24-
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
25-
client = Client(flan_t5_xxl_url, hf_headers)
24+
def test_generate_best_of(llama_7b_url, hf_headers):
25+
client = Client(llama_7b_url, hf_headers)
2626
response = client.generate(
2727
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
2828
)
@@ -39,22 +39,22 @@ def test_generate_not_found(fake_url, hf_headers):
3939
client.generate("test")
4040

4141

42-
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
43-
client = Client(flan_t5_xxl_url, hf_headers)
42+
def test_generate_validation_error(llama_7b_url, hf_headers):
43+
client = Client(llama_7b_url, hf_headers)
4444
with pytest.raises(ValidationError):
4545
client.generate("test", max_new_tokens=10_000)
4646

4747

48-
def test_generate_stream(flan_t5_xxl_url, hf_headers):
49-
client = Client(flan_t5_xxl_url, hf_headers)
48+
def test_generate_stream(llama_7b_url, hf_headers):
49+
client = Client(llama_7b_url, hf_headers)
5050
responses = [
5151
response for response in client.generate_stream("test", max_new_tokens=1)
5252
]
5353

5454
assert len(responses) == 1
5555
response = responses[0]
5656

57-
assert response.generated_text == ""
57+
assert response.generated_text == "_"
5858
assert response.details.finish_reason == FinishReason.Length
5959
assert response.details.generated_tokens == 1
6060
assert response.details.seed is None
@@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
6666
list(client.generate_stream("test"))
6767

6868

69-
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
70-
client = Client(flan_t5_xxl_url, hf_headers)
69+
def test_generate_stream_validation_error(llama_7b_url, hf_headers):
70+
client = Client(llama_7b_url, hf_headers)
7171
with pytest.raises(ValidationError):
7272
list(client.generate_stream("test", max_new_tokens=10_000))
7373

7474

7575
@pytest.mark.asyncio
76-
async def test_generate_async(flan_t5_xxl_url, hf_headers):
77-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
76+
async def test_generate_async(llama_7b_url, hf_headers):
77+
client = AsyncClient(llama_7b_url, hf_headers)
7878
response = await client.generate(
7979
"test", max_new_tokens=1, decoder_input_details=True
8080
)
8181

82-
assert response.generated_text == ""
82+
assert response.generated_text == "_"
8383
assert response.details.finish_reason == FinishReason.Length
8484
assert response.details.generated_tokens == 1
8585
assert response.details.seed is None
86-
assert len(response.details.prefill) == 1
87-
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
86+
assert len(response.details.prefill) == 2
87+
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
88+
assert response.details.prefill[1] == InputToken(
89+
id=1243, text="test", logprob=-10.96875
90+
)
8891
assert len(response.details.tokens) == 1
89-
assert response.details.tokens[0].id == 3
90-
assert response.details.tokens[0].text == " "
92+
assert response.details.tokens[0].id == 29918
93+
assert response.details.tokens[0].text == "_"
9194
assert not response.details.tokens[0].special
9295

9396

9497
@pytest.mark.asyncio
95-
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
96-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
98+
async def test_generate_async_best_of(llama_7b_url, hf_headers):
99+
client = AsyncClient(llama_7b_url, hf_headers)
97100
response = await client.generate(
98101
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
99102
)
@@ -112,23 +115,23 @@ async def test_generate_async_not_found(fake_url, hf_headers):
112115

113116

114117
@pytest.mark.asyncio
115-
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
116-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
118+
async def test_generate_async_validation_error(llama_7b_url, hf_headers):
119+
client = AsyncClient(llama_7b_url, hf_headers)
117120
with pytest.raises(ValidationError):
118121
await client.generate("test", max_new_tokens=10_000)
119122

120123

121124
@pytest.mark.asyncio
122-
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
123-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
125+
async def test_generate_stream_async(llama_7b_url, hf_headers):
126+
client = AsyncClient(llama_7b_url, hf_headers)
124127
responses = [
125128
response async for response in client.generate_stream("test", max_new_tokens=1)
126129
]
127130

128131
assert len(responses) == 1
129132
response = responses[0]
130133

131-
assert response.generated_text == ""
134+
assert response.generated_text == "_"
132135
assert response.details.finish_reason == FinishReason.Length
133136
assert response.details.generated_tokens == 1
134137
assert response.details.seed is None
@@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
143146

144147

145148
@pytest.mark.asyncio
146-
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
147-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
149+
async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
150+
client = AsyncClient(llama_7b_url, hf_headers)
148151
with pytest.raises(ValidationError):
149152
async for _ in client.generate_stream("test", max_new_tokens=10_000):
150153
pass

0 commit comments

Comments
 (0)