Skip to content

Commit 9390c55

Browse files
committed
wip vertex ai
1 parent 5674947 commit 9390c55

File tree

8 files changed

+211
-76
lines changed

8 files changed

+211
-76
lines changed

README.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ HF_MODEL_ID=hf-internal-testing/tiny-random-distilbert HF_MODEL_DIR=tmp2 HF_TASK
2020
### Container
2121

2222

23-
1. build the preferred container for either CPU or GPU for PyTorch or TensorFlow.
23+
1. build the preferred container for either CPU or GPU for PyTorch o.
2424

2525
_cpu images_
2626
```bash
@@ -58,6 +58,32 @@ curl --request POST \
5858
}'
5959
```
6060

61+
### Vertex AI Support
62+
63+
The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Enviornment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
64+
65+
#### Local run with HF_MODEL_ID and HF_TASK
66+
67+
Start Hugging Face Inference Toolkit with the following environment variables.
68+
69+
```bash
70+
mkdir tmp2/
71+
AIP_MODE=PREDICTION AIP_PORT=8080 AIP_PREDICT_ROUTE=/pred AIP_HEALTH_ROUTE=/h HF_MODEL_DIR=tmp2 HF_MODEL_ID=distilbert/distilbert-base-uncased-finetuned-sst-2-english HF_TASK=text-classification uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 8080
72+
```
73+
74+
Send request. The API schema is the same as from the [inference API](https://huggingface.co/docs/api-inference/detailed_parameters)
75+
76+
```bash
77+
curl --request POST \
78+
--url http://localhost:8080/pred \
79+
--header 'Content-Type: application/json' \
80+
--data '{
81+
"instances": ["I love this product", "I hate this product"],
82+
"parameters": { "top_k": 2 }
83+
}'
84+
```
85+
86+
6187

6288
---
6389

@@ -176,6 +202,7 @@ Below you ll find a list of supported and tested transformers and sentence trans
176202
## ⚙ Supported Frontend
177203

178204
- [x] Starlette (HF Endpoints)
205+
- [ ] Starlette (Vertex AI)
179206
- [ ] Starlette (Azure ML)
180207
- [ ] Starlette (SageMaker)
181208

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
ARG BASE_IMAGE=nvidia/cuda:12.1.0-devel-ubuntu22.04
2+
3+
FROM $BASE_IMAGE
4+
SHELL ["/bin/bash", "-c"]
5+
6+
LABEL maintainer="Hugging Face"
7+
8+
ENV DEBIAN_FRONTEND=noninteractive
9+
10+
WORKDIR /app
11+
12+
RUN apt-get update && \
13+
apt-get install software-properties-common -y && \
14+
add-apt-repository ppa:deadsnakes/ppa && \
15+
apt-get -y upgrade --only-upgrade systemd openssl cryptsetup && \
16+
apt-get install -y \
17+
build-essential \
18+
bzip2 \
19+
curl \
20+
git \
21+
git-lfs \
22+
tar \
23+
gcc \
24+
g++ \
25+
cmake \
26+
libprotobuf-dev \
27+
protobuf-compiler \
28+
python3-dev \
29+
python3-pip \
30+
python3.11 \
31+
libsndfile1-dev \
32+
ffmpeg \
33+
&& apt-get clean autoremove --yes \
34+
&& rm -rf /var/lib/{apt,dpkg,cache,log}
35+
# Copying only necessary files as filtered by .dockerignore
36+
COPY . .
37+
38+
# install wheel and setuptools
39+
RUN pip install --no-cache-dir -U pip ".[torch, st, diffusers]"
40+
41+
# copy application
42+
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
43+
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py
44+
45+
# copy entrypoint and change permissions
46+
COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh
47+
48+
ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]

requirements.txt

Whitespace-only changes.

scripts/entrypoint.sh

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
# /bin/bash
1+
#!/bin/bash
22

