Skip to content

Commit 84cabf7

Browse files
authored
Fix image-text-to-text provided kwargs to skip tokenizer (#100)
* Skip `tokenizer` if `HF_TASK='image-text-to-text' * Add unit tests for 'image-text-to-text' pipelines * Fix `pyproject.toml` warnings on `tool.ruff.lint` * Run `make style` to fix CI * Add `validate_image_text_to_text` function in `utils.py`
1 parent e0abd4b commit 84cabf7

File tree

5 files changed

+58
-75
lines changed

5 files changed

+58
-75
lines changed

pyproject.toml

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ no_implicit_optional = true
44
scripts_are_modules = true
55

66
[tool.ruff]
7+
# Same as Black.
8+
line-length = 119
9+
# Assume Python 3.11
10+
target-version = "py311"
11+
12+
[tool.ruff.lint]
713
select = [
814
"E", # pycodestyle errors
915
"W", # pycodestyle warnings
@@ -17,15 +23,8 @@ ignore = [
1723
"B008", # do not perform function calls in argument defaults
1824
"C901", # too complex
1925
]
20-
# Same as Black.
21-
line-length = 119
22-
2326
# Allow unused variables when underscore-prefixed.
2427
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
25-
26-
# Assume Python 3.11
27-
target-version = "py311"
28-
2928
per-file-ignores = { "__init__.py" = ["F401"] }
3029

3130
[tool.isort]

src/huggingface_inference_toolkit/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def get_pipeline(
237237
"zero-shot-image-classification",
238238
}:
239239
kwargs["feature_extractor"] = model_dir
240-
elif task not in {"image-to-text", "text-to-image"}:
240+
elif task not in {"image-text-to-text", "image-to-text", "text-to-image"}:
241241
kwargs["tokenizer"] = model_dir
242242

243243
if is_sentence_transformers_available() and task in [

tests/integ/config.py

+23-30
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
validate_custom,
88
validate_feature_extraction,
99
validate_fill_mask,
10+
validate_image_text_to_text,
1011
validate_ner,
1112
validate_object_detection,
1213
validate_question_answering,
@@ -108,6 +109,10 @@
108109
"pytorch": "hf-internal-testing/tiny-random-beit-pipeline",
109110
"tensorflow": None,
110111
},
112+
"image-text-to-text": {
113+
"pytorch": "Salesforce/blip-image-captioning-base",
114+
"tensorflow": None,
115+
},
111116
}
112117

113118

@@ -134,24 +139,12 @@
134139
"inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything."
135140
},
136141
"text-generation": {"inputs": "My name is philipp and I am"},
137-
"image-classification": open(
138-
os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb"
139-
).read(),
140-
"zero-shot-image-classification": open(
141-
os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb"
142-
).read(),
143-
"object-detection": open(
144-
os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb"
145-
).read(),
146-
"image-segmentation": open(
147-
os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb"
148-
).read(),
149-
"automatic-speech-recognition": open(
150-
os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb"
151-
).read(),
152-
"audio-classification": open(
153-
os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb"
154-
).read(),
142+
"image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
143+
"zero-shot-image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
144+
"object-detection": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
145+
"image-segmentation": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
146+
"automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
147+
"audio-classification": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
155148
"table-question-answering": {
156149
"inputs": {
157150
"query": "How many stars does the transformers repository have?",
@@ -175,11 +168,15 @@
175168
}
176169
},
177170
"sentence-embeddings": {"inputs": "Lets create an embedding"},
178-
"sentence-ranking": {
179-
"inputs": ["Lets create an embedding", "Lets create an embedding"]
180-
},
171+
"sentence-ranking": {"inputs": ["Lets create an embedding", "Lets create an embedding"]},
181172
"text-to-image": {"inputs": "a man on a horse jumps over a broken down airplane."},
182173
"custom": {"inputs": "this is a test"},
174+
"image-text-to-text": {
175+
"inputs": {
176+
"images": "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
177+
"text": "A photo of",
178+
}
179+
},
183180
}
184181

185182
task2output = {
@@ -213,15 +210,9 @@
213210
"end": 77,
214211
"answer": "sagemaker",
215212
},
216-
"summarization": [
217-
{"summary_text": " The A The The ANew York City has been installed in the US."}
218-
],
219-
"translation_xx_to_yy": [
220-
{"translation_text": "Mein Name ist Sarah und ich lebe in London"}
221-
],
222-
"text2text-generation": [
223-
{"generated_text": "42 is the answer to life, the universe and everything"}
224-
],
213+
"summarization": [{"summary_text": " The A The The ANew York City has been installed in the US."}],
214+
"translation_xx_to_yy": [{"translation_text": "Mein Name ist Sarah und ich lebe in London"}],
215+
"text2text-generation": [{"generated_text": "42 is the answer to life, the universe and everything"}],
225216
"feature-extraction": None,
226217
"fill-mask": None,
227218
"text-generation": None,
@@ -269,6 +260,7 @@
269260
"sentence-embeddings": {"embeddings": ""},
270261
"sentence-ranking": {"scores": ""},
271262
"text-to-image": bytes,
263+
"image-text-to-text": [{"input_text": "A photo of", "generated_text": "..."}],
272264
"custom": {"inputs": "this is a test"},
273265
}
274266

@@ -296,5 +288,6 @@
296288
"sentence-embeddings": validate_zero_shot_classification,
297289
"sentence-ranking": validate_zero_shot_classification,
298290
"text-to-image": validate_text_to_image,
291+
"image-text-to-text": validate_image_text_to_text,
299292
"custom": validate_custom,
300293
}

tests/integ/helpers.py

+20-37
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def make_sure_other_containers_are_stopped(client: DockerClient, container_name:
3434
# reraise = True
3535
# )
3636
def wait_for_container_to_be_ready(base_url, time_between_retries=3, max_retries=30):
37-
3837
retries = 0
3938
error = None
4039

@@ -46,9 +45,7 @@ def wait_for_container_to_be_ready(base_url, time_between_retries=3, max_retries
4645
logging.info("Container ready!")
4746
return True
4847
else:
49-
raise ConnectionError(
50-
f"Couldn'start container, Error: {response.status_code}"
51-
)
48+
raise ConnectionError(f"Couldn'start container, Error: {response.status_code}")
5249
except Exception as exception:
5350
error = exception
5451
logging.warning(f"Container at {base_url} not ready, trying again...")
@@ -62,7 +59,6 @@ def verify_task(
6259
# container: DockerClient,
6360
task: str,
6461
port: int = 5000,
65-
framework: str = "pytorch",
6662
):
6763
BASE_URL = f"http://localhost:{port}"
6864
logging.info(f"Base URL: {BASE_URL}")
@@ -90,10 +86,7 @@ def verify_task(
9086
headers={"content-type": "audio/x-audio"},
9187
).json()
9288
elif task == "text-to-image":
93-
prediction = requests.post(
94-
f"{BASE_URL}", json=input, headers={"accept": "image/png"}
95-
).content
96-
89+
prediction = requests.post(f"{BASE_URL}", json=input, headers={"accept": "image/png"}).content
9790
else:
9891
prediction = requests.post(f"{BASE_URL}", json=input).json()
9992

@@ -119,6 +112,8 @@ def verify_task(
119112
@pytest.mark.parametrize(
120113
"task",
121114
[
115+
# transformers
116+
# TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input
122117
"text-classification",
123118
"zero-shot-classification",
124119
"token-classification",
@@ -136,25 +131,22 @@ def verify_task(
136131
"image-segmentation",
137132
"table-question-answering",
138133
"conversational",
139-
# TODO currently not supported due to multimodality input
140-
# "visual-question-answering",
141-
# "zero-shot-image-classification",
134+
"image-text-to-text",
135+
# sentence-transformers
142136
"sentence-similarity",
143137
"sentence-embeddings",
144138
"sentence-ranking",
145139
# diffusers
146140
"text-to-image",
147141
],
148142
)
149-
def test_pt_container_remote_model(task) -> None:
143+
def test_pt_container_remote_model(task: str) -> None:
150144
container_name = f"integration-test-{task}"
151145
container_image = f"starlette-transformers:{DEVICE}"
152146
framework = "pytorch"
153147
model = task2model[task][framework]
154148
port = random.randint(5000, 6000)
155-
device_request = (
156-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
157-
)
149+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
158150

159151
make_sure_other_containers_are_stopped(client, container_name)
160152
container = client.containers.run(
@@ -177,6 +169,8 @@ def test_pt_container_remote_model(task) -> None:
177169
@pytest.mark.parametrize(
178170
"task",
179171
[
172+
# transformers
173+
# TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input
180174
"text-classification",
181175
"zero-shot-classification",
182176
"token-classification",
@@ -194,29 +188,26 @@ def test_pt_container_remote_model(task) -> None:
194188
"image-segmentation",
195189
"table-question-answering",
196190
"conversational",
197-
# TODO currently not supported due to multimodality input
198-
# "visual-question-answering",
199-
# "zero-shot-image-classification",
191+
"image-text-to-text",
192+
# sentence-transformers
200193
"sentence-similarity",
201194
"sentence-embeddings",
202195
"sentence-ranking",
203196
# diffusers
204197
"text-to-image",
205198
],
206199
)
207-
def test_pt_container_local_model(task) -> None:
200+
def test_pt_container_local_model(task: str) -> None:
208201
container_name = f"integration-test-{task}"
209202
container_image = f"starlette-transformers:{DEVICE}"
210203
framework = "pytorch"
211204
model = task2model[task][framework]
212205
port = random.randint(5000, 6000)
213-
device_request = (
214-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
215-
)
206+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
216207
make_sure_other_containers_are_stopped(client, container_name)
217208
with tempfile.TemporaryDirectory() as tmpdirname:
218209
# https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py
219-
_storage_dir = _load_repository_from_hf(model, tmpdirname, framework="pytorch")
210+
_load_repository_from_hf(model, tmpdirname, framework="pytorch")
220211
container = client.containers.run(
221212
container_image,
222213
name=container_name,
@@ -241,9 +232,7 @@ def test_pt_container_local_model(task) -> None:
241232
def test_pt_container_custom_handler(repository_id) -> None:
242233
container_name = "integration-test-custom"
243234
container_image = f"starlette-transformers:{DEVICE}"
244-
device_request = (
245-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
246-
)
235+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
247236
port = random.randint(5000, 6000)
248237

249238
make_sure_other_containers_are_stopped(client, container_name)
@@ -277,12 +266,10 @@ def test_pt_container_custom_handler(repository_id) -> None:
277266
"repository_id",
278267
["philschmid/custom-pipeline-text-classification"],
279268
)
280-
def test_pt_container_legacy_custom_pipeline(repository_id) -> None:
269+
def test_pt_container_legacy_custom_pipeline(repository_id: str) -> None:
281270
container_name = "integration-test-custom"
282271
container_image = f"starlette-transformers:{DEVICE}"
283-
device_request = (
284-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
285-
)
272+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
286273
port = random.randint(5000, 6000)
287274

288275
make_sure_other_containers_are_stopped(client, container_name)
@@ -345,9 +332,7 @@ def test_tf_container_remote_model(task) -> None:
345332
container_image = f"starlette-transformers:{DEVICE}"
346333
framework = "tensorflow"
347334
model = task2model[task][framework]
348-
device_request = (
349-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
350-
)
335+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
351336
if model is None:
352337
pytest.skip("no supported TF model")
353338
port = random.randint(5000, 6000)
@@ -401,9 +386,7 @@ def test_tf_container_local_model(task) -> None:
401386
container_image = f"starlette-transformers:{DEVICE}"
402387
framework = "tensorflow"
403388
model = task2model[task][framework]
404-
device_request = (
405-
[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
406-
)
389+
device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else []
407390
if model is None:
408391
pytest.skip("no supported TF model")
409392
port = random.randint(5000, 6000)

tests/integ/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ def validate_classification(result=None, snapshot=None):
66
assert result[idx].keys() == snapshot[idx].keys()
77
return True
88

9+
910
def validate_conversational(result=None, snapshot=None):
1011
assert len(result[0]["generated_text"]) >= len(snapshot)
1112

@@ -82,6 +83,13 @@ def validate_text_to_image(result=None, snapshot=None):
8283
assert isinstance(result, snapshot)
8384
return True
8485

86+
87+
def validate_image_text_to_text(result=None, snapshot=None):
88+
assert isinstance(result, list)
89+
assert all(isinstance(d, dict) and d.keys() == {"input_text", "generated_text"} for d in result)
90+
return True
91+
92+
8593
def validate_custom(result=None, snapshot=None):
8694
logging.info(f"Validate custom task - result: {result}, snapshot: {snapshot}")
8795
assert result == snapshot

0 commit comments

Comments
 (0)