Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 57 additions & 24 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,80 +89,108 @@ def use_kernel_func_from_hub(func_name: str):
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",
layer_name="MultiScaleDeformableAttention",
version=1,
)
},
"Llama4TextMoe": {
"cuda": LayerRepository(
repo_id="kernels-community/moe",
layer_name="Llama4TextMoe",
version=1,
)
},
"SwiGLUMLP": {
"cuda": LayerRepository(
repo_id="kernels-community/liger-kernels",
layer_name="LigerSwiGLUMLP",
version=1,
),
},
"RMSNorm": {
# TODO: training but not compile should work...
"cuda": {
#Mode.TRAINING: LayerRepository(
# repo_id="kernels-community/liger-kernels",
# layer_name="LigerRMSNorm",
# version=1,
#),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
repo_id="kernels-community/liger-kernels",
layer_name="LigerRMSNorm",
# revision="pure-layer-test",
version=1,
),
},
"rocm": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/liger-kernels",
layer_name="LigerRMSNorm",
version=1,
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
repo_id="kernels-community/liger-kernels",
layer_name="LigerRMSNorm",
)
version=1,
),
},
"xpu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/rmsnorm",
layer_name="RMSNorm",
version=1,
)
},
"mps": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/mlx_rmsnorm",
layer_name="RMSNorm",
version=1,
)
},
"npu": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/liger-kernels",
layer_name="LigerRMSNorm",
version=1,
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
repo_id="kernels-community/liger-kernels",
layer_name="LigerRMSNorm",
)
version=1,
),
},
},
"MLP": {
"cuda": LayerRepository(
repo_id="medmekk/triton-llama-mlp",
layer_name="TritonLlamaMLP",
)
},
"MegaBlocksMoeMLP": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
version=1,
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
version=1,
),
},
"rocm": {
Mode.INFERENCE: LayerRepository(
repo_id="ahadnagy/megablocks",
layer_name="MegaBlocksMoeMLP",
version=1,
)
},
"xpu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="MegaBlocksMoeMLP",
version=1,
)
},
"cpu": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/megablocks",
layer_name="CPUMegaBlocksMoeMLP",
version=1,
)
},
},
Expand Down Expand Up @@ -218,21 +246,26 @@ def use_kernel_func_from_hub(func_name: str):

# Add function kernel mappings if FuncRepository is available
if FuncRepository is not None:
_KERNEL_MAPPING["rotary_pos_emb"] = {
"xpu": {
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
)
},
"cuda": {
Mode.TRAINING: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
),
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
_FUNCTION_KERNEL_MAPPING = {
"rotary_pos_emb": {
"xpu": {
Mode.INFERENCE: FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers", version=1
)
},
"cuda": FuncRepository(
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers", version=1
),
},
"ForCausalLMLoss": {
"cuda": {
Mode.TRAINING | Mode.TORCH_COMPILE: FuncRepository(
repo_id="kernels-community/liger-kernels", func_name="LigerForCausalLMLoss", version=1
),
},
},
}
_KERNEL_MAPPING = _KERNEL_MAPPING | _FUNCTION_KERNEL_MAPPING

def has_key(d, key):
return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss

from ..integrations import use_kernel_func_from_hub
from .loss_d_fine import DFineForObjectDetectionLoss
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_deimv2 import Deimv2ForObjectDetectionLoss
Expand Down Expand Up @@ -44,23 +45,29 @@ def fixed_cross_entropy(
return loss


@use_kernel_func_from_hub("ForCausalLMLoss")
def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: torch.Tensor | None = None,
ignore_index: int = -100,
shift_labels: torch.Tensor | None = None,
hidden_states: torch.Tensor | None = None,
lm_head_weight: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()

if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()

if logits is None and hidden_states is not None and lm_head_weight is not None:
logits = torch.nn.functional.linear(hidden_states, lm_head_weight)

# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()

# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@
from accelerate.utils import extract_model_from_parallel

if TYPE_CHECKING:
from kernels.layer.mode import Mode

from ._typing import DeviceMeshLike


Expand Down Expand Up @@ -3746,19 +3748,22 @@ def _get_dtype_plan(self, dtype: torch.dtype) -> dict:

return dtype_plan

def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None):
def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None, mode: "Mode | None" = None):
"""
Set whether or not to use the `kernels` library to kernelize some layers of the model.
Args:
use_kernels (`bool`):
Whether or not to use the `kernels` library to kernelize some layers of the model.
kernel_config (`KernelConfig`, *optional*):
The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used.
mode (`Mode`, *optional*):
The mode that should be applied during `kernelize`. Optional, defaults to either training or inference mode
based on the internal `training` flag.
"""
if use_kernels:
if not is_kernels_available():
raise ValueError(
"`use_kernels=True` requires kernels>=0.9.0. Please install the latest version with `pip install -U kernels`"
"Kernels are not available. To use kernels, please install kernels using `pip install -U kernels`"
)
from kernels import use_kernel_mapping

Expand All @@ -3778,10 +3783,10 @@ def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None
# Param inherit_mapping should be False to avoid still loading kernel from remote
inherit_mapping = not kernel_config.use_local_kernel
with use_kernel_mapping(kernel_config.kernel_mapping, inherit_mapping=inherit_mapping):
self.use_kernels = True
self.kernelize(mode=mode)
# We use the default kernel mapping in .integrations.hub_kernels
else:
self.use_kernels = True
self.kernelize(mode=mode)
else:
self.use_kernels = False

Expand Down Expand Up @@ -4592,7 +4597,7 @@ def loss_function(self):
def loss_function(self, value):
self._loss_function = value

def kernelize(self, mode=None):
def kernelize(self, mode: "Mode | None" = None):
"""Temporarily register hidden kernel wrappers so `kernelize` can discover and replace them."""
if not is_kernels_available():
raise ValueError(
Expand Down Expand Up @@ -4621,7 +4626,7 @@ def detach_hidden_kernels(module):
try:
self.apply(attach_hidden_kernels)

mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode
mode = Mode.TRAINING if self.training else Mode.INFERENCE if mode is None else mode
kernelize(self, device=Device(type=self.device.type), mode=mode)
self._use_kernels = True

Expand Down
18 changes: 16 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...loss.loss_utils import ForCausalLMLoss
from ...masking_utils import create_causal_mask
from ...modeling_layers import (
GenericForQuestionAnswering,
Expand Down Expand Up @@ -168,6 +169,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
return q_embed, k_embed


@use_kernel_forward_from_hub("SwiGLUMLP")
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -425,6 +427,7 @@ def forward(
)


@use_kernelized_func(ForCausalLMLoss)
@auto_docstring
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
Expand Down Expand Up @@ -484,11 +487,22 @@ def forward(
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
hidden_states = hidden_states[:, slice_indices, :]

# We only compute logits during inference (no labels given) or when we do not use any kernel (materialization of logits needed)
logits = self.lm_head(hidden_states) if labels is None or not self.use_kernels else None

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
loss = self.loss_function(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
hidden_size=hidden_states.shape[-1],
**kwargs,
)

return CausalLMOutputWithPast(
loss=loss,
Expand Down
Loading