Skip to content

Commit

Permalink
Add general support for generation on TRN with NxD (#370)
Browse files Browse the repository at this point in the history
* Add general support for generation on TRN with NxD

* Fix bugs and refactor generation utils

* Add unit tests and minor fixes

* Move unit tests from CPU to TRN

* Fix max length when max_new_tokens is used

* Fix styles

* Fix styles - imports

* Fix test neuron compile flags

* Add beam sample tests and remove GPT-2 tests

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
aws-tianquaw and Ubuntu authored Jan 17, 2024
1 parent 9837efa commit 8fd86c1
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 5 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

from .logits_process import FusedLogitsWarper
from .token_selector import TokenSelector
from .utils import NeuronGenerationMixin
from .utils import GeneralNeuronGenerationMixin, NeuronGenerationMixin
277 changes: 277 additions & 0 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
import copy
import inspect
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

from optimum.neuron.utils.import_utils import is_torch_xla_available

from ..utils.import_utils import is_neuronx_distributed_available
from ..utils.misc import args_and_kwargs_to_kwargs_only


if is_torch_xla_available():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -54,6 +58,279 @@
logger = logging.get_logger(__name__)


if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state


def _move_dict_args_to_device(kwargs: Dict[str, Any], device: str = "cpu") -> Dict[str, Any]:
"""
Takes keyword arguments which will be passed to a model's forward function
and moves its values to `device` if
they are of type `torch.Tensor`. If the key is a dictionary it does the same to the
respective values.
Args:
kwargs: (`Dict[str, Any]`):
The kwargs to be passed to the models forward function.
device: (`str`, defaults to `cpu`):
The target device to which tensors should be moved.
Returns:
`Dict[str, Any]`: The kwargs dict with its tensors moved to `device`.
"""

def needs_move(src_device, tgt_device):
return src_device != tgt_device

for k, v in kwargs.items():
# Handle nested dicts
if isinstance(v, dict):
for k_, v_ in v.items():
if isinstance(v_, torch.Tensor):
if needs_move(v_.device, device):
v[k_] = v_.to(device=device)

# Handle tensor types
elif isinstance(v, torch.Tensor):
if needs_move(v.device, device):
kwargs[k] = v.to(device=device)

# Handle past_key_value tuples
elif k == "past_key_values":
if v is not None:
new_past_key_values = ()
for layer_past in v:
new_layer_past = ()
for past_state in layer_past:
if needs_move(past_state.device, device):
new_layer_past += (past_state.to(device=device),)
else:
new_layer_past += (past_state,)
new_past_key_values += (new_layer_past,)
kwargs[k] = new_past_key_values

return kwargs


def _pad_input_ids_for_general_sampling(
input_ids: torch.Tensor, num_padding_values: int, pad_token_id: int
) -> torch.Tensor:
"""
Pads `input_ids` with `num_padding_values` padding tokens along the second dimension.
Args:
input_ids (`torch.Tensor`):
Input ids to be padded.
num_padding_values (`int`):
Number of padding values to add.
pad_token_id (`int`):
Token ID of padding token.
Returns:
`torch.Tensor`: Padded `input_ids`.
"""
bsz = input_ids.size(0)
input_ids = torch.cat(
[input_ids, torch.ones((bsz, num_padding_values), device=input_ids.device, dtype=torch.long) * pad_token_id], 1
)
return input_ids


def _get_fwd_for_general_sampling(
current_fwd: Callable,
generation_config: GenerationConfig,
is_encoder_decoder: bool,
vocab_size: int,
main_device: str,
to_device: str = "cpu",
output_dtype: torch.dtype = torch.float32,
) -> Callable:
"""
Wraps the passed forward function and extends it such that before each forward call
the `decoder_input_ids` are padded and all tensors are moved to `main_device` (e.g. XLA).
Then the original forward passed is called followed by a `xm.mark_step`. Subsequently,
an "unpadding" of the logits is performed. This way, all functions that process the logits
can be called without making any changes.
Args:
current_fwd (`Callable`):
The current forward function of the model.
generation_config (`GenerationConfig`):
The GenerationConfig of the model.
is_encoder_decoder (`bool`):
Defines if this is a encoder-decoder model.
vocab_size (`int`):
The total number of vocabs of the current model.
main_device (`str`):
The device on which the forward pass should be executed.
to_device (`str`, defaults to `cpu`):
The device on which all other processing should be executed.
output_dtype (`torch.dtype`, defaults to `torch.float32`):
The expected data type of the output logits.
Returns:
`Callable`: The extended forward function.
"""

@wraps(current_fwd)
def new_fwd(*args, **kwargs):
# Pad input to max length
cur_len = None
input_ids_string = "decoder_input_ids" if is_encoder_decoder else "input_ids"
if input_ids_string in kwargs:
current_input_ids = kwargs[input_ids_string]
batch_size, cur_len = current_input_ids.shape
num_padding_values = generation_config.max_length - cur_len
kwargs[input_ids_string] = _pad_input_ids_for_general_sampling(
current_input_ids, num_padding_values, generation_config.pad_token_id
)

# For decoder only models, pad decoder attention mask in addition to prompts
if "attention_mask" in kwargs and not is_encoder_decoder and num_padding_values > 0:
kwargs["attention_mask"] = torch.cat(
[
kwargs["attention_mask"],
torch.zeros((batch_size, (generation_config.max_length - cur_len)))
.long()
.to(kwargs["attention_mask"].device),
],
1,
)
# create position_ids on the fly for batch generation
if "position_ids" in set(inspect.signature(current_fwd).parameters.keys()):
position_ids = kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(kwargs["attention_mask"] == 0, 1)
kwargs["position_ids"] = position_ids

# Move inputs to device
_move_dict_args_to_device(kwargs, main_device)

# Forward
kwargs = args_and_kwargs_to_kwargs_only(current_fwd, args, kwargs)
outputs = current_fwd(**kwargs)
# Gather outputs if NxD tensor parallelism is applied and the output logits have not been gathered.
if (
is_neuronx_distributed_available()
and parallel_state.model_parallel_is_initialized()
and parallel_state.get_tensor_model_parallel_size() > 1
and outputs["logits"].shape[-1] != vocab_size
):
outputs["logits"] = xm.all_gather(
outputs["logits"],
dim=-1,
groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
)
xm.mark_step()

# Move to CPU
_move_dict_args_to_device(outputs, to_device)

# Post-process output as a function of cur_len
outputs["logits"] = outputs["logits"][:, :cur_len, ...].to(output_dtype)

return outputs

return new_fwd


class GeneralNeuronGenerationMixin(GenerationMixin):
"""
A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`].
The generation will be handled on both CPU and TRN1 in the following way:
1. Model forward pass will be executed on TRN1
2. All other logics including padding, searching, and sampling will be handled by general device (CPU).
This implementation allows us to support general searching and sampling methods with minimal code changes.
"""

@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config

generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

# 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

# 3. Define model inputs and move to CPU
general_device = "cpu"
if "input_ids" in kwargs and kwargs["input_ids"] is not None:
kwargs["input_ids"] = kwargs["input_ids"].to(general_device)
if inputs is not None:
inputs = inputs.to(general_device)
input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)