3-
# check if HF_MODEL_DIR is set and if not skip installing custom dependencies
3+
# Define the default port
4+
PORT=5000
5+
6+
# Check if AIP_MODE is set and adjust the port for Vertex AI
7+
if [[ ! -z "${AIP_MODE}" ]]; then
8+
PORT=${AIP_HTTP_PORT}
9+
fi
10+
11+
# Check if HF_MODEL_DIR is set and if not skip installing custom dependencies
412
if [[ ! -z "${HF_MODEL_DIR}" ]]; then
5-
# check if requirements.txt exists and if so install dependencies
13+
# Check if requirements.txt exists and if so install dependencies
614
if [ -f "${HF_MODEL_DIR}/requirements.txt" ]; then
715
echo "Installing custom dependencies from ${HF_MODEL_DIR}/requirements.txt"
816
pip install -r ${HF_MODEL_DIR}/requirements.txt --no-cache-dir;
917
fi
1018
fi
1119

12-
# start the server
13-
uvicorn webservice_starlette:app --host 0.0.0.0 --port 5000
20+
# Start the server
21+
uvicorn webservice_starlette:app --host 0.0.0.0 --port ${PORT}

src/huggingface_inference_toolkit/handler.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from pathlib import Path
34
from typing import Optional, Union
45

@@ -40,15 +41,52 @@ def __call__(self, data):
4041
return prediction
4142

4243

44+
class VertexAIHandler(HuggingFaceHandler):
45+
"""
46+
A Default Vertex AI Hugging Face Inference Handler which abstracts the
47+
Vertex AI specific logic for inference.
48+
"""
49+
def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
50+
super().__init__(model_dir, task, framework)
51+
52+
def __call__(self, data):
53+
"""
54+
Handles an inference request with input data and makes a prediction.
55+
Args:
56+
:data: (obj): the raw request body data.
57+
:return: prediction output
58+
"""
59+
if "instances" not in data:
60+
raise ValueError("The request body must contain a key 'instances' with a list of instances.")
61+
parameters = data.pop("parameters", None)
62+
63+
predictions = []
64+
# iterate over all instances and make predictions
65+
for inputs in data["instances"]:
66+
payload = {"inputs": inputs, "parameters": parameters}
67+
predictions.append(super().__call__(payload))
68+
69+
# reutrn predictions
70+
return {"predictions": predictions}
71+
4372
def get_inference_handler_either_custom_or_default_handler(
4473
model_dir: Path,
4574
task: Optional[str] = None
4675
):
4776
"""
48-
get inference handler either custom or default Handler
77+
Returns the appropriate inference handler based on the given model directory and task.
78+
79+
Args:
80+
model_dir (Path): The directory path where the model is stored.
81+
task (Optional[str]): The task for which the inference handler is required. Defaults to None.
82+
83+
Returns:
84+
InferenceHandler: The appropriate inference handler based on the given model directory and task.
4985
"""
5086
custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir)
5187
if custom_pipeline:
5288
return custom_pipeline
89+
elif os.environ.get("AIP_MODE", None) == "PREDICTION":
90+
return VertexAIHandler(model_dir=model_dir, task=task)
5391
else:
5492
return HuggingFaceHandler(model_dir=model_dir, task=task)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
from pathlib import Path
3+
import re
4+
from typing import Union
5+
6+
7+
logger = logging.getLogger(__name__)
8+
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
9+
10+
from google.cloud import storage
11+
12+
_logger = logging.getLogger(__name__)
13+
14+
15+
GCS_URI_PREFIX = "gs://"
16+
17+
18+
# copied from https://github.com/googleapis/python-aiplatform/blob/94d838d8cfe1599bc2d706e66080c05108821986/google/cloud/aiplatform/utils/prediction_utils.py#L121
19+
def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path]="/tmp"):
20+
"""
21+
Load files from GCS path to target_dir
22+
"""
23+
_logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}")
24+
target_dir = Path(target_dir)
25+
26+
if artifact_uri.startswith(GCS_URI_PREFIX):
27+
matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri)
28+
bucket_name, prefix = matches.groups()
29+
30+
gcs_client = storage.Client()
31+
blobs = gcs_client.list_blobs(bucket_name, prefix=prefix)
32+
for blob in blobs:
33+
name_without_prefix = blob.name[len(prefix) :]
34+
name_without_prefix = (
35+
name_without_prefix[1:]
36+
if name_without_prefix.startswith("/")
37+
else name_without_prefix
38+
)
39+
file_split = name_without_prefix.split("/")
40+
directory = target_dir.join(file_split[0:-1])
41+
directory.mkdir(parents=True, exist_ok=True)
42+
if name_without_prefix and not name_without_prefix.endswith("/"):
43+
blob.download_to_filename(name_without_prefix)
44+
45+
return str(target_dir.absolute())
46+

