diff --git a/recipes/configs/qwen2_5_vision/32B_full.yaml b/recipes/configs/qwen2_5_vision/32B_full.yaml new file mode 100644 index 0000000000..8957511d75 --- /dev/null +++ b/recipes/configs/qwen2_5_vision/32B_full.yaml @@ -0,0 +1,109 @@ +# Config for single device full finetuning in full_finetune_distributed.py +# using a Qwen2.5 32B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-32B-Instruct --output-dir /tmp/Qwen2.5-32B-Instruct +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2_5/32B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config qwen2_5/32B_full checkpointer.checkpoint_dir= +# +# This config was only tested on a 4xH100 machine. + +output_dir: /tmp/torchtune/qwen2_5_32B/full # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2_5_vision.Qwen25VLTransform + path: /tmp/Qwen2.5-VL-3B-Instruct/vocab.json + merges_file: /tmp/Qwen2.5-VL-3B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.models.qwen2_5_vision.qwen2_5_vl_padded_collate_images + + +# Model Arguments +model: + _component_: torchtune.models.qwen2_5_vision.qwen2_5_vl_32b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2.5-VL-32B-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00018" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2_5_VL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: 100 +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory +custom_sharded_layers: ['decoder.tok_embeddings', 'decoder.output'] + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: False +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5_vision/3B_full_single_device.yaml b/recipes/configs/qwen2_5_vision/3B_full_single_device.yaml new file mode 100644 index 0000000000..acd06ea540 --- /dev/null +++ b/recipes/configs/qwen2_5_vision/3B_full_single_device.yaml @@ -0,0 +1,115 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 VL 3B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-VL-3B-Instruct --output-dir /tmp/Qwen2.5-VL-3B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5_vision/3B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5_vision/3B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/torchtune/qwen2_5_3B_VL/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2_5_vision.Qwen25VLTransform + path: /tmp/Qwen2.5-VL-3B-Instruct/vocab.json + merges_file: /tmp/Qwen2.5-VL-3B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.models.qwen2_5_vision.qwen2_5_vl_padded_collate_images + + +# Model Arguments +model: + _component_: torchtune.models.qwen2_5_vision.qwen2_5_vl_3b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2.5-VL-3B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2_5_VL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: False +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5_vision/72B_full.yaml b/recipes/configs/qwen2_5_vision/72B_full.yaml new file mode 100644 index 0000000000..2833402e67 --- /dev/null +++ b/recipes/configs/qwen2_5_vision/72B_full.yaml @@ -0,0 +1,109 @@ +# Config for single device full finetuning in full_finetune_distributed.py +# using a Qwen2.5 72B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-72B-Instruct --output-dir /tmp/Qwen2.5-72B-Instruct +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config qwen2_5/72B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config qwen2_5/72B_full checkpointer.checkpoint_dir= +# +# This config was only tested on a 8xH100 machine. + +output_dir: /tmp/torchtune/qwen2_5_72B/full # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2_5_vision.Qwen25VLTransform + path: /tmp/Qwen2.5-VL-3B-Instruct/vocab.json + merges_file: /tmp/Qwen2.5-VL-3B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.models.qwen2_5_vision.qwen2_5_vl_padded_collate_images + + +# Model Arguments +model: + _component_: torchtune.models.qwen2_5_vision.qwen2_5_vl_72b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2.5-VL-72B-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00018" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2_5_VL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: 100 +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory +custom_sharded_layers: ['decoder.tok_embeddings', 'decoder.output'] + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: False +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5_vision/7B_full_single_device.yaml b/recipes/configs/qwen2_5_vision/7B_full_single_device.yaml new file mode 100644 index 0000000000..b37aa071c8 --- /dev/null +++ b/recipes/configs/qwen2_5_vision/7B_full_single_device.yaml @@ -0,0 +1,115 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2.5 VL 7B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2.5-VL-7B-Instruct --output-dir /tmp/Qwen2.5-VL-7B-Instruct +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2_5_vision/7B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2_5_vision/7B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/torchtune/qwen2_5_7B_VL/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2_5_vision.Qwen25VLTransform + path: /tmp/Qwen2.5-VL-7B-Instruct/vocab.json + merges_file: /tmp/Qwen2.5-VL-7B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed + subset: ocrvqa +seed: null +shuffle: True +collate_fn: torchtune.models.qwen2_5_vision.qwen2_5_vl_padded_collate_images + + +# Model Arguments +model: + _component_: torchtune.models.qwen2_5_vision.qwen2_5_vl_7b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2.5-VL-7B-Instruct + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2_5_VL +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW + lr: 5e-6 +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: False +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_rotary.py b/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_rotary.py new file mode 100644 index 0000000000..4e7f16f7e2 --- /dev/null +++ b/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_rotary.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simplified tests for Qwen2.5-VL Rotary Positional Embeddings (M-RoPE). + +These tests validate the torchtune implementation against reference values +that were computed using a HuggingFace-style reference implementation. +""" + +import pytest +import torch +from torchtune.models.qwen2_5_vision import Qwen25VLRotaryPositionalEmbeddings +from torchtune.training.seed import set_seed + + +# Test constants +BATCH_SIZE = 2 +SEQ_LEN = 4 +NUM_HEADS = 1 +HEAD_DIM = 8 +MROPE_SECTION = [1, 1, 2] # sums to 4 pairs → 8 dims +BASE = 1e6 +MAX_SEQ_LEN = 32 +MAX_HEIGHT = 16 +MAX_WIDTH = 16 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestQwen25VLRotaryEmbeddings: + @pytest.fixture + def rope(self): + return Qwen25VLRotaryPositionalEmbeddings( + head_dim=HEAD_DIM, + max_seq_len=MAX_SEQ_LEN, + max_height=MAX_HEIGHT, + max_width=MAX_WIDTH, + base=BASE, + mrope_section=MROPE_SECTION, + ) + + @pytest.fixture + def inputs(self): + return torch.randn(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM) + + @pytest.fixture + def position_ids(self): + # Create simple position IDs: time=[0,1,2,3], height=[1,1,1,1], width=[2,2,2,2] + pos_time = torch.arange(SEQ_LEN).unsqueeze(0).repeat(BATCH_SIZE, 1) + pos_height = torch.ones(BATCH_SIZE, SEQ_LEN, dtype=torch.long) + pos_width = torch.full((BATCH_SIZE, SEQ_LEN), 2, dtype=torch.long) + return torch.stack([pos_time, pos_height, pos_width], dim=0) + + def test_forward_shape(self, rope, inputs, position_ids): + """Test basic forward pass shape.""" + output = rope(inputs, position_ids) + assert output.shape == inputs.shape + + def test_forward_values(self, rope, inputs, position_ids): + """Test forward pass produces expected values.""" + output = rope(inputs, position_ids) + + # Reference values computed using HF-style reference implementation + # These values were validated against the reference M-RoPE implementation + # to ensure correctness (max difference: 0.00e+00) + expected_mean = torch.tensor(0.077044) + expected_std = torch.tensor(1.051715) + + torch.testing.assert_close(output.mean(), expected_mean, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(output.std(), expected_std, atol=1e-3, rtol=1e-3) + + def test_no_nan_inf(self, rope, inputs, position_ids): + """Test output contains no NaN or infinite values.""" + output = rope(inputs, position_ids) + assert not torch.isnan(output).any() + assert torch.isfinite(output).all() + + def test_different_positions(self, rope): + """Test with different position values.""" + inputs = torch.randn(1, 3, 1, HEAD_DIM) + + # Test with varying positions + pos_time = torch.tensor([[0, 5, 10]]) + pos_height = torch.tensor([[1, 3, 7]]) + pos_width = torch.tensor([[2, 4, 8]]) + position_ids = torch.stack([pos_time, pos_height, pos_width], dim=0) + + output = rope(inputs, position_ids) + assert output.shape == inputs.shape + assert not torch.isnan(output).any() + + def test_gradient_flow(self, rope, position_ids): + """Test gradients flow through the module.""" + inputs = torch.randn( + BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM, requires_grad=True + ) + + output = rope(inputs, position_ids) + loss = output.sum() + loss.backward() + + assert inputs.grad is not None + assert not torch.isnan(inputs.grad).any() + + def test_different_mrope_config(self): + """Test with different mrope_section configuration.""" + rope = Qwen25VLRotaryPositionalEmbeddings( + head_dim=12, # 2+4+6 = 12 + max_seq_len=MAX_SEQ_LEN, + max_height=MAX_HEIGHT, + max_width=MAX_WIDTH, + base=BASE, + mrope_section=[1, 2, 3], # Different configuration + ) + + inputs = torch.randn(1, 2, 1, 12) + pos_time = torch.tensor([[0, 1]]) + pos_height = torch.tensor([[1, 2]]) + pos_width = torch.tensor([[1, 3]]) + position_ids = torch.stack([pos_time, pos_height, pos_width], dim=0) + + output = rope(inputs, position_ids) + assert output.shape == inputs.shape + assert not torch.isnan(output).any() diff --git a/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_vision_encoder.py b/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_vision_encoder.py new file mode 100644 index 0000000000..c9b14bf66f --- /dev/null +++ b/tests/torchtune/models/qwen2_5_vision/test_qwen2_5_vl_vision_encoder.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simplified tests for Qwen2.5-VL Vision Encoder using standard configuration. + +These tests validate the torchtune vision encoder implementation using +fixed initialization and deterministic inputs. Reference values are extracted +from HuggingFace model with identical weights (using fixed_init_model) +to ensure correctness against ground truth. + +Does require a GPU to run. +""" + +import pytest +import torch +from tests.test_utils import fixed_init_model, gpu_test +from torch import nn +from torchtune.models.qwen2_5_vision import qwen2_5_vision_encoder +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(42) + + +def create_deterministic_input(): + """Create the same deterministic input as used in the extract script.""" + set_seed(42) + + num_patches = 256 + patch_dim = 1176 + + input_tensor = torch.randn(num_patches, patch_dim) + grid_thw = torch.tensor([[1, 16, 16]]) + + return input_tensor, grid_thw + + +def get_vision_encoder(): + """Create vision encoder with exact same parameters as extract script.""" + vision_encoder = qwen2_5_vision_encoder( + embed_dim=1280, + num_layers=32, + activation=nn.SiLU(), + intermediate_size=3420, + num_heads=16, + in_channels=3, + out_hidden_size=3584, + patch_size=14, + spatial_merge_size=2, + window_size=112, + full_att_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + ) + set_seed(123) + fixed_init_model(vision_encoder, min_val=-0.02, max_val=0.02) + return vision_encoder + + +@gpu_test(gpu_count=1) +def test_vision_encoder_forward(): + """Test vision encoder forward pass with fixed initialization.""" + vision_encoder = get_vision_encoder().cuda() + + image_tensor, grid_thw = create_deterministic_input() + image_tensor = image_tensor.cuda() + grid_thw = grid_thw.cuda() + + output = vision_encoder(image_tensor, grid_thw) + + expected_patches = 256 // (2 * 2) + + assert output.shape == (expected_patches, 3584) + assert not torch.isnan(output).any() + assert torch.isfinite(output).all() + + expected_mean = torch.tensor(0.005719).cuda() + expected_std = torch.tensor(9.958812).cuda() + expected_max_abs = torch.tensor(17.250065).cuda() + + torch.testing.assert_close(output.mean(), expected_mean, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(output.std(), expected_std, atol=1e-3, rtol=1e-3) + torch.testing.assert_close( + output.abs().max(), expected_max_abs, atol=1e-3, rtol=1e-3 + ) + + +@gpu_test(gpu_count=1) +def test_vision_encoder_no_nan(): + """Test that vision encoder doesn't produce NaN values.""" + vision_encoder = get_vision_encoder().cuda() + + image_tensor, grid_thw = create_deterministic_input() + image_tensor = image_tensor.cuda() + grid_thw = grid_thw.cuda() + + output = vision_encoder(image_tensor, grid_thw) + + assert not torch.isnan(output).any() + assert torch.isfinite(output).all() + + +@gpu_test(gpu_count=1) +def test_vision_encoder_deterministic(): + """Test that vision encoder produces deterministic outputs.""" + vision_encoder = get_vision_encoder().cuda() + + image_tensor, grid_thw = create_deterministic_input() + image_tensor = image_tensor.cuda() + grid_thw = grid_thw.cuda() + + output1 = vision_encoder(image_tensor, grid_thw) + output2 = vision_encoder(image_tensor, grid_thw) + + torch.testing.assert_close(output1, output2) + + +@gpu_test(gpu_count=1) +def test_vision_encoder_different_grid_sizes(): + """Test vision encoder with different grid sizes.""" + vision_encoder = get_vision_encoder().cuda() + + test_configs = [ + (64, [1, 8, 8]), # 8x8 grid + (36, [1, 6, 6]), # 6x6 grid + (16, [1, 4, 4]), # 4x4 grid + ] + + for num_patches, grid_shape in test_configs: + set_seed(42) + image_tensor = torch.randn(num_patches, 1176).cuda() + grid_thw = torch.tensor([grid_shape]).cuda() + output = vision_encoder(image_tensor, grid_thw) + + expected_patches = num_patches // 4 + assert output.shape == (expected_patches, 3584) + assert not torch.isnan(output).any() + + +@gpu_test(gpu_count=1) +def test_vision_encoder_gradient_flow(): + """Test that gradients flow through the vision encoder.""" + vision_encoder = get_vision_encoder().cuda() + + image_tensor, grid_thw = create_deterministic_input() + image_tensor = image_tensor.cuda().requires_grad_(True) + grid_thw = grid_thw.cuda() + + output = vision_encoder(image_tensor, grid_thw) + loss = output.sum() + loss.backward() + + assert image_tensor.grad is not None + assert image_tensor.grad.shape == image_tensor.shape + assert not torch.isnan(image_tensor.grad).any() diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index ee3ae60dd2..ff118781bd 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -125,6 +125,14 @@ class Recipe: name="qwen3/8B_full_single_device", file_path="qwen3/8B_full_single_device.yaml", ), + Config( + name="qwen2_5_vision/3B_full_single_device", + file_path="qwen2_5_vision/3B_full_single_device.yaml", + ), + Config( + name="qwen2_5_vision/7B_full_single_device", + file_path="qwen2_5_vision/7B_full_single_device.yaml", + ), ], supports_distributed=False, ), @@ -181,6 +189,14 @@ class Recipe: Config(name="qwen3/1.7B_full", file_path="qwen3/1.7B_full.yaml"), Config(name="qwen3/4B_full", file_path="qwen3/4B_full.yaml"), Config(name="qwen3/8B_full", file_path="qwen3/8B_full.yaml"), + Config( + name="qwen2_5_vision/32B_full", + file_path="qwen2_5_vision/32B_full.yaml", + ), + Config( + name="qwen2_5_vision/72B_full", + file_path="qwen2_5_vision/72B_full.yaml", + ), ], supports_distributed=True, ), diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 33aec33dde..b4ea3c3fd0 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -715,4 +715,4 @@ def _stack_encoder_input(batch: list[dict[str, Any]], new_dim=False) -> dict[str stacked_batch[k] = new_dict else: raise ValueError(f"Unsupported type {type(v)} for key {k}") - return stacked_batch + return stacked_batch diff --git a/torchtune/models/qwen2_5_vision/__init__.py b/torchtune/models/qwen2_5_vision/__init__.py new file mode 100644 index 0000000000..62c65d077e --- /dev/null +++ b/torchtune/models/qwen2_5_vision/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._collate import qwen2_5_vl_padded_collate_images + +from ._component_builders import qwen2_5_vision_encoder, qwen2_5_vl_decoder + +from ._convert_weights import qwen2_5_vl_hf_to_tune +from ._model_builders import ( + qwen2_5_vl_32b, + qwen2_5_vl_3b, + qwen2_5_vl_72b, + qwen2_5_vl_7b, +) + +from ._positional_embeddings import ( + Qwen25VisionRotaryPositionalEmbeddings, + Qwen25VLRotaryPositionalEmbeddings, +) + +from ._transform import Qwen25VLTransform + +__all__ = [ + "qwen2_5_vl_decoder", + "qwen2_5_vision_encoder", + "qwen2_5_vl_72b", + "qwen2_5_vl_32b", + "qwen2_5_vl_7b", + "qwen2_5_vl_3b", + "Qwen25VLRotaryPositionalEmbeddings", + "Qwen25VisionRotaryPositionalEmbeddings", + "Qwen25VLTransform", + "qwen2_5_vl_padded_collate_images", + "qwen2_5_vl_hf_to_tune", +] diff --git a/torchtune/models/qwen2_5_vision/_collate.py b/torchtune/models/qwen2_5_vision/_collate.py new file mode 100644 index 0000000000..531b474553 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_collate.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from torchtune.data import ( + CROSS_ENTROPY_IGNORE_IDX, + left_pad_sequence, + padded_collate_sft, +) + + +def qwen2_5_vl_padded_collate_images( + batch: list[dict[str, Any]], + padding_idx: int = 151655, + ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, + pad_direction: str = "right", + pad_to_multiple_of: int = 1, +) -> dict[str, torch.Tensor]: + """ + Collate a batch of samples into a single dictionary. + This is a modified version of padded_collate_tiled_images_and_mask that + compresses images and grid_thw into single batch, due to encoder input + signature. + """ + + if pad_direction not in ["left", "right"]: + raise ValueError( + f"pad_direction should be one of 'left' or 'right' but found {pad_direction}" + ) + + # Text tokens can be handled independently by existing collaters + if pad_direction == "right": + text_only = [ + {"tokens": sample["tokens"], "labels": sample["labels"]} for sample in batch + ] + collated_text = padded_collate_sft( + text_only, padding_idx, ignore_idx, pad_to_multiple_of=pad_to_multiple_of + ) + # For inference, we don't need to handle labels + elif pad_direction == "left": + if pad_to_multiple_of > 1: + raise ValueError( + f"pad_to_multiple_of={pad_to_multiple_of} is not supported for pad_direction='left'" + ) + collated_text = { + "tokens": left_pad_sequence( + [torch.tensor(x["tokens"]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + } + + batch_dict = { + "tokens": collated_text["tokens"], + } + if "labels" in collated_text: + batch_dict["labels"] = collated_text["labels"] + + # compress images and grid_thw into single batch + batch_images = [] + batch_grid_thw = [] + for sample in batch: + sample_images = sample["encoder_input"]["image"]["hidden_states"] + i, n, p = sample_images.shape + sample_images = sample_images.reshape(i * n, p) + + # Stack multiple images per sample in num_images dimension + batch_images.append(sample_images) + batch_grid_thw.append(sample["encoder_input"]["image"]["grid_thw"]) + + if "image" in batch[0]["encoder_input"]: + batch_dict["encoder_input"] = { + "image": { + "hidden_states": torch.cat(batch_images), + "grid_thw": torch.cat(batch_grid_thw), + } + } + + return batch_dict diff --git a/torchtune/models/qwen2_5_vision/_component_builders.py b/torchtune/models/qwen2_5_vision/_component_builders.py new file mode 100644 index 0000000000..25f11c3396 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_component_builders.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable +from torch import nn + +from torchtune.models.qwen2_5_vision._encoder import ( + Qwen25VisionPatchEmbed, + Qwen25VLPatchMerger, + Qwen25VisionTransformer, +) +from torchtune.modules import ( + MultiHeadAttention, + RMSNorm, + TransformerSelfAttentionLayer, + FeedForward, + TransformerDecoder, + TiedLinear, +) +from torchtune.models.qwen2_5_vision._positional_embeddings import ( + Qwen25VLRotaryPositionalEmbeddings, + Qwen25VisionRotaryPositionalEmbeddings, +) + +""" +Component builders for the Qwen 2.5 VL model and its constituent models. +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. +""" + + +def qwen2_5_vl_decoder( + vocab_size: int = 152064, + num_layers: int = 28, + num_heads: int = 28, + num_kv_heads: int = 4, + embed_dim: int = 3584, + intermediate_dim: int = 18944, + max_seq_len: int = 32768, + attn_dropout: float = 0.0, + rope_base: float = 1000000.0, + norm_eps: float = 1e-6, + mrope_section: list[int] = [16, 24, 24], + tie_word_embeddings: bool = False, +) -> TransformerDecoder: + """ + same architecture as Qwen 2.5 text decoder, just with multimodal RoPE (M-RoPE) + for handling 3D position embeddings in vision-language sequences. + + Args: + vocab_size (int): Size of vocabulary. Default: 152064 + num_layers (int): Number of transformer layers. Default: 28 + num_heads (int): Number of query heads. Default: 28 + num_kv_heads (int): Number of key/value heads (GQA). Default: 4 + embed_dim (int): Embedding dimension. Default: 3584 + intermediate_dim (int): MLP intermediate dimension. Default: 18944 + max_seq_len (int): Maximum sequence length. Default: 32768 + attn_dropout (float): Attention dropout rate. Default: 0.0 + rope_base (float): RoPE base frequency. Default: 1000000.0 + norm_eps (float): RMS norm epsilon. Default: 1e-6 + mrope_section (list[int]): MRoPE sections [temporal, height, width]. Default: [16, 24, 24] + tie_word_embeddings (bool): Whether to tie word embeddings. Default: False + + Returns: + TransformerDecoder: Text decoder with multimodal RoPE support. + """ + head_dim = embed_dim // num_heads + + rope = Qwen25VLRotaryPositionalEmbeddings( + head_dim=head_dim, + mrope_section=mrope_section, + base=rope_base, + max_seq_len=max_seq_len, + ) + # Create layers + layers = nn.ModuleList() + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=True), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=True), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + pos_embeddings=rope, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + is_causal=True, + ) + + mlp = FeedForward( + gate_proj=nn.Linear(embed_dim, intermediate_dim, bias=False), + up_proj=nn.Linear(embed_dim, intermediate_dim, bias=False), + down_proj=nn.Linear(intermediate_dim, embed_dim, bias=False), + activation=nn.SiLU(), + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), + ) + + layers.append(layer) + + # Create embeddings and output projection + tok_embeddings = nn.Embedding(vocab_size, embed_dim) + if tie_word_embeddings: + output_proj = TiedLinear(tok_embeddings) + else: + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) + + + +def qwen2_5_vision_encoder( + embed_dim: int, + num_layers: int, + activation: Callable, + intermediate_size: int, + num_heads: int, + in_channels: int, + out_hidden_size: int, + patch_size: int, + spatial_merge_size: int, + window_size: int, + full_att_block_indexes: list[int], + temporal_patch_size: int, +) -> Qwen25VisionTransformer: + """ + Build the vision encoder for Qwen2.5-VL model, including vision-language merger. + + Args: + embed_dim (int): Embedding dimension. + num_layers (int): Number of transformer layers. + activation (Callable): Activation function. + intermediate_size (int): Intermediate size. + num_heads (int): Number of attention heads. + in_channels (int): Number of input channels. + out_hidden_size (int): Output hidden size. + patch_size (int): Patch size. + spatial_merge_size (int): Spatial merge size. + window_size (int): Window size. + full_att_block_indexes (list[int]): Full attention block indexes. + temporal_patch_size (int): Temporal patch size. + + Returns: + Qwen25VisionTransformer: Instantiation of Qwen2.5-VL vision transformer. + """ + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}" + ) + + head_dim = embed_dim // num_heads + + rope = Qwen25VisionRotaryPositionalEmbeddings(head_dim // 2, spatial_merge_unit=spatial_merge_size**2) + attn_bias = True + + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias), + pos_embeddings=rope, + attn_dropout=0.0, + is_causal=False, + ) + mlp = FeedForward( + gate_proj=nn.Linear(embed_dim, intermediate_size, bias=True), + down_proj=nn.Linear(intermediate_size, embed_dim, bias=True), + up_proj=nn.Linear(embed_dim, intermediate_size, bias=True), + activation=activation, + ) + transformer_layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(embed_dim, eps=1e-6), + mlp_norm=RMSNorm(embed_dim, eps=1e-6), + sa_scale=None, + mlp_scale=None, + ) + + patch_embed = Qwen25VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + merger = Qwen25VLPatchMerger( + dim=out_hidden_size, + context_dim=embed_dim, + spatial_merge_size=spatial_merge_size, + ) + + return Qwen25VisionTransformer( + patch_size=patch_size, + num_layers=num_layers, + layer=transformer_layer, + patch_embed=patch_embed, + patch_merger=merger, + full_att_block_indexes=full_att_block_indexes, + spatial_merge_size=spatial_merge_size, + window_size=window_size, + ) diff --git a/torchtune/models/qwen2_5_vision/_convert_weights.py b/torchtune/models/qwen2_5_vision/_convert_weights.py new file mode 100644 index 0000000000..18c4ecf5b9 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_convert_weights.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torchtune.models.convert_weights import get_mapped_key +from torchtune.models.qwen2._convert_weights import _FROM_HF as _FROM_HF_QWEN2 + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + "visual.blocks.{}.attn.proj.bias": "encoders.image.layers.{}.attn.output_proj.bias", + "visual.blocks.{}.attn.proj.weight": "encoders.image.layers.{}.attn.output_proj.weight", + "visual.blocks.{}.attn.qkv.bias": "encoders.image.layers.{}.attn.q_proj.bias", + "visual.blocks.{}.attn.qkv.weight": "encoders.image.layers.{}.attn.q_proj.weight", + "visual.blocks.{}.mlp.down_proj.bias": "encoders.image.layers.{}.mlp.w2.bias", + "visual.blocks.{}.mlp.down_proj.weight": "encoders.image.layers.{}.mlp.w2.weight", + "visual.blocks.{}.mlp.gate_proj.bias": "encoders.image.layers.{}.mlp.w1.bias", + "visual.blocks.{}.mlp.gate_proj.weight": "encoders.image.layers.{}.mlp.w1.weight", + "visual.blocks.{}.mlp.up_proj.bias": "encoders.image.layers.{}.mlp.w3.bias", + "visual.blocks.{}.mlp.up_proj.weight": "encoders.image.layers.{}.mlp.w3.weight", + "visual.blocks.{}.norm1.weight": "encoders.image.layers.{}.sa_norm.scale", + "visual.blocks.{}.norm2.weight": "encoders.image.layers.{}.mlp_norm.scale", + "visual.merger.ln_q.weight": "encoders.image.merger.ln_q.scale", + "visual.merger.mlp.{}.bias": "encoders.image.merger.mlp.{}.bias", + "visual.merger.mlp.{}.weight": "encoders.image.merger.mlp.{}.weight", + "visual.patch_embed.proj.weight": "encoders.image.patch_embed.proj.weight", +} +_FROM_HF_QWEN2 = {k: "decoder." + str(v) for k, v in _FROM_HF_QWEN2.items()} + +_FROM_HF.update(_FROM_HF_QWEN2) + +QWEN2_TIED_KEY = "lm_head.weight" + + +def qwen2_5_vl_hf_to_tune( + state_dict: dict[str, torch.Tensor], + tie_word_embeddings: bool = False, +) -> dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to TorchTune's format, which contains the weights + of a Qwen2 model. + State dicts from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but may not load + output projection weights. + + Args: + state_dict (dict[str, torch.Tensor]): State dict in HF's format. + tie_word_embeddings (bool): Whether the model's input and output word embeddings should be tied. + + Returns: + dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _FROM_HF) + if "qkv" in key: + ( + q, + k, + v, + ) = value.chunk(3, dim=0) + converted_state_dict[new_key] = q + converted_state_dict[new_key.replace("q_proj", "k_proj")] = k + converted_state_dict[new_key.replace("q_proj", "v_proj")] = v + elif ( + tie_word_embeddings and QWEN2_TIED_KEY in key + ): # Skip loading the output projection weights + continue + elif "rotary_emb.inv_freq" in key: # Skip loading the position embeddings + continue + else: + converted_state_dict[new_key] = value + return converted_state_dict + + +def qwen2_5_vl_tune_to_hf( + state_dict: dict[str, torch.Tensor], +): + """ + Convert a state dict from torchtune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + dict[str, torch.Tensor]: State dict in HF's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} + + for key, value in state_dict.items(): + if "k_proj" in key or "v_proj" in key: + continue + + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + q = value + k = state_dict[key.replace("q_proj", "k_proj")] + v = state_dict[key.replace("q_proj", "v_proj")] + qkv = torch.cat([q, k, v], dim=0) + # q_proj maps to qkv_proj; no need to string replace + converted_state_dict[new_key] = qkv + else: + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/models/qwen2_5_vision/_encoder.py b/torchtune/models/qwen2_5_vision/_encoder.py new file mode 100644 index 0000000000..c10066b6d1 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_encoder.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn +from torchtune.modules.model_fusion import register_fusion_module +from torchtune.modules.rms_norm import RMSNorm + +from torchtune.modules.transformer import _get_clones + + +class Qwen25VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +class Qwen25VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen25VisionTransformer(nn.Module): + def __init__( + self, + patch_size: int, + num_layers: int, + layer: nn.Module, + patch_embed: nn.Module, + patch_merger: nn.Module, + full_att_block_indexes: list[int], + spatial_merge_size: int = 2, + window_size: int = 14, + ) -> None: + super().__init__() + self.spatial_merge_size = spatial_merge_size + self.patch_size = patch_size + self.fullatt_block_indexes = full_att_block_indexes + self.window_size = window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = patch_embed + self.layers = _get_clones(layer, num_layers) + self.merger = patch_merger + register_fusion_module(self.merger) + + def get_rope_index(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): The final hidden states of the model. + grid_thw (torch.Tensor): The temporal, height and width of feature shape of each image in LLM. + + Returns: + torch.Tensor: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + rope_index = self.get_rope_index(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + hidden_states = hidden_states.unsqueeze(0) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.layers): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + attention_mask = torch.full( + [1, seq_len, seq_len], + torch.finfo(hidden_states.dtype).min, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + for i in range(1, len(cu_seqlens_now)): + attention_mask[ + ..., + cu_seqlens_now[i - 1] : cu_seqlens_now[i], + cu_seqlens_now[i - 1] : cu_seqlens_now[i], + ] = 0 + + hidden_states = blk( + hidden_states, + input_pos=rope_index, + mask=attention_mask, + window_index=window_index, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states diff --git a/torchtune/models/qwen2_5_vision/_fusion.py b/torchtune/models/qwen2_5_vision/_fusion.py new file mode 100644 index 0000000000..2a063c3810 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_fusion.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._early_fusion import EarlyFusionModel + + +class Qwen25VL(EarlyFusionModel): + """ + Extended EarlyFusionModel for Qwen2.5-VL that handles multimodal position encoding. + Integrates the get_rope_index() functionality to compute 3D position IDs for + multimodal RoPE (temporal, height, width dimensions). + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoders: dict[str, nn.Module], + encoder_tokens: dict[str, int], + image_token_id: int = 151655, + video_token_id: int = 151656, + vision_start_token_id: int = 151652, + spatial_merge_size: int = 2, + tokens_per_second: int = 2, + decoder_trainable: bool = True, + encoders_trainable: Union[bool, dict[str, bool]] = False, + fusion_trainable: bool = True, + ): + super().__init__( + decoder=decoder, + encoders=encoders, + encoder_tokens=encoder_tokens, + decoder_trainable=decoder_trainable, + encoders_trainable=encoders_trainable, + fusion_trainable=fusion_trainable, + ) + + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.spatial_merge_size = spatial_merge_size + self.tokens_per_second = tokens_per_second + self.rope_deltas = None + + def _get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + Adapted from HuggingFace's Qwen2.5-VL implementation. + """ + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == self.image_token_id).sum() + video_nums = (vision_tokens == self.video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if self.image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(self.image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if self.video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(self.video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // self.spatial_merge_size, + w.item() // self.spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + second_per_grid_t = torch.as_tensor( + second_per_grid_t, + dtype=range_tensor.dtype, + device=range_tensor.device, + ) + + time_tensor = ( + expanded_range * second_per_grid_t * self.tokens_per_second + ) + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # Fall back to standard position encoding for text-only inputs + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[dict[str, dict[str, Any]]] = None, + input_pos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: dict[str, Any], + ) -> torch.Tensor: + """ + Extended forward pass that computes multimodal position encoding for Qwen2.5-VL. + + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): attention mask + encoder_input (Optional[dict[str, dict[str, Any]]]): encoder inputs + input_pos (Optional[torch.Tensor]): position ids (will be computed if None) + image_grid_thw (Optional[torch.LongTensor]): image grid dimensions + video_grid_thw (Optional[torch.LongTensor]): video grid dimensions + second_per_grid_ts (Optional[torch.Tensor]): time intervals for video grids + attention_mask (Optional[torch.Tensor]): attention mask for computing positions + **kwargs (dict[str, Any]): additional arguments + + Returns: + torch.Tensor: output tensor + """ + if input_pos is None: + position_ids, rope_deltas = self._get_rope_index( + input_ids=tokens, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + + input_pos = position_ids # [3, B, L] + + return super().forward( + tokens=tokens, + mask=mask, + encoder_input=encoder_input, + input_pos=input_pos, + **kwargs, + ) diff --git a/torchtune/models/qwen2_5_vision/_model_builders.py b/torchtune/models/qwen2_5_vision/_model_builders.py new file mode 100644 index 0000000000..c3c51c2927 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_model_builders.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from torchtune.data._prompt_templates import _TemplateType + +from torchtune.models.qwen2_5_vision._component_builders import ( + qwen2_5_vl_decoder, + qwen2_5_vision_encoder, +) + +from torchtune.models.qwen2_5._tokenizer import QWEN2_5_SPECIAL_TOKENS +from torchtune.models.qwen2_5_vision._fusion import Qwen25VL + +""" +Model builders build specific instantiations using component builders. +""" + +def qwen2_5_vl_3b( + *, + decoder_trainable: bool = True, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + image_size: int = 336, +) -> Qwen25VL: + """ + Builder for creating a Qwen2.5-VL 3B instruct model with vision capabilities. + + Args: + decoder_trainable (bool): Whether the language model decoder should be trainable. Default: False + encoder_trainable (bool): Whether the vision encoder should be trainable. Default: False + fusion_trainable (bool): Whether the fusion layers should be trainable. Default: False + image_size (int): Input image size for the vision encoder. Default: 336 + """ + + encoder = qwen2_5_vision_encoder( + embed_dim=1280, + num_layers=32, + activation=nn.SiLU(), + intermediate_size=3420, + num_heads=16, + in_channels=3, + out_hidden_size=2048, + patch_size=14, + spatial_merge_size=2, + window_size=112, + full_att_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + ) + + decoder = qwen2_5_vl_decoder( + vocab_size=151936, + num_layers=36, + num_heads=16, + num_kv_heads=2, + embed_dim=2048, + intermediate_dim=11008, + max_seq_len=32768, + attn_dropout=0.0, + rope_base=1000000.0, + norm_eps=1e-6, + mrope_section=[16, 24, 24], + tie_word_embeddings=True, + ) + + return Qwen25VL( + decoder=decoder, + encoders={"image": encoder}, + encoder_tokens={ + "image": QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + }, + image_token_id=QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + vision_start_token_id=QWEN2_5_SPECIAL_TOKENS["<|vision_start|>"], + spatial_merge_size=2, + tokens_per_second=2, + encoders_trainable={ + "image": encoder_trainable, + }, + decoder_trainable=decoder_trainable, + fusion_trainable=fusion_trainable, + ) + +def qwen2_5_vl_7b( + *, + decoder_trainable: bool = True, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + image_size: int = 336, +) -> Qwen25VL: + """ + Builder for creating a Qwen2.5-VL 7B instruct model with vision capabilities. + + Args: + decoder_trainable (bool): Whether the language model decoder should be trainable. Default: False + encoder_trainable (bool): Whether the vision encoder should be trainable. Default: False + fusion_trainable (bool): Whether the fusion layers should be trainable. Default: False + image_size (int): Input image size for the vision encoder. Default: 336 + + Returns: + Qwen25VLEarlyFusionModel: Qwen2.5-VL 7B model instance + """ + + encoder = qwen2_5_vision_encoder( + embed_dim=1280, + num_layers=32, + activation=nn.SiLU(), + intermediate_size=3420, + num_heads=16, + in_channels=3, + out_hidden_size=3584, + patch_size=14, + spatial_merge_size=2, + window_size=112, + full_att_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + ) + + decoder = qwen2_5_vl_decoder( + vocab_size=152064, + num_layers=28, + num_kv_heads=4, + embed_dim=3584, + intermediate_dim=18944, + max_seq_len=32768, + attn_dropout=0.0, + rope_base=1000000.0, + norm_eps=1e-6, + mrope_section=[16, 24, 24], + tie_word_embeddings=False, + ) + + return Qwen25VL( + decoder=decoder, + encoders={"image": encoder}, + encoder_tokens={ + "image": QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + }, + image_token_id=QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + vision_start_token_id=QWEN2_5_SPECIAL_TOKENS["<|vision_start|>"], + spatial_merge_size=2, + tokens_per_second=2, + encoders_trainable={ + "image": encoder_trainable, + }, + decoder_trainable=decoder_trainable, + fusion_trainable=fusion_trainable, + ) + +def qwen2_5_vl_32b( + *, + decoder_trainable: bool = True, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + image_size: int = 336, +) -> Qwen25VL: + """ + Builder for creating a Qwen2.5-VL 32B instruct model with vision capabilities. + + Args: + decoder_trainable (bool): Whether the language model decoder should be trainable. Default: False + encoder_trainable (bool): Whether the vision encoder should be trainable. Default: False + fusion_trainable (bool): Whether the fusion layers should be trainable. Default: False + image_size (int): Input image size for the vision encoder. Default: 336 + + Returns: + Qwen25VLEarlyFusionModel: Qwen2.5-VL 72B model instance + """ + + encoder = qwen2_5_vision_encoder( + embed_dim=1280, + num_layers=32, + activation=nn.SiLU(), + intermediate_size=3456, + num_heads=16, + in_channels=3, + out_hidden_size=5120, + patch_size=14, + spatial_merge_size=2, + window_size=112, + full_att_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + ) + + decoder = qwen2_5_vl_decoder( + vocab_size=152064, + num_layers=64, + num_heads=40, + num_kv_heads=8, + embed_dim=5120, + intermediate_dim=27648, + max_seq_len=32768, + attn_dropout=0.0, + rope_base=1000000.0, + norm_eps=1e-6, + mrope_section=[16, 24, 24], + tie_word_embeddings=False, + ) + + return Qwen25VL( + decoder=decoder, + encoders={"image": encoder}, + encoder_tokens={ + "image": QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + }, + image_token_id=QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + vision_start_token_id=QWEN2_5_SPECIAL_TOKENS["<|vision_start|>"], + spatial_merge_size=2, + tokens_per_second=2, + encoders_trainable={ + "image": encoder_trainable, + }, + decoder_trainable=decoder_trainable, + fusion_trainable=fusion_trainable, + ) + +def qwen2_5_vl_72b( + *, + decoder_trainable: bool = True, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + image_size: int = 336, +) -> Qwen25VL: + """ + Builder for creating a Qwen2.5-VL 72B instruct model with vision capabilities. + + Args: + decoder_trainable (bool): Whether the language model decoder should be trainable. Default: False + encoder_trainable (bool): Whether the vision encoder should be trainable. Default: False + fusion_trainable (bool): Whether the fusion layers should be trainable. Default: False + image_size (int): Input image size for the vision encoder. Default: 336 + + Returns: + Qwen25VLEarlyFusionModel: Qwen2.5-VL 72B model instance + """ + + encoder = qwen2_5_vision_encoder( + embed_dim=1280, + num_layers=32, + activation=nn.SiLU(), + intermediate_size=3456, + num_heads=16, + in_channels=3, + out_hidden_size=8192, + patch_size=14, + spatial_merge_size=2, + window_size=112, + full_att_block_indexes=[7, 15, 23, 31], + temporal_patch_size=2, + ) + + decoder = qwen2_5_vl_decoder( + vocab_size=152064, + num_layers=80, + num_heads=64, + num_kv_heads=8, + embed_dim=8192, + intermediate_dim=29568, + max_seq_len=32768, + attn_dropout=0.0, + rope_base=1000000.0, + norm_eps=1e-6, + mrope_section=[16, 24, 24], + tie_word_embeddings=False, + ) + + return Qwen25VL( + decoder=decoder, + encoders={"image": encoder}, + encoder_tokens={ + "image": QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + }, + image_token_id=QWEN2_5_SPECIAL_TOKENS["<|image_pad|>"], + vision_start_token_id=QWEN2_5_SPECIAL_TOKENS["<|vision_start|>"], + spatial_merge_size=2, + tokens_per_second=2, + encoders_trainable={ + "image": encoder_trainable, + }, + decoder_trainable=decoder_trainable, + fusion_trainable=fusion_trainable, + ) \ No newline at end of file diff --git a/torchtune/models/qwen2_5_vision/_positional_embeddings.py b/torchtune/models/qwen2_5_vision/_positional_embeddings.py new file mode 100644 index 0000000000..7c63b84883 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_positional_embeddings.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Qwen25VLRotaryPositionalEmbeddings(nn.Module): + """ + M-RoPE (Multimodal Rotary Embeddings) for Qwen2.5-VL. + + Initially described in https://arxiv.org/pdf/2409.12191. + + Extends standard 1D RoPE to three axes: time, height, width. + + Unlike the huggingface implementation, this implementation caches the RoPE tables + for each position and each of the three dimensions. + Args: + head_dim (int): dimensionality per head (e.g. 128) + max_seq_len (int): maximum temporal length to expect (default 128000) + max_height (int): maximum height to expect (default 4096) + max_width (int): maximum width to expect (default 4096) + base (float): geometric base for theta (default 1e6) + mrope_section (list[int]): number of frequency-pairs for [time, height, width] (default [16, 24, 24]) + """ + + def __init__( + self, + head_dim: int, + max_seq_len: int = 128000, + max_height: int = 4096, + max_width: int = 4096, + base: float = 1000000.0, + mrope_section: Optional[list[int]] = None, + ) -> None: + super().__init__() + + if mrope_section is None: + mrope_section = [16, 24, 24] + + if sum(mrope_section) * 2 != head_dim: + raise ValueError( + f"mrope_section pairs {mrope_section} must satisfy 2*sum = head_dim ({head_dim})" + ) + + self.head_dim = head_dim + + self.max_seq_len = max_seq_len + self.max_height = max_height + self.max_width = max_width + + self.base = base + self.mrope_section = mrope_section + + self.rope_init() + + def rope_init(self) -> None: + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim) + ) + attention_scaling = 1.0 + self.register_buffer("theta", theta, persistent=False) + self.attention_scaling = attention_scaling + + self.build_rope_cache("time", self.max_seq_len) + self.build_rope_cache("height", self.max_height) + self.build_rope_cache("width", self.max_width) + + def build_rope_cache(self, name: str, length: int): + # positions 0…length-1 + p = torch.arange(length, device=self.theta.device, dtype=self.theta.dtype) + # [length, head_dim/2] + angles = torch.einsum("p,f->pf", p, self.theta).float() + # [length, head_dim] + freqs = torch.cat([angles, angles], dim=-1) + # [length, 2*head_dim] + cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + self.register_buffer(f"{name}_cache", cache, persistent=False) + + def forward( + self, + x: torch.Tensor, + input_pos: torch.LongTensor, + *, + window_index: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute M-RoPE cos/sin tables for a batch of queries/keys. + + Args: + x (torch.Tensor): input tensor with shape ``[B, s_x, n_heads, head_dim]`` + input_pos (torch.LongTensor): the time, height, width indices with shape ``[3, B, L]`` + window_index (Optional[torch.Tensor]): Optional tensor for window indexing (not used in M-RoPE) + + Returns: + q_out (torch.Tensor): output tensor with shape ``[B, s_x, n_heads, head_dim]`` + + Notation used for tensor shapes: + - B: batch size + - s_x: sequence length + - n_heads: number of attention heads + - head_dim: dimension of each head + - L: sequence length + - D: head dimension + """ + sections = self.mrope_section * 2 + + # unpack input_pos into three tensors of shape [B, L] + t_ids, h_ids, w_ids = input_pos + + # retrieve caches at position index, returns tensor of shape [] + cache_t = self.time_cache[t_ids] + cache_h = self.height_cache[h_ids] + cache_w = self.width_cache[w_ids] + + # [3, B, L, 2*D] + stacked = torch.stack([cache_t, cache_h, cache_w], dim=0) + + cos3 = stacked[..., : self.head_dim] * self.attention_scaling + sin3 = stacked[..., self.head_dim :] * self.attention_scaling + + # split into chunks of size self.mrope_section + cos_chunks = cos3.split(sections, dim=-1) + sin_chunks = sin3.split(sections, dim=-1) + + # for each block, pick the modality slice + cos_parts = [cos_chunks[i][i % 3] for i in range(len(cos_chunks))] + sin_parts = [sin_chunks[i][i % 3] for i in range(len(sin_chunks))] + + # concat back to [B, L, D] and unsqueeze heads-axis → [B,1,L,D] + # NOTE: the head dimension is the axis 2 + cos = torch.cat(cos_parts, dim=-1).unsqueeze(2) + sin = torch.cat(sin_parts, dim=-1).unsqueeze(2) + + x_out = (x * cos) + (rotate_half(x) * sin) + return x_out.to(x.dtype) + + +class Qwen25VisionRotaryPositionalEmbeddings(nn.Module): + """ + 2D Rope for Qwen 2.5 VL's Vision Transformer + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ``embed_dim // num_heads`` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + spatial_merge_unit (int): size of a spatial merge unit, + aka the number of patches that share the same position index + """ + + def __init__( + self, + dim: int, + max_seq_len: int = 4096, + base: int = 10_000, + spatial_merge_unit: int = 4, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self.spatial_merge_unit = spatial_merge_unit + self.rope_init() + + def rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward( + self, + x: torch.Tensor, + *, + input_pos: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + ``[b, s, n_h, h_d]`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + window_index (Optional[torch.Tensor]): Optional tensor which contains the window index + of each token. During training, this is used to indicate the window index + of each token when packed, shape [b, s]. + + Returns: + torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + # merge height and width into one dimension + rope_cache = rope_cache.flatten(1) # [s, h_d] + + # rearrange indices to match window index + rope_cache = rope_cache.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rope_cache = rope_cache[window_index, :, :] + rope_cache = rope_cache.reshape(seq_len, -1) + + # reshape input; the last dimension is used for computing the output. + x_float = x.float() + half_dim = x_float.shape[-1] // 2 + x1 = x_float[..., :half_dim] + x2 = x_float[..., half_dim:] + xshaped = torch.stack([x1, x2], dim=-1) + + # reshape the cache for broadcasting + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, h_d] + x_out = x_out.flatten(3) + return x_out.type_as(x) diff --git a/torchtune/models/qwen2_5_vision/_tokenizer.py b/torchtune/models/qwen2_5_vision/_tokenizer.py new file mode 100644 index 0000000000..c7082f40a7 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_tokenizer.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +from torchtune.data import ChatMLTemplate, Message, PromptTemplate, truncate +from torchtune.models.qwen2._tokenizer import ( + DEFAULT_QWEN2_TOKENIZER_BPE_CACHE_SIZE, + ENDOFTEXT, + IM_END, +) + +from torchtune.models.qwen2_5._tokenizer import QWEN2_5_SPECIAL_TOKENS, Qwen2_5Tokenizer + + +class Qwen25VLTokenizer(Qwen2_5Tokenizer): + """ + This class constructs a Qwen2.5-VL tokenizer, inheriting from Qwen2_5Tokenizer. + + This class overrides the tokenize_messages method to support vision tokens. + + See Qwen2_5Tokenizer for more details. + """ + + def __init__( + self, + path: str, + merges_file: str, + special_tokens: dict[str, int] = QWEN2_5_SPECIAL_TOKENS, + max_seq_len: Optional[int] = None, + *, + prompt_template: Optional[PromptTemplate] = None, + errors: str = "replace", + unk_token: Optional[str] = None, + bos_token: Optional[str] = None, + eos_token: str = IM_END, + pad_token: Optional[str] = ENDOFTEXT, + bpe_cache_size: int = DEFAULT_QWEN2_TOKENIZER_BPE_CACHE_SIZE, + truncation_type: str = "right", + ): + super().__init__( + path=path, + merges_file=merges_file, + special_tokens=special_tokens, + max_seq_len=max_seq_len, + prompt_template=prompt_template, + errors=errors, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + bpe_cache_size=bpe_cache_size, + truncation_type=truncation_type, + ) + + self.im_start_id = self.special_tokens["<|im_start|>"] + self.im_end_id = self.special_tokens["<|im_end|>"] + self.image_pad_id = self.special_tokens["<|image_pad|>"] + self.video_pad_id = self.special_tokens["<|video_pad|>"] + self.vision_start_token_id = self.special_tokens["<|vision_start|>"] + self.vision_end_token_id = self.special_tokens["<|vision_end|>"] + + def tokenize_messages( + self, + messages: list[Message], + *, + add_eos: bool = True, + ) -> tuple[list[int], list[bool]]: + """ + Given a list of messages, return a list of tokens for the concatenated + and formatted messages. + + Args: + messages (list[Message]): The message list to tokenize. + add_eos (bool): Wether to add the tokenizer's eos_id at the end of the + sequence of messages. Default is True. + + Returns: + tuple[list[int], list[bool]]: The list of token ids and the list of masks. + + Raises: + RuntimeError: If a message contains non-text content + """ + assert not isinstance(self.prompt_template, ChatMLTemplate), ( + "Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message." + "Please use a different template or set to None." + ) + templated_messages = ( + self.prompt_template(messages) + if self.prompt_template is not None + else messages + ) + + tokenized_messages = [] + mask = [] + for i, message in enumerate(templated_messages): + # message header + tokens = self._tokenize_header(templated_messages, i) + + # message content + for item in message.content: + if item["type"] == "text": + tokens.extend( + self.encode( + item["content"], + add_bos=False, + add_eos=False, + ) + ) + elif item["type"] == "image": + num_image_tokens = item.get("num_image_tokens") + + tokens.append(self.vision_start_token_id) + tokens.extend([self.image_pad_id] * num_image_tokens) + tokens.append(self.vision_end_token_id) + elif item["type"] == "video": + num_video_tokens = item.get("num_video_tokens") + + tokens.append(self.vision_start_token_id) + tokens.extend([self.video_pad_id] * num_video_tokens) + tokens.append(self.vision_end_token_id) + else: + raise RuntimeError( + f"Unsupported message content type: {item['type']}" + ) + + # message footer + tokens.extend(self._tokenize_footer(templated_messages, i)) + + tokenized_messages.extend(tokens) + mask.extend([message.masked] * len(tokens)) + + # Break out early if we reach max_seq_len + if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len: + break + + # Add the End-Of-Sequence token + if add_eos: + tokenized_messages.append(self.eos_id) + mask.append(mask[-1]) + + # Finally, truncate if necessary + if self.max_seq_len: + tokenized_messages = truncate( + tokens=tokenized_messages, + max_seq_len=self.max_seq_len, + eos_id=self.eos_id if add_eos else None, + truncation_type=self.truncation_type, + ) + mask = truncate( + tokens=mask, + max_seq_len=self.max_seq_len, + eos_id=True if add_eos else None, + truncation_type=self.truncation_type, + ) + + return tokenized_messages, mask diff --git a/torchtune/models/qwen2_5_vision/_transform.py b/torchtune/models/qwen2_5_vision/_transform.py new file mode 100644 index 0000000000..8f16334164 --- /dev/null +++ b/torchtune/models/qwen2_5_vision/_transform.py @@ -0,0 +1,468 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Any, Mapping, Optional + +import torch +from PIL import Image + +from torchtune.data import Message +from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType +from torchtune.models.qwen2_5_vision._tokenizer import Qwen25VLTokenizer +from torchtune.modules.transforms import Transform +from torchtune.modules.transforms.tokenizers import ( + ModelTokenizer, + parse_hf_tokenizer_json, +) +from torchvision.transforms import InterpolationMode +from torchvision.transforms.v2 import functional as F + +logger = logging.getLogger(__name__) + +# HuggingFace OPENAI_CLIP constants to match their normalization +OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 56 * 56, + max_pixels: int = 12845056, +): + """Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Qwen25VLImageTransform: + """ + This class accepts images of any size and dynamically resizes, normalizes and patches it + based on the image size constraints and patch size. + + Args: + image_mean (Optional[list[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, uses OPENAI_CLIP_MEAN. Default None. + image_std (Optional[list[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, uses OPENAI_CLIP_STD. Default None. + patch_size (int): Size of the patches to divide the image into. Default 14. + merge_size (int): Size of the patch merging factor. Default 2. + temporal_patch_size (int): Size of the temporal patch merging factor. Default 2. + size (Optional[dict[str, int]]): Size configuration with 'shortest_edge' and 'longest_edge' keys. + min_pixels (Optional[int]): Minimum number of pixels for the shorter edge. Default 3136 (56 * 56). + max_pixels (Optional[int]): Maximum number of pixels for the longer edge. Default 1003520 (28 * 28 * 1280). + dtype (torch.dtype): Data type of the output image. Default torch.float32. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bicubic'. + + Raises: + ValueError: If size is provided but does not contain 'shortest_edge' and 'longest_edge' keys. + """ + + def __init__( + self, + *, + image_mean: Optional[list[float]] = None, + image_std: Optional[list[float]] = None, + patch_size: int = 14, + merge_size: int = 2, + temporal_patch_size: int = 2, + size: Optional[dict[str, int]] = None, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + dtype: torch.dtype = torch.float32, + resample: str = "bicubic", + ) -> None: + self.patch_size = patch_size + self.merge_size = merge_size + self.temporal_patch_size = temporal_patch_size + + # Handle size configuration - prioritize size dict over individual params + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError( + "size must contain 'shortest_edge' and 'longest_edge' keys." + ) + self.size = size.copy() + else: + self.size = {"shortest_edge": 56 * 56, "longest_edge": 12845056} + + # Override with individual parameters if provided + if min_pixels is not None: + self.size["shortest_edge"] = min_pixels + if max_pixels is not None: + self.size["longest_edge"] = max_pixels + + self.min_pixels = self.size["shortest_edge"] + self.max_pixels = self.size["longest_edge"] + + self.dtype = dtype + self.resample = getattr(InterpolationMode, resample.upper()) + + # Use OPENAI_CLIP defaults if not provided (matches HuggingFace) + self.mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.std = image_std if image_std is not None else OPENAI_CLIP_STD + + def __call__( + self, sample: Mapping[str, Any], inference: bool = False + ) -> Mapping[str, Any]: + """ + Apply image decoding and transformations to the "image" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with an "image" field containing + a PIL Image or torch.Tensor + inference (bool): Whether the template is being used for inference or not. + + Returns: + Mapping[str, Any]: The sample with updated fields: + - "pixel_values": Flattened patches tensor + - "image_grid_thw": Grid dimensions (temporal, height, width) + - "num_patches": Number of patches calculated + """ + image = sample["image"] + assert isinstance( + image, (Image.Image, torch.Tensor) + ), "Input image must be a PIL image or a torch.Tensor." + + # Convert to RGB and tensor + if isinstance(image, Image.Image) and image.mode != "RGB": + image = image.convert("RGB") + image = F.to_image(image) + + # Convert to float and rescale to [0, 1] - this matches HF's rescaling step + image = F.to_dtype(image, dtype=torch.float32, scale=True) + + # Get image dimensions + height, width = image.shape[-2:] + + # Calculate resize dimensions + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + + # Resize image + image = F.resize( + image, size=(resized_height, resized_width), interpolation=self.resample + ) + + # Normalize with OPENAI_CLIP values + image = F.normalize(image, mean=self.mean, std=self.std) + + image = image.to(dtype=self.dtype) + + patches = image.unsqueeze(0) + + if patches.shape[0] % self.temporal_patch_size != 0: + repeats_needed = self.temporal_patch_size - ( + patches.shape[0] % self.temporal_patch_size + ) + last_frame = patches[-1:].repeat(repeats_needed, 1, 1, 1) + patches = torch.cat([patches, last_frame], dim=0) + + # Calculate grid dimensions + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h, grid_w = ( + resized_height // self.patch_size, + resized_width // self.patch_size, + ) + channels = patches.shape[1] + + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + channels, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channels * self.temporal_patch_size * self.patch_size * self.patch_size, + ) + + num_patches = grid_h * grid_w + num_image_tokens = num_patches // self.merge_size**2 + + sample.update( + { + "pixel_values": flatten_patches, + "image_grid_thw": torch.tensor([[grid_t, grid_h, grid_w]]), + "num_image_tokens": num_image_tokens, + } + ) + + return sample + + +class Qwen25VLTransform(ModelTokenizer, Transform): + """ + Transform for Qwen 2.5 Vision model that handles both text tokenization and image processing. + + Args: + path (str): Path to the tokenizer vocab.json file. + merges_file (str): Path to the tokenizer merges.txt file. + patch_size (int): Size of the patches used in vision processing. Default 14. + special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face + model files that contains all registered special tokens, or a local json file + structured similarly. Default is None to use the canonical Qwen 2.5 special tokens. + max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. Default is None. + image_mean (Optional[list[float]]): Mean values of each channel, used for normalization. + Default None to use OPENAI_CLIP_MEAN. + image_std (Optional[list[float]]): Standard deviations for each channel, used for normalization. + Default None to use OPENAI_CLIP_STD. + dtype (torch.dtype): Data type of transformed image. Default torch.float32. + prompt_template (Optional[_TemplateType]): template used to format the messages based on their role. + """ + + def __init__( + self, + path: str, + merges_file: str, + *, + patch_size: int = 14, + special_tokens_path: Optional[str] = None, + max_seq_len: Optional[int] = None, + image_mean: Optional[list[float]] = None, + image_std: Optional[list[float]] = None, + dtype: torch.dtype = torch.float32, + prompt_template: Optional[_TemplateType] = None, + ): + special_tokens = ( + parse_hf_tokenizer_json(special_tokens_path) + if special_tokens_path is not None + else None + ) + template = ( + _get_prompt_template(prompt_template) + if prompt_template is not None + else None + ) + self.tokenizer = Qwen25VLTokenizer( + path=path, + merges_file=merges_file, + max_seq_len=max_seq_len, + prompt_template=template, + ) + + # Initialize the Qwen2.5 VL image transform + self.image_transform = Qwen25VLImageTransform( + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + merge_size=2, # Default merge size for Qwen2.5-VL + temporal_patch_size=2, # Default temporal patch size + dtype=dtype, + resample="bicubic", + ) + + self.stop_tokens = self.tokenizer.stop_tokens + self.special_tokens = self.tokenizer.special_tokens + self.max_seq_len = max_seq_len + self.patch_size = patch_size + self.prompt_template = prompt_template + self.pad_id = self.tokenizer.pad_id + + @property + def base_vocab_size(self) -> int: + return len(self.tokenizer.encoder) + + @property + def vocab_size(self) -> int: + # Total vocab size includes base vocab + special tokens + return len(self.tokenizer.encoder) + len(self.tokenizer.special_tokens) + + def encode( + self, + text: str, + add_bos: bool = True, + add_eos: bool = True, + ) -> list[int]: + """ + Encode a string into a list of token ids. + + Args: + text (str): The string to encode. + add_bos (bool): Whether to add the tokenizer's bos_id. Default is True. + add_eos (bool): Whether to add the tokenizer's eos_id. Default is True. + + Returns: + list[int]: The list of token ids. + """ + return self.tokenizer.encode(text=text, add_bos=add_bos, add_eos=add_eos) + + def decode( + self, + token_ids: list[int], + truncate_at_eos: bool = True, + skip_special_tokens: bool = True, + ) -> str: + """ + Decode a list of token ids into a string. + + Args: + token_ids (list[int]): The list of token ids. + truncate_at_eos (bool): Whether to truncate the string at the end of + sequence token. Default is True. + skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. + Default is True. + + Returns: + str: The decoded string. + """ + if truncate_at_eos and self.tokenizer.eos_id in token_ids: + eos_index = token_ids.index(self.tokenizer.eos_id) + token_ids = token_ids[:eos_index] + + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def transform_image( + self, image: Image.Image, inference: bool = False + ) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Transform an image into flattened patches for the vision encoder. + This method applies the transformations defined in `Qwen25VLImageTransform`. + + Args: + image (Image.Image): The input image. + inference (bool): Whether to run in inference mode. This is passed to the + underlying image transform. Default is False. + + Returns: + tuple[torch.Tensor, torch.Tensor, int]: A tuple containing: + - The transformed image patches as a tensor. + - The image grid dimensions (t, h, w) as a tensor. + - The number of patches calculated. + """ + sample = {"image": image} + transformed = self.image_transform(sample, inference=inference) + return ( + transformed["pixel_values"], + transformed["image_grid_thw"], + transformed["num_image_tokens"], + ) + + def tokenize_message( + self, + message: Message, + *, + add_start_tokens: bool = True, + add_end_tokens: bool = True, + ) -> list[int]: + """ + Tokenize a single message into a list of token ids. + + Args: + message (Message): The message to tokenize. + add_start_tokens (bool): Whether to add the tokenizer's bos_id. Default True. + add_end_tokens (bool): Whether to add the tokenizer's eos_id. Default True. + + Returns: + list[int]: The list of token ids. + """ + return self.tokenizer.tokenize_message( + message=message, + add_start_tokens=add_start_tokens, + add_end_tokens=add_end_tokens, + ) + + def tokenize_messages( + self, + messages: list[Message], + *, + add_end_tokens: bool = True, + ) -> tuple[list[int], list[bool]]: + """ + Tokenize a list of messages into a list of token ids and masks. + + Args: + messages (list[Message]): The list of messages to tokenize. + add_end_tokens (bool): Whether to add the tokenizer's eos_id. Default True. + + Returns: + tuple[list[int], list[bool]]: The list of token ids and the list of masks. + """ + return self.tokenizer.tokenize_messages( + messages=messages, + add_end_tokens=add_end_tokens, + ) + + def __call__( + self, sample: Mapping[str, Any], inference: bool = False + ) -> Mapping[str, Any]: + """ + Apply image decoding, transformations and tokenization to messages in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field. + inference (bool): Whether to run in inference mode. Default is False. + + Returns: + Mapping[str, Any]: The transformed sample with the following fields: + - tokens: list[int] of tokenized messages + - mask: list[bool] of masks for the tokenized messages + - encoder_input: dict[str, Any] of transformed images + """ + encoder_input = {"image": {"hidden_states": [], "grid_thw": []}} + messages = sample["messages"] + for message in messages: + for content in message.content: + if content["type"] == "image": + image = content["content"] + + ( + pixel_values, + image_grid_thw, + num_image_tokens, + ) = self.transform_image(image, inference=inference) + + content["num_image_tokens"] = num_image_tokens + + encoder_input["image"]["hidden_states"].append(pixel_values) + encoder_input["image"]["grid_thw"].append(image_grid_thw) + + encoder_input["image"]["hidden_states"] = torch.stack( + encoder_input["image"]["hidden_states"], dim=0 + ) + encoder_input["image"]["grid_thw"] = torch.cat( + encoder_input["image"]["grid_thw"], dim=0 + ) + + sample["encoder_input"] = encoder_input + sample = self.tokenizer(sample, inference=inference) + return sample diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 62e4227b57..0b9c58ec11 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import inspect import logging from typing import Optional @@ -15,6 +16,36 @@ logger = logging.getLogger(__name__) +def _call_pos_embedding_safely( + pos_embedding: nn.Module, + x: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Call positional embedding with only the parameters it accepts. + + Args: + pos_embedding (nn.Module): The positional embedding module + x (torch.Tensor): Input tensor + input_pos (Optional[torch.Tensor]): Optional input position tensor + window_index (Optional[torch.Tensor]): Optional window index tensor + + Returns: + Output tensor from positional embedding + """ + sig = inspect.signature(pos_embedding.forward) + kwargs = {} + + # Only add parameters that the method accepts + if "input_pos" in sig.parameters: + kwargs["input_pos"] = input_pos + if "window_index" in sig.parameters: + kwargs["window_index"] = window_index + + return pos_embedding(x, **kwargs) + + class MultiHeadAttention(nn.Module): """Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. @@ -185,6 +216,7 @@ def forward( *, mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -209,6 +241,8 @@ def forward( of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. + window_index (Optional[torch.Tensor]): Optional tensor which contains the window index + of each token. Default is None. Raises: ValueError: If no ``y`` input and ``kv_cache`` is not enabled. @@ -239,7 +273,9 @@ def forward( # Apply positional embeddings if self.pos_embeddings is not None: - q = self.pos_embeddings(q, input_pos=input_pos) + q = _call_pos_embedding_safely( + self.pos_embeddings, q, input_pos, window_index + ) # [b, n_h, s_x, h_d] q = q.transpose(1, 2) @@ -267,7 +303,9 @@ def forward( k = k.view(b, s_y, -1, self.head_dim) v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: - k = self.pos_embeddings(k, input_pos=input_pos) + k = _call_pos_embedding_safely( + self.pos_embeddings, k, input_pos, window_index + ) # k,v shape: [b, n_kv, s_y, h_d] k = k.transpose(1, 2) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 724138b14e..76e90820dc 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -91,6 +91,7 @@ def forward( *, mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, + window_index: Optional[torch.Tensor] = None, **kwargs: dict, ) -> torch.Tensor: """ @@ -115,6 +116,8 @@ def forward( of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. + window_index (Optional[torch.Tensor]): Optional tensor which contains the window index + of each token. Default is None. **kwargs (dict): transformer layer inputs not relevant to self attention. Returns: @@ -129,7 +132,9 @@ def forward( # With TP we need to use a replicated tensor here bsz, seq_len, *_ = h.shape mask = self.mask_mod(mask=mask, bsz=bsz, seq_len=seq_len) - attn_out = self.attn(h, h, mask=mask, input_pos=input_pos) + attn_out = self.attn( + h, h, mask=mask, input_pos=input_pos, window_index=window_index + ) # Residual connection; shape: [batch_size, seq_length, embed_dim] h = self.sa_scale(attn_out) + x diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index ce5ccf5963..35e5ec3770 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -609,6 +609,14 @@ def load_checkpoint(self) -> dict[str, Any]: dim=self._config["hidden_size"], tie_word_embeddings=self._config["tie_word_embeddings"], ) + elif self._model_type == ModelType.QWEN2_5_VL: + from torchtune.models.qwen2_5_vision._convert_weights import ( + qwen2_5_vl_hf_to_tune, + ) + + converted_state_dict[training.MODEL_KEY] = qwen2_5_vl_hf_to_tune( + merged_state_dict, + ) elif self._model_type == ModelType.QWEN3: from torchtune.models.qwen3._convert_weights import qwen3_hf_to_tune @@ -748,6 +756,14 @@ def save_checkpoint( dim=self._config["hidden_size"], tie_word_embeddings=self._config["tie_word_embeddings"], ) + elif self._model_type == ModelType.QWEN2_5_VL: + from torchtune.models.qwen2_5_vision._convert_weights import ( + qwen2_5_vl_tune_to_hf, + ) + + state_dict[training.MODEL_KEY] = qwen2_5_vl_tune_to_hf( + state_dict[training.MODEL_KEY], + ) elif self._model_type == ModelType.QWEN3: from torchtune.models.qwen3._convert_weights import qwen3_tune_to_hf diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 1dde03a121..0623569c99 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -98,6 +98,7 @@ class ModelType(Enum): to a single class for reward modelling. See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b` QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` + QWEN2_5_VL (str): Qwen2.5-VL family of models. See :func:`~torchtune.models.qwen2_5_vision.qwen2_5_vl_32b` CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` T5_ENCODER (str): T5 text encoder. See :func:`~torchtune.models.t5.t5_v1_1_xxl_encoder` QWEN3 (str): Qwen3 family of models. See :func:`~torchtune.models.qwen3.qwen3` @@ -122,6 +123,7 @@ class ModelType(Enum): PHI4: str = "phi4" REWARD: str = "reward" QWEN2: str = "qwen2" + QWEN2_5_VL: str = "qwen2_5_vl" CLIP_TEXT: str = "clip_text" T5_ENCODER: str = "t5_encoder" QWEN3: str = "qwen3"