Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401


# Check if 'transformers' is installed
Expand Down Expand Up @@ -120,6 +121,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
}

if name in monkey_patch_symbols:
Expand Down Expand Up @@ -189,5 +191,6 @@ def __getattr__(name: str):
"apply_liger_kernel_to_qwen3_moe",
"apply_liger_kernel_to_qwen3_next",
"apply_liger_kernel_to_smollm3",
"apply_liger_kernel_to_smolvlm",
]
)
158 changes: 158 additions & 0 deletions src/liger_kernel/transformers/model/smolvlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import TYPE_CHECKING
from typing import Optional
from typing import Union
from typing import Unpack

import torch

from transformers.models.smolvlm.modeling_smolvlm import SmolVLMCausalLMOutputWithPast
from transformers.utils.generic import can_return_tuple

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss

if TYPE_CHECKING:
from transformers.cache_utils import Cache
from transformers.utils.generic import TransformersKwargs


# Forward adapted to enable fused Linear + CE without materializing logits.
# Mirrors the pattern used for other multimodal models (e.g., InternVL, LLaVA).
@can_return_tuple
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional["Cache"] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None, # Added argument for liger-kernel
**lm_kwargs: Unpack["TransformersKwargs"], # renamed from kwargs
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The hidden states of the image encoder after modality projection.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id`. Tokens with indices set to `model.image_token_id` are
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> import requests
>>> import torch
>>> from PIL import Image
>>> from io import BytesIO
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> from transformers.image_utils import load_image
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
>>> processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
>>> model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", dtype=torch.bfloat16, device_map="auto")
>>> # Create inputs
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "video", "path": path/to/video},
... {"type": "text", "text": "What is happening in this video?"},
... ]
... }
... ]
>>> inputs = processor.apply_chat_template([messages], add_generation_prompt=True)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=256)
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> print(generated_texts)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_hidden_states=image_hidden_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
**lm_kwargs,
)

# Copied from llava.py
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = lm_kwargs.pop("shift_labels", None)
logits = None
loss = None

if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")

if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.text_config.hidden_size,
**lm_kwargs,
)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return SmolVLMCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
89 changes: 86 additions & 3 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,6 @@ def apply_liger_kernel_to_internvl(
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
Expand All @@ -2068,8 +2067,11 @@ def apply_liger_kernel_to_internvl(
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
logger.info("Apply liger cross entropy")

from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unintended changes in internvl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, I'm author of PR adding InternVL3(#878) to liger-kernel and this change is intended. I checked that on the transformers version when InternVL3 and SmolVLM2 implemented they aren't using CrossEntropyLoss inside modeling_*.py instead using loss_utils. So I fixed the InternVL3 implementation togather.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you hope I'll separate this into another PR

Copy link
Collaborator

@Tcc0403 Tcc0403 Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate it into another PR and open an issue to track if there's any other cross entropy not being patched correctly!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue #920

if fused_linear_cross_entropy:
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
if rms_norm:
Expand Down Expand Up @@ -2112,6 +2114,86 @@ def apply_liger_kernel_to_internvl(
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")


def apply_liger_kernel_to_smolvlm(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
model: Optional[PreTrainedModel] = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
NOTE: SmolVLM is not available in transformers<4.50.0

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.smolvlm import modeling_smolvlm

from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward

if cross_entropy:
logger.info("Apply liger cross entropy")

from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
if rms_norm:
modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm

if model is not None:
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)

kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.model.text_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")

if vision_liger_fn:
accept_params = inspect.signature(vision_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.model.vision_model
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default SmolVLMForConditionalGeneration.vision_model is SmolVLMVisionTransformer

WARNING  liger_kernel.transformers.monkey_patch:monkey_patch.py:2194 smolvlm_vision is not supported by Liger kernel.

Do you want to support it in this PR as well? (SmolVLMVisionMLP and LayerNorm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmolVLMVisionMLP seems not easy to patch kernel but LayerNorm seems possible, so I patched them!



def apply_liger_kernel_to_falcon_h1(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -2304,6 +2386,7 @@ def apply_liger_kernel_to_qwen3_next(
"phi3": apply_liger_kernel_to_phi3,
"paligemma": apply_liger_kernel_to_paligemma,
"falcon_h1": apply_liger_kernel_to_falcon_h1,
"smolvlm": apply_liger_kernel_to_smolvlm,
}


Expand Down
Loading
Loading