Skip to content

Commit c5f8dfa

Browse files
Documentation for the kernel API (#46754)
* doc: initial doc commit * doc: add missing content * feat: infer forward pass automatically * doc: improve doc * doc: address comments from Steven
1 parent 612c371 commit c5f8dfa

4 files changed

Lines changed: 204 additions & 0 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@
117117
title: Kernels
118118
- local: kernel_doc/loading_kernels
119119
title: Loading kernels
120+
- local: kernel_doc/writing_kernels
121+
title: Writing kernels
120122
title: Kernels
121123
- local: perf_torch_compile
122124
title: torch.compile

docs/source/en/kernel_doc/loading_kernels.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,31 @@ kernel_config = KernelConfig(
171171
)
172172
```
173173

174+
## Module fusion
175+
176+
Fuse adjacent modules into a single kernel by passing a tuple of `(class_name, path_pattern)` pairs as the key in [`KernelConfig`]. All patterns must share the same parent module. `*` matches any single path segment.
177+
178+
```python
179+
from transformers import AutoModelForCausalLM, KernelConfig
180+
181+
kernel_config = KernelConfig(
182+
{
183+
(
184+
("RMSNorm", "model.layers.*.post_attention_layernorm"),
185+
("MLP", "model.layers.*.mlp"),
186+
): "owner/fused-rmsnorm-mlp:RMSNormMLP",
187+
}
188+
)
189+
model = AutoModelForCausalLM.from_pretrained(
190+
"Qwen/Qwen3-0.6B",
191+
use_kernels=True,
192+
kernel_config=kernel_config,
193+
device_map="cuda",
194+
)
195+
```
196+
197+
Fusion requires the kernel repo to provide a companion `KernelNameLayout` class alongside the `KernelName` class. See the [Writing kernels](./writing_kernels) guide for how to implement one.
198+
174199
## Local kernels
175200

176201
Load kernels from local file paths with `use_local_kernel=True` in [`KernelConfig`]. This loads from a local filesystem path instead of a Hub repository.
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
<!--Copyright 2026 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Writing kernels
18+
19+
This guide explains how to write kernels that go beyond a stateless `forward` replacement. It covers two capabilities the extended `KernelConfig` API supports:
20+
21+
1. Parameter transformation: the kernel expects weights in a different layout than the original model (for example, renamed or merged parameters).
22+
2. Module fusion: the kernel replaces multiple adjacent modules with a single fused implementation.
23+
24+
For basic kernels (stateless `forward` replacements with no parameter changes), see the [kernels](https://github.com/huggingface/kernels) library documentation.
25+
26+
## Two-class pattern
27+
28+
Any kernel that carries its own parameters follows a two-class pattern.
29+
30+
- `KernelName`: contains only the `forward` pass. The `kernels` library uses this class to kernelize the model because it does not allow stateful kernel classes.
31+
- `KernelNameLayout`: an `nn.Module` that holds the parameters and monkey-patches the original module before the checkpoint is loaded. At runtime, `kernelize` replaces its `forward` with the `forward` from `KernelName`'. You do not need to define `forward`. Transformers injects one automatically with the same signature as `KernelName.forward`.
32+
33+
> [!IMPORTANT]
34+
35+
The naming convention is strict. The layout class must be named `{KernelName}Layout` and defined in the same module as `KernelName`.
36+
37+
## Parameter transformation
38+
39+
Use this pattern when the kernel expects weights under different names or in a different shape than the original model checkpoint.
40+
41+
The `KernelNameLayout` class has the same `__init__` signature as the module it replaces and declares a `conversion_mapping` class attribute that tells Transformers how to remap checkpoint keys to the new parameter names (see [Dynamic weight loading](../weightconverter) for more details).
42+
43+
```python
44+
import torch
45+
import torch.nn as nn
46+
47+
class CustomRMSNormLayout(nn.Module):
48+
conversion_mapping = [...] # rules that remap checkpoint keys to the new parameter names
49+
50+
def __init__(self, hidden_size: int, eps: float = 1e-6):
51+
super().__init__()
52+
self.scale = nn.Parameter(torch.ones(hidden_size))
53+
self.variance_epsilon = eps
54+
55+
56+
class CustomRMSNorm(nn.Module):
57+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
58+
input_dtype = hidden_states.dtype
59+
hidden_states = hidden_states.to(torch.float32)
60+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
61+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
62+
return self.scale * hidden_states.to(input_dtype)
63+
64+
65+
class layers:
66+
CustomRMSNorm = CustomRMSNorm
67+
```
68+
69+
> [!NOTE]
70+
> The `layers` class is required by the `kernels` library to expose the kernel entry point.
71+
72+
Load this kernel by passing the repo and class name to [`KernelConfig`]. The key is the original module class name from the model. The value points to the `KernelName` class (not the `Layout`) in the repo.
73+
74+
```python
75+
from transformers import AutoModelForCausalLM, KernelConfig
76+
77+
kernel_config = KernelConfig({"RMSNorm": "owner/my-kernel:CustomRMSNorm"})
78+
model = AutoModelForCausalLM.from_pretrained(
79+
"Qwen/Qwen3-0.6B",
80+
use_kernels=True,
81+
kernel_config=kernel_config,
82+
device_map="cuda",
83+
)
84+
```
85+
86+
When the model loads, Transformers:
87+
1. Loads `CustomRMSNorm` from the repo and looks for `CustomRMSNormLayout` in the same module.
88+
2. Monkey-patches every `RMSNorm` in the model with `CustomRMSNormLayout`.
89+
3. Remaps checkpoint weights using `conversion_mapping` so they load into the new parameter names.
90+
4. Calls `kernelize`, which replaces `CustomRMSNormLayout.forward` with `CustomRMSNorm.forward`.
91+
92+
## Module fusion
93+
94+
Use this pattern when a kernel replaces multiple adjacent modules with a single fused implementation. Because the fused module combines parameters from several original modules, the `KernelNameLayout.__init__` receives the instantiated child modules rather than their constructor arguments.
95+
96+
```python
97+
import torch
98+
import torch.nn as nn
99+
100+
class RMSNormMLPLayout(nn.Module):
101+
conversion_mapping = [...] # rules that remap checkpoint keys to the fused parameter names
102+
103+
def __init__(self, norm, mlp):
104+
super().__init__()
105+
self.variance_epsilon = norm.variance_epsilon
106+
self.scale = nn.Parameter(torch.empty_like(norm.weight))
107+
self.gate_up_proj = nn.Linear(
108+
mlp.gate_proj.in_features,
109+
mlp.gate_proj.out_features + mlp.up_proj.out_features,
110+
bias=mlp.gate_proj.bias is not None,
111+
device=mlp.gate_proj.weight.device,
112+
dtype=mlp.gate_proj.weight.dtype,
113+
)
114+
self.down_proj = nn.Linear(
115+
mlp.down_proj.in_features,
116+
mlp.down_proj.out_features,
117+
bias=mlp.down_proj.bias is not None,
118+
device=mlp.down_proj.weight.device,
119+
dtype=mlp.down_proj.weight.dtype,
120+
)
121+
self.act_fn = mlp.act_fn
122+
123+
124+
class RMSNormMLP(nn.Module):
125+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
126+
input_dtype = hidden_states.dtype
127+
hidden_states = hidden_states.to(torch.float32)
128+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
129+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
130+
hidden_states = self.scale * hidden_states.to(input_dtype)
131+
gate, up = self.gate_up_proj(hidden_states).chunk(2, dim=-1)
132+
return self.down_proj(self.act_fn(gate) * up)
133+
134+
135+
class layers:
136+
RMSNormMLP = RMSNormMLP
137+
```
138+
139+
To fuse modules, pass a tuple of `(class_name, path_pattern)` pairs as the key in `KernelConfig` instead of a plain string. All patterns must share the same parent module (Transformers fuses the children in that parent). The `*` wildcard matches any single path segment.
140+
141+
```python
142+
from transformers import AutoModelForCausalLM, KernelConfig
143+
144+
kernel_config = KernelConfig(
145+
{
146+
(
147+
("RMSNorm", "model.layers.*.post_attention_layernorm"),
148+
("MLP", "model.layers.*.mlp"),
149+
): "owner/my-kernel:RMSNormMLP",
150+
}
151+
)
152+
model = AutoModelForCausalLM.from_pretrained(
153+
"Qwen/Qwen3-0.6B",
154+
use_kernels=True,
155+
kernel_config=kernel_config,
156+
device_map="cuda",
157+
)
158+
```
159+
160+
When the model loads, Transformers:
161+
1. Loads `RMSNormMLP` from the repo and finds `RMSNormMLPLayout` in the same module.
162+
2. Matches every decoder layer at `model.layers.*` and builds a fused parent class whose `__init__` calls `RMSNormMLPLayout(post_attention_layernorm, mlp)`.
163+
3. Replaces the remaining child (`mlp`) with `nn.Identity()` to preserve the parent module's interface.
164+
4. Remaps checkpoint weights using `conversion_mapping`.
165+
5. Calls `kernelize`, which replaces `RMSNormMLPLayout.forward` with `RMSNormMLP.forward`.
166+
167+
> [!TIP]
168+
> The order of pairs in the fusion tuple determines the argument order passed to `KernelNameLayout.__init__`.

src/transformers/integrations/hub_kernels.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import functools
1415
import os
1516
import re
1617
import sys
@@ -765,6 +766,14 @@ def register_kernel_replacements_and_fusions(
765766
kernel_mod = sys.modules.get(kernel_cls.__module__)
766767
layout_cls = getattr(kernel_mod, f"{kernel_cls.__name__}Layout", None) if kernel_mod else None
767768

769+
if layout_cls is not None and "forward" not in layout_cls.__dict__:
770+
771+
@functools.wraps(kernel_cls.forward)
772+
def _noop_forward(self, *args, **kwargs):
773+
pass
774+
775+
layout_cls.forward = _noop_forward
776+
768777
# Case 1: no fusion.
769778
if isinstance(layer_name, str):
770779
# No layout class: stateless kernel, leave for kernels.kernelize.

0 commit comments

Comments
 (0)