Skip to content

Transformers use logits processor#31

Open
lapp0 wants to merge 1 commit intomainfrom
transformers-use-logits-processor
Open

Transformers use logits processor#31
lapp0 wants to merge 1 commit intomainfrom
transformers-use-logits-processor

Conversation

@lapp0
Copy link
Owner

@lapp0 lapp0 commented Jun 12, 2024

Fixes dottxt-ai#806

Fixes dottxt-ai#789

Closes dottxt-ai#910

Problem

For outlines.models.transformers, instead of using logits processors which encapsulate automata management, SequenceGenerator directly manages the automata. This different implementation resulted in dottxt-ai#789's bug.

Solution

  • Implement Transformers.generate and Transformers.stream which use HF transformers logits_processor argument with outlines.processors.OutlinesLogitsProcessor
  • Use SequenceGeneratorAdapter for transformers instead of SequenceGenerator

TODO:

  • implement Transformers.generate and Transformers.stream
  • implement SequenceGeneratorAdapter version of outlines.models.transformers
  • unit tests
  • await mlx merge and rebase onto main
  • update transformers integration documentation
  • revert llamacpp and vllm changes, these will be in a separate PR
  • ~~logits processor profiling in benchmarks~
    • will do in logits processor unification PR
  • ping people who've requested this

Bonus

Details

This new structure allows us to easily integrate multi-modal models by subclassing models.Transformer. Additionally, we can make models.mamba a Transformer model and just pass model_class=MambaLMHeadModel.

Multi-modal model example:

from outlines.processors import RegexLogitsProcessor
from outlines.models.transformers import TransformerTokenizer

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration, LogitsProcessorList, AutoTokenizer


model_uri = "llava-hf/llava-1.5-7b-hf"

url = "https://www.ilankelman.org/stopsigns/australia.jpg"
prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
output_pattern = r"This is like, totally an image of .*"


model = LlavaForConditionalGeneration.from_pretrained(model_uri, load_in_4bit=True)
llava_processor = AutoProcessor.from_pretrained(model_uri)
regex_logits_processor = RegexLogitsProcessor(
    output_pattern,
    TransformerTokenizer(AutoTokenizer.from_pretrained(model_uri)),
)

inputs = llava_processor(
    text=prompt,
    images=Image.open(requests.get(url, stream=True).raw),
    return_tensors="pt"
)

# Generate
generate_ids = model.generate(
    **inputs,
    logits_processor=LogitsProcessorList([
        regex_logits_processor
    ]),
    max_new_tokens=30
)

result = llava_processor.batch_decode(
    generate_ids,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)[0]
print(result)
# USER:\nWhat's the content of the image? ASSISTANT:This is like, totally an image of a stop sign on a street.

@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 24 times, most recently from 3f00ec7 to d9d650c Compare June 12, 2024 22:42
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 5 times, most recently from 6ea3047 to b07ac99 Compare June 12, 2024 23:10
@lapp0 lapp0 force-pushed the transformers-use-logits-processor branch 29 times, most recently from b50d280 to 6467d26 Compare June 13, 2024 23:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Update the transformers integration RegexPrefixAllowedTokens does not work for batch

1 participant