Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ loss.backward()
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


## Low-level APIs
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
Expand Down Expand Up @@ -98,6 +99,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
"apply_liger_kernel_to_llama4",
Expand Down Expand Up @@ -163,6 +165,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
"apply_liger_kernel_to_llama4",
Expand Down
150 changes: 150 additions & 0 deletions src/liger_kernel/transformers/model/internvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.models.internvl.modeling_internvl import InternVLCausalLMOutputWithPast
from transformers.utils import can_return_tuple

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


# Copied from https://github.com/huggingface/transformers/blob/d888bd435d0c0eaabaabad5b33d52af518c7187c/src/transformers/models/internvl/modeling_internvl.py#L862
@can_return_tuple
def lce_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
skip_logits: Optional[bool] = None, # Added argument for liger-kernel
**lm_kwargs, # renamed from kwargs
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
Example:

```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> torch_device = "cuda"
>>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
>>> model = AutoModelForImageTextToText.from_pretrained(
... "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
... )

>>> messages = [
... {
... "role": "user",
... "content": [
... {
... "type": "image",
... "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
... },
... {
... "type": "image",
... "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
... },
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
... ],
... },
... ]

>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
>>> generate_ids = model.generate(**inputs, max_new_tokens=200)
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)

outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
)

# Copied from llava.py
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = lm_kwargs.pop("shift_labels", None)
logits = None
loss = None

if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")

if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**lm_kwargs,
)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return InternVLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
81 changes: 80 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
from types import MethodType
from typing import Callable
from typing import Optional

import transformers

Expand Down Expand Up @@ -1334,7 +1335,6 @@ def apply_liger_kernel_to_qwen2(
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
print("Applied Liger kernels to Qwen2")


def apply_liger_kernel_to_qwen3(
Expand Down Expand Up @@ -2029,6 +2029,84 @@ def apply_liger_kernel_to_glm4v_moe(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_internvl(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
model: Optional[PreTrainedModel] = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
NOTE: InternVL is not available in transformers<4.52.1

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.internvl import modeling_internvl

from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
if rms_norm:
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm

if model is not None:
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)

kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.language_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")

if vision_liger_fn:
accept_params = inspect.signature(vision_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.vision_tower
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
Expand All @@ -2038,6 +2116,7 @@ def apply_liger_kernel_to_glm4v_moe(
"glm4": apply_liger_kernel_to_glm4,
"glm4v": apply_liger_kernel_to_glm4v,
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
"internvl": apply_liger_kernel_to_internvl,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
"llama4": apply_liger_kernel_to_llama4,
Expand Down
64 changes: 64 additions & 0 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_glm4v
from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_internvl
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llama4
from liger_kernel.transformers import apply_liger_kernel_to_llava
Expand All @@ -51,6 +52,7 @@
from test.utils import revert_liger_kernel_to_glm4v
from test.utils import revert_liger_kernel_to_glm4v_moe
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_internvl
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llama4
from test.utils import revert_liger_kernel_to_llava
Expand Down Expand Up @@ -190,6 +192,15 @@
except ImportError:
SMOLLM3_AVAILABLE = False

try:
# InternVL is only available in transformers>=4.52.1
from transformers.models.internvl.configuration_internvl import InternVLConfig
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration

INTERNVL_AVAILABLE = True
except ImportError:
INTERNVL_AVAILABLE = False

from liger_kernel.utils import infer_device

device = infer_device()
Expand Down Expand Up @@ -1022,6 +1033,38 @@
),
)

if INTERNVL_AVAILABLE:
MINI_MODEL_SETUPS["mini_internvl"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_internvl,
liger_kernel_patch_revert_func=revert_liger_kernel_to_internvl,
model_class=InternVLForConditionalGeneration,
mini_model_config=InternVLConfig(
text_config=Qwen2Config(
rms_norm_eps=1e-5,
hidden_size=256, # 1024
intermediate_size=1024, # 4096
hidden_act="silu",
num_hidden_layers=4, # 24
num_attention_heads=4, # 16
num_key_value_heads=2, # 16
max_position_embeddings=4096, # 8192
vocab_size=32000, # 151936
bos_token_id=1,
eos_token_id=2,
pad_token_id=2,
tie_word_embeddings=False,
),
vision_config={
"hidden_size": 256, # 1024
"intermediate_size": 1024, # 4096
"num_hidden_layers": 4, # 24
"num_attention_heads": 4, # 16
},
image_token_id=10,
attn_implementation="sdpa", # default value, pytorch native attention
),
)


def create_model(model_name="mini_llama4"):
"""
Expand Down Expand Up @@ -1080,6 +1123,8 @@ def run_mini_model(
else:
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)

print(MINI_MODEL_SETUPS[model_name])

model = create_model(model_name).to(dtype).to(device)

train_dataset = load_from_disk(DEFAULT_DATASET_PATH)
Expand Down Expand Up @@ -1432,6 +1477,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_internvl",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not INTERNVL_AVAILABLE,
reason="InternVL not available in this version of transformers",
),
],
),
# TODO: mixtral is flaky so disable the test for now
# pytest.param(
# "mini_mixtral",
Expand Down
Loading