55from text_generation .types import FinishReason , InputToken
66
77
8- def test_generate (flan_t5_xxl_url , hf_headers ):
9- client = Client (flan_t5_xxl_url , hf_headers )
8+ def test_generate (llama_7b_url , hf_headers ):
9+ client = Client (llama_7b_url , hf_headers )
1010 response = client .generate ("test" , max_new_tokens = 1 , decoder_input_details = True )
1111
12- assert response .generated_text == ""
12+ assert response .generated_text == "_ "
1313 assert response .details .finish_reason == FinishReason .Length
1414 assert response .details .generated_tokens == 1
1515 assert response .details .seed is None
16- assert len (response .details .prefill ) == 1
17- assert response .details .prefill [0 ] == InputToken (id = 0 , text = "<pad >" , logprob = None )
16+ assert len (response .details .prefill ) == 2
17+ assert response .details .prefill [0 ] == InputToken (id = 1 , text = "<s >" , logprob = None )
1818 assert len (response .details .tokens ) == 1
19- assert response .details .tokens [0 ].id == 3
20- assert response .details .tokens [0 ].text == " "
19+ assert response .details .tokens [0 ].id == 29918
20+ assert response .details .tokens [0 ].text == "_ "
2121 assert not response .details .tokens [0 ].special
2222
2323
24- def test_generate_best_of (flan_t5_xxl_url , hf_headers ):
25- client = Client (flan_t5_xxl_url , hf_headers )
24+ def test_generate_best_of (llama_7b_url , hf_headers ):
25+ client = Client (llama_7b_url , hf_headers )
2626 response = client .generate (
2727 "test" , max_new_tokens = 1 , best_of = 2 , do_sample = True , decoder_input_details = True
2828 )
@@ -39,22 +39,22 @@ def test_generate_not_found(fake_url, hf_headers):
3939 client .generate ("test" )
4040
4141
42- def test_generate_validation_error (flan_t5_xxl_url , hf_headers ):
43- client = Client (flan_t5_xxl_url , hf_headers )
42+ def test_generate_validation_error (llama_7b_url , hf_headers ):
43+ client = Client (llama_7b_url , hf_headers )
4444 with pytest .raises (ValidationError ):
4545 client .generate ("test" , max_new_tokens = 10_000 )
4646
4747
48- def test_generate_stream (flan_t5_xxl_url , hf_headers ):
49- client = Client (flan_t5_xxl_url , hf_headers )
48+ def test_generate_stream (llama_7b_url , hf_headers ):
49+ client = Client (llama_7b_url , hf_headers )
5050 responses = [
5151 response for response in client .generate_stream ("test" , max_new_tokens = 1 )
5252 ]
5353
5454 assert len (responses ) == 1
5555 response = responses [0 ]
5656
57- assert response .generated_text == ""
57+ assert response .generated_text == "_ "
5858 assert response .details .finish_reason == FinishReason .Length
5959 assert response .details .generated_tokens == 1
6060 assert response .details .seed is None
@@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
6666 list (client .generate_stream ("test" ))
6767
6868
69- def test_generate_stream_validation_error (flan_t5_xxl_url , hf_headers ):
70- client = Client (flan_t5_xxl_url , hf_headers )
69+ def test_generate_stream_validation_error (llama_7b_url , hf_headers ):
70+ client = Client (llama_7b_url , hf_headers )
7171 with pytest .raises (ValidationError ):
7272 list (client .generate_stream ("test" , max_new_tokens = 10_000 ))
7373
7474
7575@pytest .mark .asyncio
76- async def test_generate_async (flan_t5_xxl_url , hf_headers ):
77- client = AsyncClient (flan_t5_xxl_url , hf_headers )
76+ async def test_generate_async (llama_7b_url , hf_headers ):
77+ client = AsyncClient (llama_7b_url , hf_headers )
7878 response = await client .generate (
7979 "test" , max_new_tokens = 1 , decoder_input_details = True
8080 )
8181
82- assert response .generated_text == ""
82+ assert response .generated_text == "_ "
8383 assert response .details .finish_reason == FinishReason .Length
8484 assert response .details .generated_tokens == 1
8585 assert response .details .seed is None
86- assert len (response .details .prefill ) == 1
87- assert response .details .prefill [0 ] == InputToken (id = 0 , text = "<pad>" , logprob = None )
86+ assert len (response .details .prefill ) == 2
87+ assert response .details .prefill [0 ] == InputToken (id = 1 , text = "<s>" , logprob = None )
88+ assert response .details .prefill [1 ] == InputToken (
89+ id = 1243 , text = "test" , logprob = - 10.96875
90+ )
8891 assert len (response .details .tokens ) == 1
89- assert response .details .tokens [0 ].id == 3
90- assert response .details .tokens [0 ].text == " "
92+ assert response .details .tokens [0 ].id == 29918
93+ assert response .details .tokens [0 ].text == "_ "
9194 assert not response .details .tokens [0 ].special
9295
9396
9497@pytest .mark .asyncio
95- async def test_generate_async_best_of (flan_t5_xxl_url , hf_headers ):
96- client = AsyncClient (flan_t5_xxl_url , hf_headers )
98+ async def test_generate_async_best_of (llama_7b_url , hf_headers ):
99+ client = AsyncClient (llama_7b_url , hf_headers )
97100 response = await client .generate (
98101 "test" , max_new_tokens = 1 , best_of = 2 , do_sample = True , decoder_input_details = True
99102 )
@@ -112,23 +115,23 @@ async def test_generate_async_not_found(fake_url, hf_headers):
112115
113116
114117@pytest .mark .asyncio
115- async def test_generate_async_validation_error (flan_t5_xxl_url , hf_headers ):
116- client = AsyncClient (flan_t5_xxl_url , hf_headers )
118+ async def test_generate_async_validation_error (llama_7b_url , hf_headers ):
119+ client = AsyncClient (llama_7b_url , hf_headers )
117120 with pytest .raises (ValidationError ):
118121 await client .generate ("test" , max_new_tokens = 10_000 )
119122
120123
121124@pytest .mark .asyncio
122- async def test_generate_stream_async (flan_t5_xxl_url , hf_headers ):
123- client = AsyncClient (flan_t5_xxl_url , hf_headers )
125+ async def test_generate_stream_async (llama_7b_url , hf_headers ):
126+ client = AsyncClient (llama_7b_url , hf_headers )
124127 responses = [
125128 response async for response in client .generate_stream ("test" , max_new_tokens = 1 )
126129 ]
127130
128131 assert len (responses ) == 1
129132 response = responses [0 ]
130133
131- assert response .generated_text == ""
134+ assert response .generated_text == "_ "
132135 assert response .details .finish_reason == FinishReason .Length
133136 assert response .details .generated_tokens == 1
134137 assert response .details .seed is None
@@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
143146
144147
145148@pytest .mark .asyncio
146- async def test_generate_stream_async_validation_error (flan_t5_xxl_url , hf_headers ):
147- client = AsyncClient (flan_t5_xxl_url , hf_headers )
149+ async def test_generate_stream_async_validation_error (llama_7b_url , hf_headers ):
150+ client = AsyncClient (llama_7b_url , hf_headers )
148151 with pytest .raises (ValidationError ):
149152 async for _ in client .generate_stream ("test" , max_new_tokens = 10_000 ):
150153 pass
0 commit comments