src/huggingface_inference_toolkit/webservice_robyn.py

-57
This file was deleted.

src/huggingface_inference_toolkit/webservice_starlette.py

+37-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from pathlib import Path
34
from time import perf_counter
45

@@ -20,6 +21,7 @@
2021
from huggingface_inference_toolkit.serialization.base import ContentType
2122
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
2223
from huggingface_inference_toolkit.utils import _load_repository_from_hf, convert_params_to_int_or_bool
24+
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
2325

2426

2527
def config_logging(level=logging.INFO):
@@ -35,10 +37,11 @@ def config_logging(level=logging.INFO):
3537
logger = logging.getLogger(__name__)
3638

3739

38-
async def some_startup_task():
40+
async def prepare_model_artifacts():
3941
global inference_handler
4042
# 1. check if model artifacts available in HF_MODEL_DIR
4143
if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0:
44+
# 2. if not available, try to load from HF_MODEL_ID
4245
if HF_MODEL_ID is not None:
4346
_load_repository_from_hf(
4447
repository_id=HF_MODEL_ID,
@@ -47,6 +50,11 @@ async def some_startup_task():
4750
revision=HF_REVISION,
4851
hf_hub_token=HF_HUB_TOKEN,
4952
)
53+
# 3. check if in Vertex AI environment and load from GCS
54+
# If artifactUri not on Model Creation not set returns an empty string
55+
elif len(os.environ.get("AIP_STORAGE_URI", '')) > 0:
56+
_load_repository_from_gcs(os.environ["AIP_STORAGE_URI"], target_dir=HF_MODEL_DIR)
57+
# 4. if not available, raise error
5058
else:
5159
raise ValueError(
5260
f"""Can't initialize model.
@@ -72,7 +80,7 @@ async def predict(request):
7280
# try to deserialize payload
7381
deserialized_body = ContentType.get_deserializer(content_type).deserialize(await request.body())
7482
# checks if input schema is correct
75-
if "inputs" not in deserialized_body:
83+
if "inputs" not in deserialized_body and "instances" not in deserialized_body:
7684
raise ValueError(f"Body needs to provide a inputs key, recieved: {orjson.dumps(deserialized_body)}")
7785

7886
# check for query parameter and add them to the body
@@ -97,14 +105,31 @@ async def predict(request):
97105
logger.error(e)
98106
return Response(Jsoner.serialize({"error": str(e)}), status_code=400, media_type="application/json")
99107

100-
101-
app = Starlette(
102-
debug=True,
103-
routes=[
104-
Route("/", health, methods=["GET"]),
105-
Route("/health", health, methods=["GET"]),
106-
Route("/", predict, methods=["POST"]),
107-
Route("/predict", predict, methods=["POST"]),
108-
],
109-
on_startup=[some_startup_task],
108+
# Create app based on which cloud environment is used
109+
if os.getenv("AIP_MODE", None) == "PREDICTION":
110+
logger.info("Running in Vertex AI environment")
111+
# extract routes from environment variables
112+
_predict_route = os.getenv("AIP_PREDICT_ROUTE", None)
113+
_health_route = os.getenv("AIP_HEALTH_ROUTE", None)
114+
if _predict_route is None or _health_route is None:
115+
raise ValueError("AIP_PREDICT_ROUTE and AIP_HEALTH_ROUTE need to be set in Vertex AI environment")
116+
117+
app = Starlette(
118+
debug=False,
119+
routes=[
120+
Route(_health_route, health, methods=["GET"]),
121+
Route(_predict_route, predict, methods=["POST"]),
122+
],
123+
on_startup=[prepare_model_artifacts],
124+
)
125+
else:
126+
app = Starlette(
127+
debug=False,
128+
routes=[
129+
Route("/", health, methods=["GET"]),
130+
Route("/health", health, methods=["GET"]),
131+
Route("/", predict, methods=["POST"]),
132+
Route("/predict", predict, methods=["POST"]),
133+
],
134+
on_startup=[prepare_model_artifacts],
110135
)

0 commit comments

Comments
 (0)