Skip to content

Commit d9ae3d9

Browse files
authored
Add HF_TRUST_REMOTE_CODE environment variable (#78)
* Propagate `**kwargs` to `sentence-transformers` and `diffusers` pipelines * Add `HF_TRUST_REMOTE_CODE` env var * Fix `HF_TRUST_REMOTE_CODE` bool-handling via `strtobool` The `strtobool` had to be defined within `huggingface_inference_toolkit` since it's deprecated and removed from `distutils` from Python 3.10 onwards. * Fix some typos with `codespell` * Update `README.md` * Bump version to `0.4.2` * Move `strtobool` to `env_utils` module to avoid circular import * Revert enforce of `trust_remote_code=True` * Remove `logging` messages for debug * Fix `diffusers` propagation of `trust_remote_code=True`
1 parent 4e3877f commit d9ae3d9

File tree

9 files changed

+114
-66
lines changed

9 files changed

+114
-66
lines changed

README.md

+65-55
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11

22
<div style="display:flex; text-align:center; justify-content:center;">
3-
<img src="https://huggingface.co/front/assets/huggingface_logo.svg" width="100"/>
3+
<img src="https://huggingface.co/front/assets/huggingface_logo.svg" width="100"/>
44
<h1 style="margin-top:auto;"> Hugging Face Inference Toolkit <h1>
55
</div>
66

7-
87
Hugging Face Inference Toolkit is for serving 🤗 Transformers models in containers. This library provides default pre-processing, predict and postprocessing for Transformers, Sentence Tranfsformers. It is also possible to define custom `handler.py` for customization. The Toolkit is build to work with the [Hugging Face Hub](https://huggingface.co/models).
98

109
---
1110

1211
## 💻 Getting Started with Hugging Face Inference Toolkit
1312

14-
* Clone the repository `git clone https://github.com/huggingface/huggingface-inference-toolkit``
15-
* Install the dependencies in dev mode `pip install -e ".[torch, st, diffusers, test,quality]"`
16-
* If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade`
13+
* Clone the repository `git clone <https://github.com/huggingface/huggingface-inference-toolkit``>
14+
* Install the dependencies in dev mode `pip install -e ".[torch,st,diffusers,test,quality]"`
15+
* If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade`
16+
* If you develop on Google Cloud install with `pip install -e ".[torch,st,diffusers,google,test,quality]"`
1717
* Unit Testing: `make unit-test`
1818
* Integration testing: `make integ-test`
1919

20-
2120
### Local run
2221

2322
```bash
@@ -27,22 +26,22 @@ HF_MODEL_ID=hf-internal-testing/tiny-random-distilbert HF_MODEL_DIR=tmp2 HF_TASK
2726

2827
### Container
2928

30-
3129
1. build the preferred container for either CPU or GPU for PyTorch.
3230

33-
_cpu images_
31+
_CPU Images_
32+
3433
```bash
3534
make inference-pytorch-cpu
3635
```
3736

38-
_gpu images_
37+
_GPU Images_
38+
3939
```bash
4040
make inference-pytorch-gpu
4141
```
4242

4343
2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored.
4444

45-
4645
```bash
4746
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering integration-test-pytorch:cpu
4847
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text integration-test-pytorch:gpu
@@ -51,43 +50,44 @@ docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusi
5150
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository integration-test-pytorch:cpu
5251
```
5352

54-
5553
3. Send request. The API schema is the same as from the [inference API](https://huggingface.co/docs/api-inference/detailed_parameters)
5654

5755
```bash
5856
curl --request POST \
5957
--url http://localhost:5000 \
6058
--header 'Content-Type: application/json' \
6159
--data '{
62-
"inputs": {
63-
"question": "What is used for inference?",
64-
"context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference."
65-
}
60+
"inputs": {
61+
"question": "What is used for inference?",
62+
"context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference."
63+
}
6664
}'
6765
```
6866

6967
### Custom Handler and dependency support
7068

71-
The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
72-
For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
69+
The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
70+
71+
For an example check [philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
72+
7373
```bash
7474
model.tar.gz/
7575
|- pytorch_model.bin
7676
|- ....
7777
|- handler.py
7878
|- requirements.txt
7979
```
80+
8081
In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies.
8182
The custom module can override the following methods:
8283

83-
8484
### Vertex AI Support
8585

86-
The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
86+
The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
8787

8888
#### Local run with HF_MODEL_ID and HF_TASK
8989

90-
Start Hugging Face Inference Toolkit with the following environment variables.
90+
Start Hugging Face Inference Toolkit with the following environment variables.
9191

9292
```bash
9393
mkdir tmp2/
@@ -101,8 +101,8 @@ curl --request POST \
101101
--url http://localhost:8080/pred \
102102
--header 'Content-Type: application/json' \
103103
--data '{
104-
"instances": ["I love this product", "I hate this product"],
105-
"parameters": { "top_k": 2 }
104+
"instances": ["I love this product", "I hate this product"],
105+
"parameters": { "top_k": 2 }
106106
}'
107107
```
108108

@@ -124,35 +124,39 @@ docker run -ti -p 8080:8080 -e AIP_MODE=PREDICTION -e AIP_HTTP_PORT=8080 -e AIP_
124124

125125
```bash
126126
curl --request POST \
127-
--url http://localhost:8080/pred \
128-
--header 'Content-Type: application/json' \
129-
--data '{
130-
"instances": ["I love this product", "I hate this product"],
131-
"parameters": { "top_k": 2 }
127+
--url http://localhost:8080/pred \
128+
--header 'Content-Type: application/json' \
129+
--data '{
130+
"instances": ["I love this product", "I hate this product"],
131+
"parameters": { "top_k": 2 }
132132
}'
133133
```
134134

135-
### AWS Inferentia2 Support
135+
### AWS Inferentia2 Support
136136

137137
The Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options:
138-
* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx`
138+
139+
* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format e.g. `optimum/bge-base-en-v1.5-neuronx`
139140
* Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128`
140141
* Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}`
141142

