Skip to content

Commit

Permalink
Fix sth in tgi_model
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Aug 25, 2024
1 parent e1a5bc1 commit b44362b
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions tests/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Iterator, TypeAlias

import docker
import docker.errors
import pytest
import requests
from huggingface_hub import ChatCompletionInputMessage
Expand All @@ -49,20 +50,25 @@
@pytest.fixture(scope="module")
def tgi_model() -> Iterator[TGIModel]:
client = docker.from_env()
port = random.randint(8000, 9000)
container = client.containers.run(
"ghcr.io/huggingface/text-generation-inference:2.2.0",
command=[
"--model-id",
"hf-internal-testing/tiny-random-LlamaForCausalLM",
"--dtype",
"float16",
],
detach=True,
name="lighteval-tgi-model-test",
auto_remove=True,
ports={"80/tcp": port},
)

try:
container = client.containers.get("lighteval-tgi-model-test")
port = container.ports["80/tcp"][0]["HostPort"]
except docker.errors.NotFound:
port = random.randint(8000, 9000)
container = client.containers.run(
"ghcr.io/huggingface/text-generation-inference:2.2.0",
command=[
"--model-id",
"hf-internal-testing/tiny-random-LlamaForCausalLM",
"--dtype",
"float16",
],
detach=True,
name="lighteval-tgi-model-test",
auto_remove=False,
ports={"80/tcp": port},
)
address = f"http://localhost:{port}"
for _ in range(30):
try:
Expand All @@ -76,6 +82,7 @@ def tgi_model() -> Iterator[TGIModel]:
yield model
container.stop()
container.wait()
container.remove()
model.cleanup()


Expand Down

0 comments on commit b44362b

Please sign in to comment.