diff --git a/.github/workflows/build-container.yaml b/.github/workflows/build-container.yaml index fe12fbf6..3ad8ad5f 100644 --- a/.github/workflows/build-container.yaml +++ b/.github/workflows/build-container.yaml @@ -34,6 +34,16 @@ jobs: TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }} REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} + starlette-pytorch-neuron: + uses: ./.github/workflows/docker-build-action.yaml + with: + image: inference-pytorch-neuron + dockerfile: dockerfiles/pytorch/Dockerfile + build_args: "BASE_IMAGE=ubuntu:22.04,NEURONX=1" + secrets: + TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }} + REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} + REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} starlette-tensorflow-cpu: uses: ./.github/workflows/docker-build-action.yaml with: diff --git a/.github/workflows/docker-build-action.yaml b/.github/workflows/docker-build-action.yaml index 62cba961..37889934 100644 --- a/.github/workflows/docker-build-action.yaml +++ b/.github/workflows/docker-build-action.yaml @@ -64,8 +64,8 @@ jobs: context: ${{ inputs.context }} build-args: ${{ inputs.build_args }} file: ${{ inputs.context }}/${{ inputs.dockerfile }} - tags: ${{ inputs.repository }}/${{ inputs.image }}:sha-${{ env.GITHUB_SHA_SHORT }},${{ inputs.repository }}/${{ inputs.image }}:latest - + # tags: ${{ inputs.repository }}/${{ inputs.image }}:sha-${{ env.GITHUB_SHA_SHORT }},${{ inputs.repository }}/${{ inputs.image }}:latest + tags: ${{ inputs.repository }}/${{ inputs.image }}:testraph - name: Tailscale Wait if: ${{ failure() || runner.debug == '1' }} uses: huggingface/tailscale-action@v1 diff --git a/dockerfiles/pytorch/Dockerfile b/dockerfiles/pytorch/Dockerfile index 8e4c4d35..d01f0df2 100644 --- a/dockerfiles/pytorch/Dockerfile +++ b/dockerfiles/pytorch/Dockerfile @@ -1,6 +1,9 @@ ARG BASE_IMAGE=nvidia/cuda:12.1.0-devel-ubuntu22.04 FROM $BASE_IMAGE + +ARG NEURONX=0 + SHELL ["/bin/bash", "-c"] LABEL maintainer="Hugging Face" @@ -31,12 +34,12 @@ RUN apt-get update && \ libsndfile1-dev \ ffmpeg \ && apt-get clean autoremove --yes \ - && rm -rf /var/lib/{apt,dpkg,cache,log} + && rm -rf /var/lib/{apt,cache,log} + # Copying only necessary files as filtered by .dockerignore COPY . . -# install wheel and setuptools -RUN pip install --no-cache-dir -U pip ".[torch, st, diffusers]" +RUN if [[ "$NEURONX" == "1" ]];then /bin/bash -c "./dockerfiles/pytorch/neuronx.sh";else pip install --no-cache-dir -U pip ".[torch, st, diffusers]";fi # copy application COPY src/huggingface_inference_toolkit huggingface_inference_toolkit @@ -45,4 +48,4 @@ COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starle # copy entrypoint and change permissions COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh -ENTRYPOINT ["bash", "-c", "./entrypoint.sh"] \ No newline at end of file +ENTRYPOINT ["bash", "-c", "./entrypoint.sh"] diff --git a/dockerfiles/pytorch/neuronx.sh b/dockerfiles/pytorch/neuronx.sh new file mode 100755 index 00000000..9b8f35f4 --- /dev/null +++ b/dockerfiles/pytorch/neuronx.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e + +# Install system prerequisites +apt-get update -y \ + && apt-get install -y --no-install-recommends \ + gnupg2 \ + wget + +. /etc/os-release +tee /etc/apt/sources.list.d/neuron.list > /dev/null <> /root/.bashrc diff --git a/makefile b/makefile index a9490428..d6c8a53e 100644 --- a/makefile +++ b/makefile @@ -26,5 +26,8 @@ inference-pytorch-gpu: inference-pytorch-cpu: docker build --build-arg="BASE_IMAGE=ubuntu:22.04" -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:cpu . +inference-pytorch-neuron: + docker build --build-arg=BASE_IMAGE=ubuntu:22.04 --build-arg=NEURONX=1 -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:neuron . + stop-all: - docker stop $$(docker ps -a -q) && docker container prune --force \ No newline at end of file + docker stop $$(docker ps -a -q) && docker container prune --force diff --git a/setup.py b/setup.py index 5e99df02..98543268 100644 --- a/setup.py +++ b/setup.py @@ -17,8 +17,8 @@ "wheel==0.42.0", "setuptools==69.1.0", "cmake==3.28.3", - "transformers[sklearn,sentencepiece, audio, vision]==4.38.2", - "huggingface_hub==0.20.3", + "transformers[sklearn,sentencepiece, audio, vision]>=4.38.2", + "huggingface_hub==0.23.0", "orjson", # vision "Pillow", @@ -39,6 +39,8 @@ extras["st"] = ["sentence_transformers==2.4.0"] extras["diffusers"] = ["diffusers==0.26.3", "accelerate==0.27.2"] extras["torch"] = ["torch==2.2.0", "torchvision", "torchaudio"] +# For neuronx +extras["torch-neuronx"] = ["torch-neuronx", "torchvision", "torchaudio"] extras["tensorflow"] = ["tensorflow"] extras["test"] = [ "pytest==7.2.1", diff --git a/src/huggingface_inference_toolkit/diffusers_utils.py b/src/huggingface_inference_toolkit/diffusers_utils.py index 521a85df..21795e7a 100644 --- a/src/huggingface_inference_toolkit/diffusers_utils.py +++ b/src/huggingface_inference_toolkit/diffusers_utils.py @@ -1,5 +1,7 @@ import importlib.util +import json import logging +import os from transformers.utils.import_utils import is_torch_bf16_gpu_available @@ -7,6 +9,11 @@ logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO) _diffusers = importlib.util.find_spec("diffusers") is not None +_optimum = importlib.util.find_spec("optimum") is not None +if _optimum: + _optimum_neuron = importlib.util.find_spec("optimum.neuron") is not None +else: + _optimum_neuron = False def is_diffusers_available(): @@ -18,6 +25,10 @@ def is_diffusers_available(): from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline +if _optimum_neuron: + from optimum import neuron + + class IEAutoPipelineForText2Image: def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU dtype = torch.float32 @@ -55,8 +66,91 @@ def __call__( } -def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs): +def _is_neuron_model(model_dir): + for root, _, files in os.walk(model_dir): + for f in files: + if f == "config.json": + filename = os.path.join(root, f) + with open(filename, 'r') as fh: + try: + config = json.load(fh) + except Exception as e: + logger.warning("Unable to load config %s properly, skipping", filename) + logger.exception(e) + continue + if 'neuron' in config.keys(): + return True + return False + + +def neuron_diffusion_pipeline(task: str, model_dir: str): + + # Step 1: load config and look for _class_name + try: + config = StableDiffusionPipeline.load_config(pretrained_model_name_or_path=model_dir) + except OSError as e: + logger.error("Unable to load config file for repository %s", model_dir) + logger.exception(e) + raise + + pipeline_class_name = config['_class_name'] + + logger.debug("Repository pipeline class name %s", pipeline_class_name) + if "Diffusion" in pipeline_class_name and "XL" in pipeline_class_name: + if task == "image-to-image": + pipeline_class = neuron.NeuronStableDiffusionXLImg2ImgPipeline + else: + pipeline_class = neuron.NeuronStableDiffusionXLPipeline + else: + if task == "image-to-image": + pipeline_class = neuron.NeuronStableDiffusionImg2ImgPipeline + else: + pipeline_class = neuron.NeuronStableDiffusionPipeline + + logger.debug("Pipeline class %s", pipeline_class.__class__) + + compiler_args = { + "auto_cast": "matmul", + "auto_cast_type": "bf16", + "inline_weights_to_neff": os.environ.get("INLINE_WEIGHTS_TO_NEFF", + "false").lower() in ["false", "no", "0"], + "data_parallel_mode": os.environ.get("DATA_PARALLEL_MODE", "unet") + } + input_shapes = {"batch_size": 1, + "height": int(os.environ.get("IMAGE_HEIGHT", 512)), + "width": int(os.environ.get("IMAGE_WIDTH", 512))} + export_kwargs = {**compiler_args, **input_shapes, "export": True} + + # if is neuron model, no need for additional kwargs, any info lies within the repo + is_neuron_m = _is_neuron_model(model_dir) + if is_neuron_m: + kwargs = {} + fallback_kwargs = export_kwargs + else: + kwargs = export_kwargs + fallback_kwargs = {} + + # In the second case, exporting can take a huge amount of time, which makes endpoints not a really suited solution + # at least as long as the cache is not really an option for diffusion + try: + logger.info("Loading model %s with kwargs %s", model_dir, kwargs) + return pipeline_class.from_pretrained(model_dir, **kwargs) + except Exception as e: + logger.error("Unable to load model %s properly falling back to kwargs %s", model_dir, fallback_kwargs) + logger.exception(e) + return pipeline_class.from_pretrained(model_dir, **fallback_kwargs) + + +def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **_kwargs): """Get a pipeline for Diffusers models.""" - device = "cuda" if device == 0 else "cpu" - pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device) + if device == 0: + device = "cuda" + elif device is not None: + device = "cpu" + # None case: neuronx, no need to specify device + + if device is not None: + pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device) + else: + pipeline = neuron_diffusion_pipeline(task=task, model_dir=model_dir) return pipeline diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py index 1570317b..65e3a6b4 100644 --- a/src/huggingface_inference_toolkit/utils.py +++ b/src/huggingface_inference_toolkit/utils.py @@ -30,6 +30,14 @@ import torch _optimum_available = importlib.util.find_spec("optimum") is not None +if _optimum_available: + _optimum_neuron = importlib.util.find_spec("optimum.neuron") is not None + from optimum.neuron.modeling_decoder import get_available_cores as get_neuron_cores +else: + _optimum_neuron = False + + def get_neuron_cores(): + return 0 def is_optimum_available(): @@ -38,6 +46,10 @@ def is_optimum_available(): # return _optimum_available +def is_optimum_neuron_available(): + return _optimum_neuron + + framework2weight = { "pytorch": "pytorch*", "tensorflow": "tf*", @@ -215,6 +227,8 @@ def get_device(): if gpu: return 0 + elif get_neuron_cores() > 0: + return None else: return -1 @@ -229,7 +243,10 @@ def get_pipeline( create pipeline class for a specific task based on local saved model """ device = get_device() - logger.info(f"Using device { 'GPU' if device == 0 else 'CPU'}") + logger.info(f"Using device { 'GPU' if device == 0 else 'Neuron' if device is None else 'CPU'}") + + if device is None and task != "text-to-image": + raise Exception("This container only supports text-to-image task with neurons") if task is None: raise EnvironmentError( diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 8bc68b2e..e06ab447 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path from time import perf_counter @@ -23,7 +24,7 @@ def config_logging(level=logging.INFO): - logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", datefmt="", level=level) + logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", datefmt="", level=level, force=True) # disable uvicorn access logs to hide /health uvicorn_access = logging.getLogger("uvicorn.access") uvicorn_access.disabled = True @@ -31,7 +32,7 @@ def config_logging(level=logging.INFO): logging.getLogger("uvicorn").removeHandler(logging.getLogger("uvicorn").handlers[0]) -config_logging() +config_logging(os.environ.get("LOG_LEVEL", logging.getLevelName(logging.INFO))) logger = logging.getLogger(__name__) @@ -50,7 +51,7 @@ async def some_startup_task(): else: raise ValueError( f"""Can't initialize model. - Please set env HF_MODEL_DIR or provider a HF_MODEL_ID. + Please set env HF_MODEL_DIR or provide a HF_MODEL_ID. Provided values are: HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" )