Skip to content

Commit

Permalink
Merge branch 'main' into ci_amd2
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jun 7, 2024
2 parents c8128c7 + bf3c813 commit c73355b
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 93 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ jobs:
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [amd-gpu-tgi, multi-gpu, mi250]
needs:
- build-and-push-image
needs: build-and-push-image
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
10 changes: 7 additions & 3 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && apt install -y intel-basekit xpu-smi
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build

# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
Expand All @@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \


WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope

# Install server
COPY proto proto
Expand All @@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include

RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
8 changes: 8 additions & 0 deletions server/tests/models/test_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand All @@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
1 change: 1 addition & 0 deletions server/tests/models/test_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True,
truncate=100,
parameters=default_pb_parameters,
Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple, List, Type, Dict

from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Expand Down Expand Up @@ -86,7 +87,8 @@ def from_pb(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))

next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
9 changes: 6 additions & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from typing import Iterable, Optional, Tuple, List, Type, Dict

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
Expand Down Expand Up @@ -127,11 +128,13 @@ def to_pb(self) -> generate_pb2.CachedBatch:
)

@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
batch_inputs = []
max_truncation = 0
for r in requests:
batch_inputs.append(r.inputs)
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/models/galactica.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks

# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py

Expand Down Expand Up @@ -91,7 +92,9 @@ def from_pb(
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
inputs.append(
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
21 changes: 12 additions & 9 deletions server/text_generation_server/models/idefics_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from io import BytesIO
from PIL import Image
import torch
import time

Expand All @@ -21,11 +22,6 @@
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split

import re

IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")


tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -109,7 +105,7 @@ def from_pb_processor(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(r.input_chunks.chunks)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand All @@ -128,8 +124,15 @@ def from_pb_processor(
for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = []
for chunk in split(inp):
prompt.append(chunk["content"])
for chunk in inp:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt)

# The processor replaces the call to tokenizer, and
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Generation,
GeneratedText,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
Expand Down Expand Up @@ -139,7 +140,7 @@ def from_pb(
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
Expand Down
39 changes: 16 additions & 23 deletions server/text_generation_server/models/pali_gemma.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,48 @@
from io import BytesIO
from PIL import Image
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
from transformers import AutoProcessor, AutoConfig

from text_generation_server.pb.generate_pb2 import Request

tracer = trace.get_tracer(__name__)


class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += "<bos>" + chunk.text + "\n"
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
raise RuntimeError(f"Invalid chunk type {chunk_type}")

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict

from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
Expand Down Expand Up @@ -93,7 +94,7 @@ def from_pb(
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
inputs.append(r.inputs)
inputs.append(concat_text_chunks(r.input_chunks.chunks))
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(
Expand Down
Loading

0 comments on commit c73355b

Please sign in to comment.