# 4. Set Neuron specific generation configurations
original_forward = copy.deepcopy(self.forward)
try:
general_forward = _get_fwd_for_general_sampling(
self.forward,
generation_config,
self.config.is_encoder_decoder,
self.config.vocab_size,
self.device,
)
self.forward = general_forward
if generation_config.use_cache:
warnings.warn(
"use_cache is not supported for generation on Neuron devices, switching to use_cache=False."
)
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
raise ValueError("Decoder-only models with inputs_embeds forwarding must use `use_cache=True`")
generation_config.use_cache = False
if generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[-1]

# 5. Run HuggingFace generate function
return super().generate(inputs, generation_config, **kwargs)
finally:
self.forward = original_forward

def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
"""Move the input tensor to XLA device and move the output tensors back to CPU."""
output = super()._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor.to(self.device), model_kwargs, model_input_name
)
_move_dict_args_to_device(output, "cpu")
return output


class NeuronGenerationMixin(GenerationMixin):
"""
A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`].
Expand Down
28 changes: 27 additions & 1 deletion optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from transformers.utils.logging import set_verbosity as set_verbosity_transformers

from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import NeuronGenerationMixin
from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin
from . import is_torch_xla_available
from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla

Expand Down Expand Up @@ -263,6 +263,32 @@ def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"):
cls.__bases__ = tuple(new_bases)


def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrainedModel"):
"""
Changes the vanilla `GenerationMixin` class from Transformers to `GeneralNeuronGenerationMixin` in the model's
inheritance. This allows to make the model Neuron-compatible for generation without much hassle.
"""
to_visit = [model.__class__]
should_stop = False
while to_visit and not should_stop:
cls = to_visit.pop(0)
if cls is object:
continue
bases = cls.__bases__
new_bases = []
for base in bases:
to_visit.append(base)
if base == GenerationMixin:
new_bases.append(GeneralNeuronGenerationMixin)
should_stop = True
elif base == GeneralNeuronGenerationMixin:
should_stop = True
new_bases.append(base)
else:
new_bases.append(base)
cls.__bases__ = tuple(new_bases)


def prepare_environment_for_neuron():
"""
Prepares the system environment for Transformers models training on AWS Neuron.
Expand Down
23 changes: 22 additions & 1 deletion tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tempfile import TemporaryDirectory

import pytest
from transformers import AutoTokenizer
from transformers import AutoTokenizer, T5ForConditionalGeneration

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import requires_neuronx
Expand All @@ -30,21 +30,42 @@
"mistral": "dacorvo/tiny-random-MistralForCausalLM",
"opt": "hf-internal-testing/tiny-random-OPTForCausalLM",
}
TRN_DECODER_MODEL_ARCHITECTURES = ["bloom", "llama", "opt"]
TRN_DECODER_MODEL_NAMES = {
"bloom": "bigscience/bloom-560m",
"llama": "dacorvo/tiny-random-llama",
"opt": "facebook/opt-125m",
}
SEQ2SEQ_MODEL_NAMES = {
"t5": "hf-internal-testing/tiny-random-t5",
}
SEQ2SEQ_MODEL_CLASSES = {
"t5": T5ForConditionalGeneration,
}


@pytest.fixture(scope="module", params=[DECODER_MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES])
def export_decoder_id(request):
return request.param


@pytest.fixture(
scope="module", params=[TRN_DECODER_MODEL_NAMES[model_arch] for model_arch in TRN_DECODER_MODEL_ARCHITECTURES]
)
def export_trn_decoder_id(request):
return request.param


@pytest.fixture(scope="module", params=[SEQ2SEQ_MODEL_NAMES[model_arch] for model_arch in SEQ2SEQ_MODEL_NAMES])
def export_seq2seq_id(request):
return request.param


@pytest.fixture(scope="module", params=[SEQ2SEQ_MODEL_CLASSES[model_arch] for model_arch in SEQ2SEQ_MODEL_NAMES])
def export_seq2seq_model_class(request):
return request.param


@pytest.fixture(scope="module")
@requires_neuronx
def neuron_decoder_path(export_decoder_id):
Expand Down
Loading

0 comments on commit 8fd86c1

Please sign in to comment.