27
27
from typing import Iterator , TypeAlias
28
28
29
29
import docker
30
+ import docker .errors
30
31
import pytest
31
32
import requests
32
33
from huggingface_hub import ChatCompletionInputMessage
49
50
@pytest .fixture (scope = "module" )
50
51
def tgi_model () -> Iterator [TGIModel ]:
51
52
client = docker .from_env ()
52
- port = random .randint (8000 , 9000 )
53
- container = client .containers .run (
54
- "ghcr.io/huggingface/text-generation-inference:2.2.0" ,
55
- command = [
56
- "--model-id" ,
57
- "hf-internal-testing/tiny-random-LlamaForCausalLM" ,
58
- "--dtype" ,
59
- "float16" ,
60
- ],
61
- detach = True ,
62
- name = "lighteval-tgi-model-test" ,
63
- auto_remove = True ,
64
- ports = {"80/tcp" : port },
65
- )
53
+
54
+ try :
55
+ container = client .containers .get ("lighteval-tgi-model-test" )
56
+ port = container .ports ["80/tcp" ][0 ]["HostPort" ]
57
+ except docker .errors .NotFound :
58
+ port = random .randint (8000 , 9000 )
59
+ container = client .containers .run (
60
+ "ghcr.io/huggingface/text-generation-inference:2.2.0" ,
61
+ command = [
62
+ "--model-id" ,
63
+ "hf-internal-testing/tiny-random-LlamaForCausalLM" ,
64
+ "--dtype" ,
65
+ "float16" ,
66
+ ],
67
+ detach = True ,
68
+ name = "lighteval-tgi-model-test" ,
69
+ auto_remove = False ,
70
+ ports = {"80/tcp" : port },
71
+ )
66
72
address = f"http://localhost:{ port } "
67
73
for _ in range (30 ):
68
74
try :
@@ -76,6 +82,7 @@ def tgi_model() -> Iterator[TGIModel]:
76
82
yield model
77
83
container .stop ()
78
84
container .wait ()
85
+ container .remove ()
79
86
model .cleanup ()
80
87
81
88
0 commit comments