142143
The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs.
143144

144145
#### Local run with HF_MODEL_ID and HF_TASK
145146

146-
Start Hugging Face Inference Toolkit with the following environment variables.
147+
Start Hugging Face Inference Toolkit with the following environment variables.
147148

148149
_Note: You need to run this on an Inferentia2 instance._
149150

150-
- transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
151+
* transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
152+
151153
```bash
152154
mkdir tmp2/
153155
HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
154156
```
155-
- sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
157+
158+
* sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
159+
156160
```bash
157161
HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
158162
```
@@ -161,16 +165,15 @@ Send request
161165

162166
```bash
163167
curl --request POST \
164-
--url http://localhost:5000 \
165-
--header 'Content-Type: application/json' \
166-
--data '{
167-
"inputs": "Wow, this is such a great product. I love it!"
168+
--url http://localhost:5000 \
169+
--header 'Content-Type: application/json' \
170+
--data '{
171+
"inputs": "Wow, this is such a great product. I love it!"
168172
}'
169173
```
170174

171175
#### Container run with HF_MODEL_ID and HF_TASK
172176

173-
174177
1. build the preferred container for either CPU or GPU for PyTorch o.
175178

176179
```bash
@@ -187,26 +190,25 @@ docker run -ti -p 5000:5000 -e HF_MODEL_ID="distilbert/distilbert-base-uncased-f
187190

188191
```bash
189192
curl --request POST \
190-
--url http://localhost:5000 \
191-
--header 'Content-Type: application/json' \
192-
--data '{
193-
"inputs": "Wow, this is such a great product. I love it!",
194-
"parameters": { "top_k": 2 }
193+
--url http://localhost:5000 \
194+
--header 'Content-Type: application/json' \
195+
--data '{
196+
"inputs": "Wow, this is such a great product. I love it!",
197+
"parameters": { "top_k": 2 }
195198
}'
196199
```
197200

198-
199201
---
200202

201203
## 🛠️ Environment variables
202204

203-
The Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. All potential environment varialbes can be found in [const.py](src/huggingface_inference_toolkit/const.py)
205+
The Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. All potential environment variables can be found in [const.py](src/huggingface_inference_toolkit/const.py)
204206

205207
### `HF_MODEL_DIR`
206208

207-
The `HF_MODEL_DIR` environment variable defines the directory where your model is stored or will be stored.
208-
If `HF_MODEL_ID` is not set the toolkit expects a the model artifact at this directory. This value should be set to the value where you mount your model artifacts.
209-
If `HF_MODEL_ID` is set the toolkit and the directory where `HF_MODEL_DIR` is pointing to is empty. The toolkit will download the model from the Hub to this directory.
209+
The `HF_MODEL_DIR` environment variable defines the directory where your model is stored or will be stored.
210+
If `HF_MODEL_ID` is not set the toolkit expects a the model artifact at this directory. This value should be set to the value where you mount your model artifacts.
211+
If `HF_MODEL_ID` is set the toolkit and the directory where `HF_MODEL_DIR` is pointing to is empty. The toolkit will download the model from the Hub to this directory.
210212

211213
The default value is `/opt/huggingface/model`
212214

@@ -246,6 +248,14 @@ The `HF_HUB_TOKEN` environment variable defines the your Hugging Face authorizat
246248
HF_HUB_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
247249
```
248250

251+
### `HF_TRUST_REMOTE_CODE`
252+
253+
The `HF_TRUST_REMOTE_CODE` environment variable defines whether to trust remote code. This flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers when loading models from the Hugging Face Hub. The default value is `"0"`; set it to `"1"` to trust remote code.
254+
255+
```bash
256+
HF_TRUST_REMOTE_CODE="0"
257+
```
258+
249259
### `HF_FRAMEWORK`
250260

251261
The `HF_FRAMEWORK` environment variable defines the base deep learning framework used in the container. This is important when loading large models from the Hugguing Face Hub to avoid extra file downloads.
@@ -256,28 +266,28 @@ HF_FRAMEWORK="pytorch"
256266

