Skip to content

Commit

Permalink
eos_token_id can be a list in configs (#580)
Browse files Browse the repository at this point in the history
* fix(decoder): default sequence length

* tests(tgi): fix style

* fix(decoder): use revision when fetching generation_config

* fix(generation): eos_token_id can be a list in configs

* tests(decoder): test multiple eos token ids

* Update optimum/neuron/generation/token_selector.py

Co-authored-by: Michael Benayoun <[email protected]>

---------

Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
dacorvo and michaelbenayoun authored Apr 26, 2024
1 parent 1ea7a11 commit 43796c0
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 23 deletions.
19 changes: 10 additions & 9 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import Optional
from typing import List, Optional

import torch
from transformers.generation import (
Expand Down Expand Up @@ -41,15 +41,15 @@ def __init__(
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_id: int,
eos_token_ids: List[int],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_id = eos_token_id
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
self.generator = torch.Generator()
Expand Down Expand Up @@ -130,13 +130,14 @@ def create(
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)

# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
# This is not supposed to happen for any of the models we support
assert eos_token_id is not None and not isinstance(eos_token_id, list)
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
if generation_config.pad_token_id is None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-ended generation.")
generation_config.pad_token_id = eos_token_ids[0]

generation_mode = model._get_generation_mode(generation_config, None)
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
Expand All @@ -151,7 +152,7 @@ def create(
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_id=eos_token_id,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)
Expand Down
12 changes: 10 additions & 2 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,13 @@ def generate(
elif batch_size < self.batch_size and not self.continuous_batching:
logger.warning("Inputs will be padded to match the model static batch size. This will increase latency.")
padding_shape = [self.batch_size - batch_size, sequence_length]
padding = torch.full(padding_shape, fill_value=self.config.eos_token_id, dtype=torch.int64)
pad_token_id = generation_config.pad_token_id
if pad_token_id is None:
if isinstance(self.config.eos_token_id, list):
pad_token_id = self.config.eos_token_id[0]
else:
pad_token_id = self.config.eos_token_id
padding = torch.full(padding_shape, fill_value=pad_token_id, dtype=torch.int64)
padded_input_ids = torch.cat([padded_input_ids, padding])
padding = torch.zeros(padding_shape, dtype=torch.int64)
padded_attention_mask = torch.cat([padded_attention_mask, padding])
Expand Down Expand Up @@ -908,7 +914,9 @@ def generate_tokens(
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

# if eos_token was found in one sentence, set sentence to finished
unfinished_sequences = unfinished_sequences * next_tokens.ne(selector.eos_token_id)
unfinished_sequences = unfinished_sequences * torch.isin(
next_tokens, torch.tensor(selector.eos_token_ids), invert=True
)

# stop when each sentence is finished
if unfinished_sequences.max() == 0:
Expand Down
13 changes: 9 additions & 4 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,13 @@ def get_export_config(
batch_size = 1
# If the sequence_length was not specified, deduce it from the model configuration
if sequence_length is None:
# Note: for older models, max_position_embeddings is an alias for n_positions
sequence_length = config.max_position_embeddings
if hasattr(config, "n_positions"):
sequence_length = config.n_positions
elif hasattr(config, "max_position_embeddings"):
sequence_length = config.max_position_embeddings
else:
# Use transformers-neuronx default
sequence_length = 2048
if num_cores is None:
# Use all available cores
num_cores = get_available_cores()
Expand Down Expand Up @@ -357,7 +362,7 @@ def _export(
# Try to reload the generation config (if any)
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(model_id)
generation_config = GenerationConfig.from_pretrained(model_id, revision=revision)
except OSError:
pass

Expand Down Expand Up @@ -414,7 +419,7 @@ def _from_pretrained(
# Try to reload the generation config (if any)
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(model_id)
generation_config = GenerationConfig.from_pretrained(model_id, revision=revision)
except OSError:
pass

Expand Down
5 changes: 4 additions & 1 deletion optimum/neuron/pipelines/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def pipeline(
batch_size = model.config.neuron[attr]
if batch_size > 1 and tokenizer is not None and tokenizer.pad_token_id is None:
# The pipeline needs a pad token to be able to batch
tokenizer.pad_token_id = model.config.eos_token_id
if isinstance(model.config.eos_token_id, list):
tokenizer.pad_token_id = model.config.eos_token_id[0]
else:
tokenizer.pad_token_id = model.config.eos_token_id

return transformers_pipeline(
task,
Expand Down
4 changes: 1 addition & 3 deletions tests/generation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def export_seq2seq_model_class(request):
@pytest.fixture(scope="session")
@requires_neuronx
def neuron_decoder_path(export_decoder_id):
model = NeuronModelForCausalLM.from_pretrained(
export_decoder_id, export=True, batch_size=2, sequence_length=100, num_cores=2
)
model = NeuronModelForCausalLM.from_pretrained(export_decoder_id, export=True, batch_size=2, num_cores=2)
model_dir = TemporaryDirectory()
model_path = model_dir.name
model.save_pretrained(model_path)
Expand Down
37 changes: 34 additions & 3 deletions tests/generation/test_tnx_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import pytest
import torch
from transformers import AutoTokenizer

from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx


@is_inferentia_test
@requires_neuronx
def test_generation_llama_padded_inputs():
@pytest.fixture(scope="module")
def neuron_model_config():
model_id = "NousResearch/Llama-2-7b-chat-hf"
model_kwargs = {"batch_size": 4, "sequence_length": 4096, "auto_cast_type": "f16", "num_cores": 2}
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)
yield (model, tokenizer)


@is_inferentia_test
@requires_neuronx
def test_generation_llama_padded_inputs(neuron_model_config):
model, tokenizer = neuron_model_config
prompt = "One of my fondest memory is of my grandmother making homemade bread"
first_input = tokenizer(prompt)
first_ids = first_input["input_ids"]
Expand All @@ -43,3 +52,25 @@ def test_generation_llama_padded_inputs():
)
# Verify we did not generate any unknown token
assert torch.all(outputs[:, -1] != 0)


@is_inferentia_test
@requires_neuronx
def test_decoder_generation_multiple_eos_token_ids(neuron_model_config):
model, tokenizer = neuron_model_config
prompt = "Name three fruits:"
tokens = tokenizer(prompt, return_tensors="pt")
generation_config = copy.deepcopy(model.generation_config)
if not isinstance(generation_config, list):
generation_config.eos_token_id = [generation_config.eos_token_id]
generation_config.max_new_tokens = model.max_length - tokens["input_ids"].shape[-1]
# Generate and verify we stopped on an eos_token_id, and not on max_new_tokens
outputs = model.generate(**tokens, do_sample=True, generation_config=generation_config)
assert outputs.shape[-1] < model.max_length
assert outputs[0, -1].numpy() in generation_config.eos_token_id
# Extract the last non-eos generated token and use it as a fake eos_token_id
fake_eos_token_id = outputs[0, -2]
generation_config.eos_token_id.append(fake_eos_token_id)
# Generate againg an verify we stopped on that id
outputs = model.generate(**tokens, do_sample=True, generation_config=generation_config)
assert outputs[0, -1] == fake_eos_token_id
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import Levenshtein
import pytest

Expand Down

0 comments on commit 43796c0

Please sign in to comment.