Skip to content

Commit

Permalink
Bump transformers and optimum version (#615)
Browse files Browse the repository at this point in the history
* tests(tgi): use InferenceClient

* chore: bump transformers and optimum versions

* test(tnx): modify multiple eos_token_ids

* feat(decoder): add support for stop strings

* feat(TGI): add support for stop sequences

* fix(generation): adapt to new GenerationMixin
  • Loading branch information
dacorvo authored May 30, 2024
1 parent 2dcdf38 commit 5e772d7
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 52 deletions.
14 changes: 11 additions & 3 deletions optimum/neuron/generation/token_selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

import torch
from transformers.generation import (
Expand All @@ -14,6 +14,9 @@
from .logits_process import FusedLogitsWarper


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -63,6 +66,7 @@ def create(
model: GenerationMixin,
max_seq_length: int,
stopping_criteria: Optional[StoppingCriteriaList] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
seed: Optional[int] = 0,
) -> "TokenSelector":
r"""Creates the `TokenSelector` for a specific generation configuration.
Expand All @@ -78,7 +82,9 @@ def create(
The maximum number of input + generated tokens for this model. It depends on the model compilation parameters.
stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config.
generation config
tokenizer (`Optional[transformers.PreTrainedTokenizer]`, default to `None`):
A tokenizer used when stop strings are passed to generate.
seed(`Optional[int]`):
The optional seed for sampling. Defaults to zero.
Return:
Expand Down Expand Up @@ -128,7 +134,9 @@ def create(
)
if stopping_criteria is None:
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)
stopping_criteria = model._get_stopping_criteria(
generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer
)

# This is not supposed to happen for any of the models we support
eos_token_id = generation_config.eos_token_id
Expand Down
18 changes: 6 additions & 12 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,9 @@ def generate(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

# 3. Define model inputs
# inputs_tensor has to be defined
Expand Down Expand Up @@ -700,6 +692,9 @@ def generate(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if (
Expand All @@ -725,7 +720,6 @@ def generate(
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
)
else:
Expand Down
9 changes: 8 additions & 1 deletion optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,13 +818,20 @@ def generate(
"""
# The actual generation configuration is a combination of config and parameters
generation_config = copy.deepcopy(self.generation_config if generation_config is None else generation_config)
# Extract tokenizer if any (used only for stop strings)
tokenizer = kwargs.pop("tokenizer", None)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# Check model kwargs are actually used by either prepare_inputs_for_generation or forward
self._validate_model_kwargs(model_kwargs)

# Instantiate a TokenSelector for the specified configuration
selector = TokenSelector.create(
input_ids, generation_config, self, self.max_length, stopping_criteria=stopping_criteria
input_ids,
generation_config,
self,
self.max_length,
stopping_criteria=stopping_criteria,
tokenizer=tokenizer,
)

# Verify that the inputs are compatible with the model static input dimensions
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@


INSTALL_REQUIRES = [
"transformers == 4.40.2",
"transformers == 4.41.1",
"accelerate == 0.29.2",
"optimum ~= 1.19.1",
"optimum ~= 1.20.0",
"huggingface_hub >= 0.20.1",
"numpy>=1.22.2, <=1.25.2",
"protobuf<4",
Expand Down
31 changes: 26 additions & 5 deletions tests/generation/test_tnx_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@pytest.fixture(scope="module")
def neuron_model_config():
model_id = "HuggingFaceTB/cosmo-1b"
model_id = "princeton-nlp/Sheared-LLaMA-1.3B"
model_kwargs = {"batch_size": 4, "sequence_length": 4096, "auto_cast_type": "f16", "num_cores": 2}
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -63,14 +63,35 @@ def test_decoder_generation_multiple_eos_token_ids(neuron_model_config):
generation_config = copy.deepcopy(model.generation_config)
if not isinstance(generation_config.eos_token_id, list):
generation_config.eos_token_id = [generation_config.eos_token_id]
generation_config.max_new_tokens = model.max_length - tokens["input_ids"].shape[-1]
# Generate and verify we stopped on an eos_token_id, and not on max_new_tokens
generation_config.max_new_tokens = 256
outputs = model.generate(**tokens, do_sample=True, generation_config=generation_config)
assert outputs.shape[-1] < model.max_length
assert outputs[0, -1].numpy() in generation_config.eos_token_id
# Extract the last non-eos generated token and use it as a fake eos_token_id
fake_eos_token_id = outputs[0, -2]
generation_config.eos_token_id.append(fake_eos_token_id)
# Generate again an verify we stopped on that id
outputs = model.generate(**tokens, do_sample=True, generation_config=generation_config)
assert outputs[0, -1] == fake_eos_token_id


@is_inferentia_test
@requires_neuronx
def test_decoder_generation_stop_strings(neuron_model_config):
model, tokenizer = neuron_model_config
prompt = "Name three fruits:"
tokens = tokenizer(prompt, return_tensors="pt")
generation_config = copy.deepcopy(model.generation_config)
generation_config.max_new_tokens = model.max_length - tokens["input_ids"].shape[-1]
# Generate once
outputs = model.generate(**tokens, do_sample=False, generation_config=generation_config)
output_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Now create a generation_config with stop_strings corresponding to the beginning of the outputs
sos = len(prompt)
stop_string = output_string[sos : sos + 10]
generation_config.stop_strings = [stop_string]
# Generate and verify we stopped on the stop string
outputs = model.generate(**tokens, do_sample=False, generation_config=generation_config, tokenizer=tokenizer)
new_output_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Verify we stopped on the stop string
assert len(new_output_string) < len(output_string)
# Verify the stop string is in the generated string (but not necessarily exactly at the end because of tokenization)
assert stop_string in output_string
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC
if request.parameters.repetition_penalty != 0:
self._generation_config.repetition_penalty = request.parameters.repetition_penalty
self.seed = request.parameters.seed
# TODO: watermark
self._generation_config.max_new_tokens = request.stopping_parameters.max_new_tokens
self._max_new_tokens = self._generation_config.max_new_tokens
# TODO: stop_sequences, ignore_eos_token
stop_strings = request.stopping_parameters.stop_sequences
if stop_strings:
self._generation_config.stop_strings = stop_strings

def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, selector: TokenSelector):
"""Reset the slot for the next generation.
Expand Down Expand Up @@ -413,7 +414,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids, slot.generation_config, self.model, self.model.max_length, seed=slot.seed
slot_input_ids,
slot.generation_config,
self.model,
self.model.max_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
Expand Down
18 changes: 10 additions & 8 deletions text-generation-inference/tests/fixtures/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from text_generation import AsyncClient
from text_generation.types import Response
from huggingface_hub import AsyncInferenceClient, TextGenerationOutput


OPTIMUM_CACHE_REPO_ID = "optimum/neuron-testing-cache"
Expand All @@ -30,10 +29,10 @@
logger = logging.getLogger(__file__)


class TestClient(AsyncClient):
class TestClient(AsyncInferenceClient):

def __init__(self, service_name: str, base_url: str):
super().__init__(base_url)
super().__init__(model=base_url)
self.service_name = service_name


Expand All @@ -51,7 +50,7 @@ async def health(self, timeout: int = 60):
raise RuntimeError(f"Service crashed after {i} seconds.")

try:
await self.client.generate("test", max_new_tokens=1)
await self.client.text_generation("test", max_new_tokens=1)
logger.info(f"Service started after {i} seconds")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
Expand Down Expand Up @@ -226,12 +225,15 @@ def generate_load():
The number of requests
Returns:
A list of `text_generation.Response`.
A list of `huggingface_hub.TextGenerationOutput`.
"""

async def generate_load_inner(client: AsyncClient, prompt: str, max_new_tokens: int, n: int) -> List[Response]:
async def generate_load_inner(
client: AsyncInferenceClient, prompt: str, max_new_tokens: int, n: int
) -> List[TextGenerationOutput]:
futures = [
client.generate(prompt, max_new_tokens=max_new_tokens, decoder_input_details=True) for _ in range(n)
client.text_generation(prompt, max_new_tokens=max_new_tokens, details=True, decoder_input_details=True)
for _ in range(n)
]

return await asyncio.gather(*futures)
Expand Down
32 changes: 20 additions & 12 deletions text-generation-inference/tests/integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ async def test_model_single_request(tgi_service):
service_name = tgi_service.client.service_name
prompt = "What is Deep Learning?"
# Greedy bounded without input
response = await tgi_service.client.generate(
prompt,
max_new_tokens=17,
decoder_input_details=True,
response = await tgi_service.client.text_generation(
prompt, max_new_tokens=17, details=True, decoder_input_details=True
)
assert response.details.generated_tokens == 17
greedy_expectations = {
Expand All @@ -30,32 +28,42 @@ async def test_model_single_request(tgi_service):
assert response.generated_text == greedy_expectations[service_name]

# Greedy bounded with input
response = await tgi_service.client.generate(
"What is Deep Learning?",
max_new_tokens=17,
return_full_text=True,
decoder_input_details=True,
response = await tgi_service.client.text_generation(
"What is Deep Learning?", max_new_tokens=17, return_full_text=True, details=True, decoder_input_details=True
)
assert response.details.generated_tokens == 17
assert response.generated_text == prompt + greedy_expectations[service_name]

# Sampling
response = await tgi_service.client.generate(
response = await tgi_service.client.text_generation(
"What is Deep Learning?",
do_sample=True,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=128,
seed=42,
decoder_input_details=True,
)
sample_expectations = {
"gpt2": "A lot of researchers have tried to make a broad, intuitive definition of Deep Learning",
"llama": "Deep Learning is a technique for training artificial neural networks",
"mistral": "Why is deep learning important?",
}
assert sample_expectations[service_name] in response.generated_text
assert sample_expectations[service_name] in response

# Sampling with stop sequence
stop_sequence = sample_expectations[service_name][-5:]
response = await tgi_service.client.text_generation(
"What is Deep Learning?",
do_sample=True,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=128,
seed=42,
stop_sequences=[stop_sequence],
)
assert response.endswith(stop_sequence)


@pytest.mark.asyncio
Expand Down
12 changes: 6 additions & 6 deletions text-generation-inference/tests/integration/test_implicit_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import pytest
from text_generation.errors import ValidationError
from huggingface_hub.errors import ValidationError


@pytest.fixture(scope="module", params=["hub-neuron", "hub", "local-neuron"])
Expand Down Expand Up @@ -42,20 +42,21 @@ async def test_model_single_request(tgi_service):
# Just verify that the generation works, and nothing is raised, with several set of params

# No params
await tgi_service.client.generate(
await tgi_service.client.text_generation(
"What is Deep Learning?",
)

response = await tgi_service.client.generate(
response = await tgi_service.client.text_generation(
"How to cook beans ?",
max_new_tokens=17,
details=True,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17

# check error
try:
await tgi_service.client.generate("What is Deep Learning?", max_new_tokens=170000)
await tgi_service.client.text_generation("What is Deep Learning?", max_new_tokens=170000)
except ValidationError:
pass
else:
Expand All @@ -65,13 +66,12 @@ async def test_model_single_request(tgi_service):
)

# Sampling
await tgi_service.client.generate(
await tgi_service.client.text_generation(
"What is Deep Learning?",
do_sample=True,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
max_new_tokens=128,
seed=42,
decoder_input_details=True,
)

0 comments on commit 5e772d7

Please sign in to comment.