Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for chameleon processor #32181

Merged
merged 12 commits into from
Sep 26, 2024
12 changes: 6 additions & 6 deletions docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rendered properly in your Markdown viewer.
## Overview

The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.


The abstract from the paper is the following:
Expand Down Expand Up @@ -61,7 +61,7 @@ The original code can be found [here](https://github.com/facebookresearch/chamel

### Single image inference

Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token.
Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token.
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`):

```python
Expand All @@ -78,7 +78,7 @@ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
prompt = "What do you see in this image?<image>"

inputs = processor(prompt, image, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16)

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=50)
Expand Down Expand Up @@ -117,7 +117,7 @@ prompts = [

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=50)
Expand Down Expand Up @@ -152,8 +152,8 @@ from transformers import ChameleonForConditionalGeneration

model_id = "facebook/chameleon-7b"
model = ChameleonForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2"
).to(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from transformers import (
ChameleonConfig,
ChameleonForCausalLM,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw @zucchini-nlp we might need to increase prio for this PR because of this

I have this change in my other PR too, but I forgot we haven't merged it yet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was out for a while. Yes, I think some other contributor also reported the issue and wanted to open a PR to fix the conversion script. Feel free to open a PR if there isn't any, as this issue isn't at all related to processor kwargs

ChameleonForConditionalGeneration,
ChameleonImageProcessor,
ChameleonProcessor,
)
Expand All @@ -49,10 +49,10 @@
Thereafter, models can be loaded via:

```py
from transformers import ChameleonForCausalLM, LlamaTokenizer
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast

model = ChameleonForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
```

Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
Expand Down Expand Up @@ -372,7 +372,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
vocabulary_map=vocabulary_map,
)
with init_empty_weights():
model = ChameleonForCausalLM(config)
model = ChameleonForConditionalGeneration(config)

model.load_state_dict(state_dict, assign=True, strict=False)
model.save_pretrained(model_path, safe_serialization=True)
Expand All @@ -397,7 +397,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
print("Loading the checkpoint in a Chameleon model...")
print("*" * 100)
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained(model_path)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ def forward(
>>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
>>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)

>>> inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.bfloat16)
>>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)

>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
Expand Down
77 changes: 41 additions & 36 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class ChameleonTextKwargs(TextKwargs, total=False):
return_for_text_completion: bool


class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: ChameleonTextKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class ChameleonProcessor(ProcessorMixin):
Expand Down Expand Up @@ -57,13 +73,11 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_for_text_completion: bool = False,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ChameleonProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -73,26 +87,13 @@ def __call__(
of the above two methods for more information.

Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand All @@ -110,10 +111,21 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
if text is None and images is None:
raise ValueError("You must provide either text or images")

output_kwargs = self._merge_kwargs(
ChameleonProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)

# Replace the image token with the expanded image token sequence
prompt_strings = []
Expand All @@ -124,19 +136,12 @@ def __call__(
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)

data = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

if images is not None:
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
data["pixel_values"] = pixel_values
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def test_flash_attn_2_generate_padding_right(self):

processor.tokenizer.padding_side = "right"

inputs = processor(texts, return_tensors="pt", padding=True).to(0)
inputs = processor(text=texts, return_tensors="pt", padding=True).to(0)

output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = processor.tokenizer.batch_decode(output_native)
Expand Down Expand Up @@ -392,7 +392,7 @@ def test_model_7b(self):
)
prompt = "<image>Describe what do you see here and tell me about the history behind it?"

inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16)

# greedy generation outputs
EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_model_7b_batched(self):
"What constellation is this image showing?<image>",
]

inputs = processor(prompts, images=[image, image_2], padding=True, return_tensors="pt").to(
inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to(
model.device, torch.float16
)

Expand Down Expand Up @@ -450,7 +450,7 @@ def test_model_7b_multi_image(self):
)
prompt = "What do these two images have in common?<image><image>"

inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.float16)
inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16)

# greedy generation outputs
EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between two things that are not necessarily related. The first image shows a group of stars, while the second image shows a network of lines connecting two points. The connection between'] # fmt: skip
Expand Down
44 changes: 44 additions & 0 deletions tests/models/chameleon/test_processor_chameleon.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no custom tests but it still inherits the tests from ProcessorTesterMixin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I reviewed too quickly and thought this was a scrap file. We should keep and:

  • Update the checkpoint
  • Add a copyright header

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch chameleon model."""

import tempfile
import unittest

from transformers import ChameleonProcessor, LlamaTokenizer
from transformers.testing_utils import get_tests_dir
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import ChameleonImageProcessor


SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")


class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ChameleonProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = ChameleonImageProcessor()
tokenizer = LlamaTokenizer(vocab_file=SAMPLE_VOCAB)
tokenizer.pad_token_id = 0
tokenizer.sep_token_id = 1
processor = self.processor_class(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)
Loading