257267
#### `HF_OPTIMUM_BATCH_SIZE`
258268

259-
The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
269+
The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
260270

261271
```bash
262272
HF_OPTIMUM_BATCH_SIZE="1"
263273
```
264274

265275
#### `HF_OPTIMUM_SEQUENCE_LENGTH`
266276

267-
The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
277+
The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
268278

269279
```bash
270280
HF_OPTIMUM_SEQUENCE_LENGTH="128"
271281
```
272282

273283
---
274284

275-
## ⚙ Supported Frontend
285+
## ⚙ Supported Front-Ends
276286

277-
- [x] Starlette (HF Endpoints)
278-
- [x] Starlette (Vertex AI)
279-
- [ ] Starlette (Azure ML)
280-
- [ ] Starlette (SageMaker)
287+
* [x] Starlette (HF Endpoints)
288+
* [x] Starlette (Vertex AI)
289+
* [ ] Starlette (Azure ML)
290+
* [ ] Starlette (SageMaker)
281291

282292
---
283293

@@ -287,6 +297,6 @@ HF_OPTIMUM_SEQUENCE_LENGTH="128"
287297

288298
## 📜 License
289299

290-
TBD.
300+
TBD.
291301

292302
---

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# We don't declare our dependency on transformers here because we build with
66
# different packages for different variants
77

8-
VERSION = "0.4.1.dev0"
8+
VERSION = "0.4.2"
99

1010
# Ubuntu packages
1111
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
1212
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
13-
# libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg
13+
# libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg
1414

1515
install_requires = [
1616
"transformers[sklearn,sentencepiece,audio,vision]==4.41.1",
+6-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import os
22
from pathlib import Path
33

4+
from huggingface_inference_toolkit.env_utils import strtobool
5+
46
HF_MODEL_DIR = os.environ.get("HF_MODEL_DIR", "/opt/huggingface/model")
57
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", None)
68
HF_TASK = os.environ.get("HF_TASK", None)
79
HF_FRAMEWORK = os.environ.get("HF_FRAMEWORK", None)
810
HF_REVISION = os.environ.get("HF_REVISION", None)
911
HF_HUB_TOKEN = os.environ.get("HF_HUB_TOKEN", None)
12+
HF_TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "0"))
1013
# custom handler consts
1114
HF_DEFAULT_PIPELINE_NAME = os.environ.get("HF_DEFAULT_PIPELINE_NAME", "handler.py")
1215
# default is pipeline.PreTrainedPipeline
13-
HF_MODULE_NAME = os.environ.get("HF_MODULE_NAME", f"{Path(HF_DEFAULT_PIPELINE_NAME).stem}.EndpointHandler")
16+
HF_MODULE_NAME = os.environ.get(
17+
"HF_MODULE_NAME", f"{Path(HF_DEFAULT_PIPELINE_NAME).stem}.EndpointHandler"
18+
)

src/huggingface_inference_toolkit/diffusers_utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
from typing import Union
23

34
from transformers.utils.import_utils import is_torch_bf16_gpu_available
45

@@ -21,14 +22,16 @@ def is_diffusers_available():
2122

2223

2324
class IEAutoPipelineForText2Image:
24-
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
25+
def __init__(
26+
self, model_dir: str, device: Union[str, None] = None, **kwargs
27+
): # needs "cuda" for GPU
2528
dtype = torch.float32
2629
if device == "cuda":
2730
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
2831
device_map = "auto" if device == "cuda" else None
2932

3033
self.pipeline = AutoPipelineForText2Image.from_pretrained(
31-
model_dir, torch_dtype=dtype, device_map=device_map
34+
model_dir, torch_dtype=dtype, device_map=device_map, **kwargs
3235
)
3336
# try to use DPMSolverMultistepScheduler
3437
if isinstance(self.pipeline, StableDiffusionPipeline):
@@ -66,5 +69,5 @@ def __call__(
6669
def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs):
6770
"""Get a pipeline for Diffusers models."""
6871
device = "cuda" if device == 0 else "cpu"
69-
pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device)
72+
pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device, **kwargs)
7073
return pipeline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
def strtobool(val: str) -> bool:
2+
"""Convert a string representation of truth to True or False booleans.
3+
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
4+
are 'n', 'no', 'f', 'false', 'off', and '0'.
5+
6+
Raises:
7+
ValueError: if 'val' is anything else.
8+
9+
Note:
10+
Function `strtobool` copied and adapted from `distutils`, as it's deprecated from Python 3.10 onwards.
11+
12+
References:
13+
- https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321
14+
"""
15+
val = val.lower()
16+
if val in ("y", "yes", "t", "true", "on", "1"):
17+
return True
18+
if val in ("n", "no", "f", "false", "off", "0"):
19+
return False
20+
raise ValueError(
21+
f"Invalid truth value, it should be a string but {val} was provided instead."
22+
)

0 commit comments

Comments
 (0)