2727from typing import Iterator , TypeAlias
2828
2929import docker
30+ import docker .errors
3031import pytest
3132import requests
3233from huggingface_hub import ChatCompletionInputMessage
4950@pytest .fixture (scope = "module" )
5051def tgi_model () -> Iterator [TGIModel ]:
5152 client = docker .from_env ()
52- port = random .randint (8000 , 9000 )
53- container = client .containers .run (
54- "ghcr.io/huggingface/text-generation-inference:2.2.0" ,
55- command = [
56- "--model-id" ,
57- "hf-internal-testing/tiny-random-LlamaForCausalLM" ,
58- "--dtype" ,
59- "float16" ,
60- ],
61- detach = True ,
62- name = "lighteval-tgi-model-test" ,
63- auto_remove = True ,
64- ports = {"80/tcp" : port },
65- )
53+
54+ try :
55+ container = client .containers .get ("lighteval-tgi-model-test" )
56+ port = container .ports ["80/tcp" ][0 ]["HostPort" ]
57+ except docker .errors .NotFound :
58+ port = random .randint (8000 , 9000 )
59+ container = client .containers .run (
60+ "ghcr.io/huggingface/text-generation-inference:2.2.0" ,
61+ command = [
62+ "--model-id" ,
63+ "hf-internal-testing/tiny-random-LlamaForCausalLM" ,
64+ "--dtype" ,
65+ "float16" ,
66+ ],
67+ detach = True ,
68+ name = "lighteval-tgi-model-test" ,
69+ auto_remove = False ,
70+ ports = {"80/tcp" : port },
71+ )
6672 address = f"http://localhost:{ port } "
6773 for _ in range (30 ):
6874 try :
@@ -76,6 +82,7 @@ def tgi_model() -> Iterator[TGIModel]:
7682 yield model
7783 container .stop ()
7884 container .wait ()
85+ container .remove ()
7986 model .cleanup ()
8087
8188
0 commit comments