Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for transformers-neuronx continuous batching #488

Merged
merged 13 commits into from
Feb 19, 2024
Merged
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \
$(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES)
python -m build

TGI_VERSION ?= 1.4.0
TGI_VERSION ?= 1.4.1

neuronx-tgi: $(PACKAGE_DIST)
docker build --rm -f text-generation-inference/Dockerfile \
Expand Down
7 changes: 6 additions & 1 deletion optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class NeuronDecoderConfig(NeuronConfig):
be passed to export the model,
- NEURONX_CLASS (`str`) -- the name of the transformers-neuronx class to instantiate for the model.
It is a full class name defined relatively to the transformers-neuronx module, e.g. `gpt2.model.GPT2ForSampling`
[`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs.
- CONTINUOUS_BATCHING (`bool`, , defaults to `False`) -- Whether the model supports continuous batching or not.

The NEURONX_CLASS must always be defined in each model configuration.

Expand All @@ -389,6 +389,7 @@ class NeuronDecoderConfig(NeuronConfig):

INPUT_ARGS = ("batch_size", "sequence_length")
NEURONX_CLASS = None
CONTINUOUS_BATCHING = False

def __init__(self, task: str):
if not is_transformers_neuronx_available():
Expand All @@ -404,3 +405,7 @@ def __init__(self, task: str):
@property
def neuronx_class(self):
return self._neuronx_class

@property
def continuous_batching(self):
return self.CONTINUOUS_BATCHING
2 changes: 2 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ class GPT2NeuronConfig(TextNeuronDecoderConfig):
@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("t5-encoder", "text2text-generation")
Expand Down Expand Up @@ -533,3 +534,4 @@ def generate_io_aliases(self, model):
@register_in_tasks_manager("mistral", "text-generation")
class MistralNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "mistral.model.MistralForSampling"
CONTINUOUS_BATCHING = True
135 changes: 90 additions & 45 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import logging
from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import torch
from transformers import (
Expand Down Expand Up @@ -656,14 +656,14 @@ def __init__(
generation_config: Optional["GenerationConfig"] = None,
):
super().__init__(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config)
self.cur_len = 0
self.batch_size = self.model.config.batch_size
self.max_length = self.model.config.n_positions
self.batch_size = self.config.neuron["batch_size"]
self.max_length = self.config.neuron["sequence_length"]
self.continuous_batching = self.model.neuron_config and self.model.neuron_config.continuous_batching
# The generate method from GenerationMixin expects the device attribute to be set
self.device = torch.device("cpu")

def reset_generation(self):
self.cur_len = 0
pass

@add_start_docstrings_to_model_forward(
NEURON_CAUSALLM_MODEL_FORWARD_DOCSTRING
Expand All @@ -688,32 +688,78 @@ def forward(
return ModelOutput([("logits", out_logits)])
return (out_logits,)

def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> Dict[str, torch.Tensor]:
# convert attention_mask to start_ids
def get_start_ids(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[torch.Tensor] = None,
):
# The start_ids parameter has different meanings:
# - for continuous (unpadded) batching it corresponds to the sequence id,
# - for static batching it corresponds to the start of the padded sequence.
if self.continuous_batching:
if seq_ids is None:
seq_ids = torch.arange(input_ids.shape[0])
else:
assert seq_ids.shape[0] == input_ids.shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe raise an explicit error with a message saying that the shapes should match in this case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using an assert because these methods are always called internally, so it is rather to catch internal programming errors.

JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
return seq_ids
start_ids = None
if attention_mask is not None:
_, start_ids = attention_mask.max(axis=1)

if self.cur_len > 0:
# Only pass the last tokens of each sample
input_ids = input_ids[:, -1:]
# Specify the single index at which the new keys and values need to be stored
cache_ids = torch.as_tensor([self.cur_len], dtype=torch.int32)
else:
# cache_ids will be set directly by the parallel context encoding code
cache_ids = None

# Increment the current cache index
self.cur_len += input_ids.shape[-1]
model_inputs = {
return start_ids

def get_cache_ids(self, attention_mask: torch.tensor, prefill: bool):
cache_n, cache_len = attention_mask.shape
if self.continuous_batching:
# Evaluate the inputs that are not masked for each sequence
input_length = attention_mask.sum(axis=1)
if not prefill:
# When decoding, cache_ids contains a single value per sequence
return (input_length - 1).unsqueeze(1)
# When prefilling, cache_ids is an increasing range
cache_ids = torch.zeros_like(attention_mask)
for i in range(cache_n):
cur_length = input_length[i]
cache_ids[i, :cur_length] = torch.arange(cur_length)
return cache_ids
# Static batching
return None if prefill else torch.tensor([cache_len - 1], dtype=torch.int32)

def prepare_inputs_for_prefill(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, seq_ids: Optional[List[int]] = None
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=True)
if self.continuous_batching and torch.any(attention_mask[:, 0] == 0):
# Inputs are left padded: we need to invert padding as continuous batching requires right-padding
batch_size, seq_len = input_ids.shape
input_length = attention_mask.sum(axis=1)
new_input_ids = torch.zeros_like(input_ids)
for i in range(batch_size):
cur_length = input_length[i]
new_input_ids[i, :cur_length] = input_ids[i, seq_len - cur_length :]
input_ids = new_input_ids
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}

return model_inputs
def prepare_inputs_for_decode(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
seq_ids: Optional[List[int]] = None,
) -> Dict[str, torch.Tensor]:
start_ids = self.get_start_ids(input_ids, attention_mask, seq_ids=seq_ids)
cache_ids = self.get_cache_ids(attention_mask, prefill=False)
# Only pass the last tokens of each sample
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"cache_ids": cache_ids,
"start_ids": start_ids,
}

def can_generate(self) -> bool:
"""Returns True to validate the check made in `GenerationMixin.generate()`."""
Expand Down Expand Up @@ -775,7 +821,7 @@ def generate(
f"The input sequence length ({sequence_length}) exceeds the model static sequence length ({self.max_length})"
)
padded_input_ids = input_ids
padded_attention_mask = attention_mask
padded_attention_mask = torch.ones_like(input_ids) if attention_mask is None else attention_mask
if batch_size > self.batch_size:
raise ValueError(
f"The specified batch_size ({batch_size}) exceeds the model static batch size ({self.batch_size})"
Expand All @@ -784,18 +830,15 @@ def generate(
logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.")
padding_shape = [self.batch_size - batch_size, sequence_length]
padding = torch.full(padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64)
padded_input_ids = torch.cat([input_ids, padding])
if attention_mask is not None:
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([attention_mask, padding])
# Drop the current generation context and clear the Key/Value cache
self.reset_generation()
padded_input_ids = torch.cat([padded_input_ids, padding])
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([padded_attention_mask, padding])

output_ids = self.generate_tokens(
padded_input_ids,
selector,
batch_size,
attention_mask=padded_attention_mask,
padded_attention_mask,
**model_kwargs,
)
return output_ids[:batch_size, :]
Expand All @@ -805,7 +848,7 @@ def generate_tokens(
input_ids: torch.LongTensor,
selector: TokenSelector,
batch_size: int,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor,
**model_kwargs,
) -> torch.LongTensor:
r"""
Expand All @@ -831,17 +874,15 @@ def generate_tokens(
unfinished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences[:batch_size] = 1

# Prefill and obtain the first token
model_inputs = self.prepare_inputs_for_prefill(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)

# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, **model_kwargs)

# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
)

next_token_logits = outputs.logits[:, -1, :]

next_tokens = selector.select(input_ids, next_token_logits)
Expand All @@ -851,10 +892,7 @@ def generate_tokens(

# update inputs for the next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

# if eos_token was found in one sentence, set sentence to finished
unfinished_sequences = unfinished_sequences * next_tokens.ne(selector.eos_token_id)
Expand All @@ -867,4 +905,11 @@ def generate_tokens(
if selector.stopping_criteria(input_ids, None):
break

# forward pass to get next token
model_inputs = self.prepare_inputs_for_decode(input_ids, attention_mask)
outputs = self(
**model_inputs,
return_dict=True,
)

return input_ids
29 changes: 20 additions & 9 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@


if is_transformers_neuronx_available():
from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig
from transformers_neuronx.module import save_split


Expand Down Expand Up @@ -131,16 +132,26 @@ def __init__(

exporter = get_exporter(config, task)

# transformers-neuronx uses f32/f16 instead of fp32/fp16
auto_cast_type = auto_cast_type.replace("p", "")
tnx_kwargs = {
"batch_size": batch_size,
"tp_degree": num_cores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TP degree is always exactly the number of neuron cores used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I use the num_cores name in optimum-neuron because I use it for two things: to set the TP degree (here) and also to restrict the number of cores I reserve at initialization (otherwise the TNX runtime takes them all).

# transformers-neuronx uses f32/f16 instead of fp32/fp16
"amp": auto_cast_type.replace("p", ""),
}
if batch_size > 1 and exporter.continuous_batching:
# Continuous batching is always enabled for models that support it because static batching
# is broken for these models: see https://github.com/aws-neuron/transformers-neuronx/issues/79
tnx_kwargs["neuron_config"] = NeuronConfig(
continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size)
)
tnx_kwargs["n_positions"] = [sequence_length]
tnx_kwargs["context_length_estimate"] = [sequence_length]
else:
tnx_kwargs["n_positions"] = sequence_length

# Instantiate neuronx model
checkpoint_path = checkpoint_dir.name if isinstance(checkpoint_dir, TemporaryDirectory) else checkpoint_dir
neuronx_model = exporter.neuronx_class.from_pretrained(
checkpoint_path,
batch_size=batch_size,
n_positions=sequence_length,
tp_degree=num_cores,
amp=auto_cast_type,
)
neuronx_model = exporter.neuronx_class.from_pretrained(checkpoint_path, **tnx_kwargs)

if compiled_dir is not None:
# Specify the path where compiled artifacts are stored before conversion
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def export_seq2seq_model_class(request):
@requires_neuronx
def neuron_decoder_path(export_decoder_id):
model = NeuronModelForCausalLM.from_pretrained(
export_decoder_id, export=True, batch_size=1, sequence_length=100, num_cores=2
export_decoder_id, export=True, batch_size=2, sequence_length=100, num_cores=2
)
model_dir = TemporaryDirectory()
model_path = model_dir.name
Expand Down
62 changes: 1 addition & 61 deletions tests/generation/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,12 @@
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import StoppingCriteria

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron import NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, is_trainium_test, requires_neuronx
from optimum.neuron.utils.training_utils import patch_generation_mixin_to_general_neuron_generation_mixin


def _test_model_generation(model, tokenizer, batch_size, input_length, **gen_kwargs):
input_ids = torch.ones((batch_size, input_length), dtype=torch.int64)
with torch.inference_mode():
sample_output = model.generate(input_ids, **gen_kwargs)
assert sample_output.shape[0] == batch_size


def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen_kwargs):
import torch_xla.core.xla_model as xm

Expand All @@ -43,58 +35,6 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen
assert sample_output.shape[0] == batch_size


@pytest.mark.parametrize(
"gen_kwargs",
[
{"do_sample": True},
{"do_sample": True, "temperature": 0.7},
{"do_sample": False},
{"do_sample": False, "repetition_penalty": 1.2},
],
ids=["sample", "sample-with-temp", "greedy", "greedy_no-repeat"],
)
@is_inferentia_test
@requires_neuronx
def test_decoder_generation(neuron_decoder_path, gen_kwargs):
model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
_test_model_generation(model, tokenizer, model.batch_size, 10, **gen_kwargs)


@is_inferentia_test
@requires_neuronx
def test_model_generation_input_dimensions(neuron_decoder_path):
model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
# Using valid input dimensions
_test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2)
# Using an incompatible batch_size
with pytest.raises(ValueError, match="The specified batch_size"):
_test_model_generation(model, tokenizer, model.batch_size + 1, model.max_length)
# Using an incompatible input length
with pytest.raises(ValueError, match="The input sequence length"):
_test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2)


@is_inferentia_test
@requires_neuronx
def test_decoder_generation_custom_stopping_criteria():
model_id = "hf-internal-testing/tiny-random-gpt2"
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=1)

class CustomStoppingCriteria(StoppingCriteria):
def __init__(self):
self.called = False

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
self.called = True
return True

criteria = CustomStoppingCriteria()
model.generate(input_ids=torch.ones([1, 10], dtype=torch.int64), stopping_criteria=[criteria])
assert criteria.called, "Custom StoppingCriteria should have been called"


@is_inferentia_test
@requires_neuronx
def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
Expand Down
Loading
Loading