diff --git a/Makefile b/Makefile index e9ec19103..036eaca83 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \ $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES) python -m build -TGI_VERSION ?= 1.4.0 +TGI_VERSION ?= 1.4.1 neuronx-tgi: $(PACKAGE_DIST) docker build --rm -f text-generation-inference/Dockerfile \ diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 240859e69..303e56793 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -379,7 +379,7 @@ class NeuronDecoderConfig(NeuronConfig): be passed to export the model, - NEURONX_CLASS (`str`) -- the name of the transformers-neuronx class to instantiate for the model. It is a full class name defined relatively to the transformers-neuronx module, e.g. `gpt2.model.GPT2ForSampling` - [`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs. + - CONTINUOUS_BATCHING (`bool`, , defaults to `False`) -- Whether the model supports continuous batching or not. The NEURONX_CLASS must always be defined in each model configuration. @@ -389,6 +389,7 @@ class NeuronDecoderConfig(NeuronConfig): INPUT_ARGS = ("batch_size", "sequence_length") NEURONX_CLASS = None + CONTINUOUS_BATCHING = False def __init__(self, task: str): if not is_transformers_neuronx_available(): @@ -404,3 +405,7 @@ def __init__(self, task: str): @property def neuronx_class(self): return self._neuronx_class + + @property + def continuous_batching(self): + return self.CONTINUOUS_BATCHING diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 1b6ce4b2e..045589f3b 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -431,6 +431,7 @@ class GPT2NeuronConfig(TextNeuronDecoderConfig): @register_in_tasks_manager("llama", "text-generation") class LLamaNeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = "llama.model.LlamaForSampling" + CONTINUOUS_BATCHING = True @register_in_tasks_manager("t5-encoder", "text2text-generation") @@ -533,3 +534,4 @@ def generate_io_aliases(self, model): @register_in_tasks_manager("mistral", "text-generation") class MistralNeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = "mistral.model.MistralForSampling" + CONTINUOUS_BATCHING = True diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index fa2fdb574..8d26c83fc 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -16,7 +16,7 @@ import copy import logging -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch from transformers import ( @@ -656,14 +656,14 @@ def __init__( generation_config: Optional["GenerationConfig"] = None, ): super().__init__(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config) - self.cur_len = 0 - self.batch_size = self.model.config.batch_size - self.max_length = self.model.config.n_positions + self.batch_size = self.config.neuron["batch_size"] + self.max_length = self.config.neuron["sequence_length"] + self.continuous_batching = self.model.neuron_config and self.model.neuron_config.continuous_batching # The generate method from GenerationMixin expects the device attribute to be set self.device = torch.device("cpu") def reset_generation(self): - self.cur_len = 0 + pass @add_start_docstrings_to_model_forward( NEURON_CAUSALLM_MODEL_FORWARD_DOCSTRING @@ -688,32 +688,78 @@ def forward( return ModelOutput([("logits", out_logits)]) return (out_logits,) - def prepare_inputs_for_generation( - self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs - ) -> Dict[str, torch.Tensor]: - # convert attention_mask to start_ids + def get_start_ids( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + seq_ids: Optional[torch.Tensor] = None, + ): + # The start_ids parameter has different meanings: + # - for continuous (unpadded) batching it corresponds to the sequence id, + # - for static batching it corresponds to the start of the padded sequence. + if self.continuous_batching: + if seq_ids is None: + seq_ids = torch.arange(input_ids.shape[0]) + else: + assert seq_ids.shape[0] == input_ids.shape[0] + return seq_ids start_ids = None if attention_mask is not None: _, start_ids = attention_mask.max(axis=1) - - if self.cur_len > 0: - # Only pass the last tokens of each sample - input_ids = input_ids[:, -1:] - # Specify the single index at which the new keys and values need to be stored - cache_ids = torch.as_tensor([self.cur_len], dtype=torch.int32) - else: - # cache_ids will be set directly by the parallel context encoding code - cache_ids = None - - # Increment the current cache index - self.cur_len += input_ids.shape[-1] - model_inputs = { + return start_ids + + def get_cache_ids(self, attention_mask: torch.tensor, prefill: bool): + cache_n, cache_len = attention_mask.shape + if self.continuous_batching: + # Evaluate the inputs that are not masked for each sequence + input_length = attention_mask.sum(axis=1) + if not prefill: + # When decoding, cache_ids contains a single value per sequence + return (input_length - 1).unsqueeze(1) + # When prefilling, cache_ids is an increasing range + cache_ids = torch.zeros_like(attention_mask) + for i in range(cache_n): + cur_length = input_length[i] + cache_ids[i, :cur_length] = torch.arange(cur_length) + return cache_ids + # Static batching + return None if prefill else torch.tensor([cache_len - 1], dtype=torch.int32) + + def prepare_inputs_for_prefill( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, seq_ids: Optional[List[int]] = None + ) -> Dict[str, torch.Tensor]: + start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids) + cache_ids = self.get_cache_ids(attention_mask, prefill=True) + if self.continuous_batching and torch.any(attention_mask[:, 0] == 0): + # Inputs are left padded: we need to invert padding as continuous batching requires right-padding + batch_size, seq_len = input_ids.shape + input_length = attention_mask.sum(axis=1) + new_input_ids = torch.zeros_like(input_ids) + for i in range(batch_size): + cur_length = input_length[i] + new_input_ids[i, :cur_length] = input_ids[i, seq_len - cur_length :] + input_ids = new_input_ids + return { "input_ids": input_ids, "cache_ids": cache_ids, "start_ids": start_ids, } - return model_inputs + def prepare_inputs_for_decode( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + seq_ids: Optional[List[int]] = None, + ) -> Dict[str, torch.Tensor]: + start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids) + cache_ids = self.get_cache_ids(attention_mask, prefill=False) + # Only pass the last tokens of each sample + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "cache_ids": cache_ids, + "start_ids": start_ids, + } def can_generate(self) -> bool: """Returns True to validate the check made in `GenerationMixin.generate()`.""" @@ -775,7 +821,7 @@ def generate( f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})" ) padded_input_ids = input_ids - padded_attention_mask = attention_mask + padded_attention_mask = torch.ones_like(input_ids) if attention_mask is None else attention_mask if batch_size > self.batch_size: raise ValueError( f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})" @@ -784,18 +830,15 @@ def generate( logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.") padding_shape = [self.batch_size - batch_size, sequence_length] padding = torch.full(padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64) - padded_input_ids = torch.cat([input_ids, padding]) - if attention_mask is not None: - padding = torch.zeros(padding_shape, dtype=torch.int64) - padded_attention_mask = torch.cat([attention_mask, padding]) - # Drop the current generation context and clear the Key/Value cache - self.reset_generation() + padded_input_ids = torch.cat([padded_input_ids, padding]) + padding = torch.zeros(padding_shape, dtype=torch.int64) + padded_attention_mask = torch.cat([padded_attention_mask, padding]) output_ids = self.generate_tokens( padded_input_ids, selector, batch_size, - attention_mask=padded_attention_mask, + padded_attention_mask, **model_kwargs, ) return output_ids[:batch_size, :] @@ -805,7 +848,7 @@ def generate_tokens( input_ids: torch.LongTensor, selector: TokenSelector, batch_size: int, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, **model_kwargs, ) -> torch.LongTensor: r""" @@ -831,17 +874,15 @@ def generate_tokens( unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) unfinished_sequences[:batch_size] = 1 + # Prefill and obtain the first token + model_inputs = self.prepare_inputs_for_prefill(input_ids, attention_mask) + outputs = self( + **model_inputs, + return_dict=True, + ) + # auto-regressive generation while True: - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - ) - next_token_logits = outputs.logits[:, -1, :] next_tokens = selector.select(input_ids, next_token_logits) @@ -851,10 +892,7 @@ def generate_tokens( # update inputs for the next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) + attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) # if eos_token was found in one sentence, set sentence to finished unfinished_sequences = unfinished_sequences * next_tokens.ne(selector.eos_token_id) @@ -867,4 +905,11 @@ def generate_tokens( if selector.stopping_criteria(input_ids, None): break + # forward pass to get next token + model_inputs = self.prepare_inputs_for_decode(input_ids, attention_mask) + outputs = self( + **model_inputs, + return_dict=True, + ) + return input_ids diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index fdf6fbaa7..e09b0a18e 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -35,6 +35,7 @@ if is_transformers_neuronx_available(): + from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig from transformers_neuronx.module import save_split @@ -131,16 +132,26 @@ def __init__( exporter = get_exporter(config, task) - # transformers-neuronx uses f32/f16 instead of fp32/fp16 - auto_cast_type = auto_cast_type.replace("p", "") + tnx_kwargs = { + "batch_size": batch_size, + "tp_degree": num_cores, + # transformers-neuronx uses f32/f16 instead of fp32/fp16 + "amp": auto_cast_type.replace("p", ""), + } + if batch_size > 1 and exporter.continuous_batching: + # Continuous batching is always enabled for models that support it because static batching + # is broken for these models: see https://github.com/aws-neuron/transformers-neuronx/issues/79 + tnx_kwargs["neuron_config"] = NeuronConfig( + continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size) + ) + tnx_kwargs["n_positions"] = [sequence_length] + tnx_kwargs["context_length_estimate"] = [sequence_length] + else: + tnx_kwargs["n_positions"] = sequence_length + + # Instantiate neuronx model checkpoint_path = checkpoint_dir.name if isinstance(checkpoint_dir, TemporaryDirectory) else checkpoint_dir - neuronx_model = exporter.neuronx_class.from_pretrained( - checkpoint_path, - batch_size=batch_size, - n_positions=sequence_length, - tp_degree=num_cores, - amp=auto_cast_type, - ) + neuronx_model = exporter.neuronx_class.from_pretrained(checkpoint_path, **tnx_kwargs) if compiled_dir is not None: # Specify the path where compiled artifacts are stored before conversion diff --git a/tests/generation/conftest.py b/tests/generation/conftest.py index b9d70505f..9d69c5579 100644 --- a/tests/generation/conftest.py +++ b/tests/generation/conftest.py @@ -72,7 +72,7 @@ def export_seq2seq_model_class(request): @requires_neuronx def neuron_decoder_path(export_decoder_id): model = NeuronModelForCausalLM.from_pretrained( - export_decoder_id, export=True, batch_size=1, sequence_length=100, num_cores=2 + export_decoder_id, export=True, batch_size=2, sequence_length=100, num_cores=2 ) model_dir = TemporaryDirectory() model_path = model_dir.name diff --git a/tests/generation/test_generate.py b/tests/generation/test_generate.py index 92170b40a..41eb4bc08 100644 --- a/tests/generation/test_generate.py +++ b/tests/generation/test_generate.py @@ -18,20 +18,12 @@ import pytest import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.generation import StoppingCriteria -from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM +from optimum.neuron import NeuronModelForSeq2SeqLM from optimum.neuron.utils.testing_utils import is_inferentia_test, is_trainium_test, requires_neuronx from optimum.neuron.utils.training_utils import patch_generation_mixin_to_general_neuron_generation_mixin -def _test_model_generation(model, tokenizer, batch_size, input_length, **gen_kwargs): - input_ids = torch.ones((batch_size, input_length), dtype=torch.int64) - with torch.inference_mode(): - sample_output = model.generate(input_ids, **gen_kwargs) - assert sample_output.shape[0] == batch_size - - def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen_kwargs): import torch_xla.core.xla_model as xm @@ -43,58 +35,6 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen assert sample_output.shape[0] == batch_size -@pytest.mark.parametrize( - "gen_kwargs", - [ - {"do_sample": True}, - {"do_sample": True, "temperature": 0.7}, - {"do_sample": False}, - {"do_sample": False, "repetition_penalty": 1.2}, - ], - ids=["sample", "sample-with-temp", "greedy", "greedy_no-repeat"], -) -@is_inferentia_test -@requires_neuronx -def test_decoder_generation(neuron_decoder_path, gen_kwargs): - model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) - tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path) - _test_model_generation(model, tokenizer, model.batch_size, 10, **gen_kwargs) - - -@is_inferentia_test -@requires_neuronx -def test_model_generation_input_dimensions(neuron_decoder_path): - model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) - tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path) - # Using valid input dimensions - _test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2) - # Using an incompatible batch_size - with pytest.raises(ValueError, match="The specified batch_size"): - _test_model_generation(model, tokenizer, model.batch_size + 1, model.max_length) - # Using an incompatible input length - with pytest.raises(ValueError, match="The input sequence length"): - _test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2) - - -@is_inferentia_test -@requires_neuronx -def test_decoder_generation_custom_stopping_criteria(): - model_id = "hf-internal-testing/tiny-random-gpt2" - model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=1) - - class CustomStoppingCriteria(StoppingCriteria): - def __init__(self): - self.called = False - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - self.called = True - return True - - criteria = CustomStoppingCriteria() - model.generate(input_ids=torch.ones([1, 10], dtype=torch.int64), stopping_criteria=[criteria]) - assert criteria.called, "Custom StoppingCriteria should have been called" - - @is_inferentia_test @requires_neuronx def test_seq2seq_generation_beam(neuron_seq2seq_beam_path): diff --git a/tests/generation/test_tnx_generate.py b/tests/generation/test_tnx_generate.py new file mode 100644 index 000000000..94b0f06c8 --- /dev/null +++ b/tests/generation/test_tnx_generate.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from transformers import AutoTokenizer +from transformers.generation import StoppingCriteria + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx + + +def _test_generation(model, batch_size, input_length, **gen_kwargs): + input_ids = torch.ones((batch_size, input_length), dtype=torch.int64) + with torch.inference_mode(): + sample_output = model.generate(input_ids, **gen_kwargs) + assert sample_output.shape[0] == batch_size + + +@pytest.mark.parametrize( + "gen_kwargs", + [ + {"do_sample": True}, + {"do_sample": True, "temperature": 0.7}, + {"do_sample": False}, + {"do_sample": False, "repetition_penalty": 1.2}, + ], + ids=["sample", "sample-with-temp", "greedy", "greedy_no-repeat"], +) +@is_inferentia_test +@requires_neuronx +def test_decoder_generation(neuron_decoder_path, gen_kwargs): + model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) + _test_generation(model, model.batch_size, 10, **gen_kwargs) + + +@is_inferentia_test +@requires_neuronx +def test_model_generation_input_dimensions(neuron_decoder_path): + model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) + AutoTokenizer.from_pretrained(neuron_decoder_path) + # Using valid input dimensions + _test_generation(model, model.batch_size, model.max_length // 2) + # Using an incompatible batch_size + with pytest.raises(ValueError, match="The specified batch_size"): + _test_generation(model, model.batch_size + 1, model.max_length) + # Using an incompatible input length + with pytest.raises(ValueError, match="The input sequence length"): + _test_generation(model, model.batch_size, input_length=model.max_length * 2) + + +@is_inferentia_test +@requires_neuronx +def test_decoder_generation_custom_stopping_criteria(neuron_decoder_path): + model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) + + class CustomStoppingCriteria(StoppingCriteria): + def __init__(self): + self.called = False + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + self.called = True + return True + + criteria = CustomStoppingCriteria() + model.generate(input_ids=torch.ones([1, 10], dtype=torch.int64), stopping_criteria=[criteria]) + assert criteria.called, "Custom StoppingCriteria should have been called" + + +@is_inferentia_test +@requires_neuronx +def test_decoder_generation_padded_inputs(neuron_decoder_path): + model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path) + assert model.batch_size >= 2 + tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path) + prompt = ( + "It was a bright cold day in April, and the clocks were striking thirteen." + " Winston Smith, his chin nuzzled into his breast in an effort to escape the" + " vile wind, slipped quickly through the glass doors of Victory Mansions," + ) + first_input = tokenizer(prompt) + first_ids = first_input["input_ids"] + first_mask = first_input["attention_mask"] + max_padding = 12 + input_len = len(first_ids) + for i in range(max_padding): + second_ids = [tokenizer.eos_token_id] * i + first_ids[: input_len - i] + second_mask = [0] * i + [1] * (input_len - i) + input_ids = torch.tensor([first_ids, second_ids], dtype=torch.int64) + attention_mask = torch.tensor([first_mask, second_mask], dtype=torch.int64) + outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=False) + # Verify we did not generate any unknown token + assert torch.all(outputs[:, -1] != 0) diff --git a/tests/generation/test_tnx_llama.py b/tests/generation/test_tnx_llama.py new file mode 100644 index 000000000..b2aff36b0 --- /dev/null +++ b/tests/generation/test_tnx_llama.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import AutoTokenizer + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx + + +@is_inferentia_test +@requires_neuronx +def test_generation_llama_padded_inputs(): + model_id = "NousResearch/Llama-2-7b-chat-hf" + model_kwargs = {"batch_size": 2, "sequence_length": 2048, "auto_cast_type": "f16", "num_cores": 2} + model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "One of my fondest memory is of my grandmother making homemade bread" + first_input = tokenizer(prompt) + first_ids = first_input["input_ids"] + first_mask = first_input["attention_mask"] + max_padding = 12 + input_len = len(first_ids) + for i in range(max_padding): + second_ids = [tokenizer.eos_token_id] * i + first_ids[: input_len - i] + second_mask = [0] * i + [1] * (input_len - i) + input_ids = torch.tensor([first_ids, second_ids], dtype=torch.int64) + attention_mask = torch.tensor([first_mask, second_mask], dtype=torch.int64) + outputs = model.generate( + input_ids=input_ids, attention_mask=attention_mask, do_sample=False, max_new_tokens=10 + ) + # Verify we did not generate any unknown token + assert torch.all(outputs[:, -1] != 0) diff --git a/text-generation-inference/Dockerfile b/text-generation-inference/Dockerfile index 28568edc2..c5c65e152 100644 --- a/text-generation-inference/Dockerfile +++ b/text-generation-inference/Dockerfile @@ -92,8 +92,8 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU RUN apt-get update -y \ && apt-get install -y --no-install-recommends \ aws-neuronx-dkms=2.15.9.0 \ - aws-neuronx-collectives=2.20.11.0 \ - aws-neuronx-runtime-lib=2.20.11.0 \ + aws-neuronx-collectives=2.20.11.0-c101c322e \ + aws-neuronx-runtime-lib=2.20.11.0-b7d33e68b \ aws-neuronx-tools=2.17.0.0 \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -102,7 +102,7 @@ ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}" RUN pip3 install \ neuronx-cc==2.12.68.0 \ - torch-neuronx==1.13.1.1.13.0 \ + torch-neuronx==1.13.1.1.13.1 \ transformers-neuronx==0.9.474 \ --extra-index-url=https://pip.repos.neuron.amazonaws.com diff --git a/text-generation-inference/README.md b/text-generation-inference/README.md index c0c2fd949..adfe8b983 100644 --- a/text-generation-inference/README.md +++ b/text-generation-inference/README.md @@ -100,11 +100,9 @@ docker run -p 8080:80 \ -e HF_TOKEN=${HF_TOKEN} \ ghcr.io/huggingface/neuronx-tgi:latest \ --model-id aws-neuron/Llama-2-7b-hf-neuron-budget \ - --max-concurrent-requests 1 \ + --max-batch-size 1 \ --max-input-length 1024 \ - --max-total-tokens 2048 \ - --max-batch-prefill-tokens 1024 \ - --max-batch-total-tokens 2048 + --max-total-tokens 2048 ``` ### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) @@ -130,11 +128,9 @@ docker run -p 8080:80 \ -e HF_NUM_CORES=2 \ ghcr.io/huggingface/neuronx-tgi:latest \ --model-id aws-neuron/Llama-2-7b-hf-neuron-budget \ - --max-concurrent-requests 1 \ + --max-batch-size 1 \ --max-input-length 512 \ - --max-total-tokens 1024 \ - --max-batch-prefill-tokens 512 \ - --max-batch-total-tokens 1024 + --max-total-tokens 1024 ``` ### Using a model exported to a local path @@ -162,15 +158,11 @@ The configuration of an inference endpoint is always a compromise between throug The neuron models have static input dimensions `[batch_size, max_length]`. -It leads to a maximum number of tokens of `max_tokens = batch_size * max_length`. - This adds several restrictions to the following parameters: -- `--max-concurrent-requests` must be set to `batch size`, +- `--max-batch-size` must be set to `batch size`, - `--max-input-length` must be lower than `max_length`, -- `--max-total-tokens` must be set to `max_length` (it is per-request), -- `--max-batch-prefill-tokens` must be set to `batch_size * max_input_length`, -- `--max-batch-total-tokens` must be set to `max_tokens`. +- `--max-total-tokens` must be set to `max_length` (it is per-request). ### Choosing the correct batch size diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile index e16ab6585..dc8bf3290 100644 --- a/text-generation-inference/server/Makefile +++ b/text-generation-inference/server/Makefile @@ -2,7 +2,7 @@ pkg_name := text_generation_server BUILDDIR ?= $(CURDIR)/build VERSION ?= 0.0.1 -TGI_VERSION ?= 1.4.0 +TGI_VERSION ?= 1.4.1 mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) pkg_dir := $(BUILDDIR)/$(pkg_name) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index a568d1049..a867a4583 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -381,12 +381,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64) slot_attention_mask = attention_mask[i] slot.reset(slot_input_ids, slot_attention_mask, selector) - # Clear KV cache - self.model.reset_generation() # Pause previously active slots during generation. # The KV cache of paused slots will be prefilled during generation but new tokens # will be ignored, as they have already been generated and sent back in the last decode. - generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask) + model_inputs = self.model.prepare_inputs_for_prefill(input_ids, attention_mask) + logits = self.model(**model_inputs)[0] + generation, next_batch = self._generate_token(batch.id, logits, input_ids) # Reactivate previously active slots for the next decode, and append # back their next token. for slot, next_token in zip(active_slots, next_tokens): @@ -433,23 +433,20 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa attention_mask[i, :] = slot.attention_mask if input_ids is None: raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") - return self._generate_token(next_batch_id, input_ids, attention_mask) + model_inputs = self.model.prepare_inputs_for_decode(input_ids, attention_mask) + logits = self.model(**model_inputs)[0] + return self._generate_token(next_batch_id, logits, input_ids) def _generate_token( - self, next_batch_id: int, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None + self, next_batch_id: int, logits: torch.Tensor, input_ids: torch.LongTensor ) -> Tuple[List[Generation], CachedBatch]: - model_inputs = self.model.prepare_inputs_for_generation(input_ids, attention_mask) - outputs = self.model( - **model_inputs, - return_dict=True, - ) generations = [] active_slots = False for i, slot in enumerate(self.slots): if slot.state != Slot.State.READY: continue request_id = slot.request_id - next_token_logits = outputs.logits[i : i + 1, -1, :] + next_token_logits = logits[i : i + 1, -1, :] slot_input_ids = input_ids[i : i + 1, :] next_token = slot.select(slot_input_ids, next_token_logits) next_token_text = slot.append(next_token)