Skip to content

Commit 10b6cde

Browse files
Manan17Manan Shahvaibhavjindallancerts
authored
Llama4 rope implementation (#843)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR implements a fused Triton kernel for Llama4's complex-polar RoPE, providing significant performance and memory improvements over the original PyTorch implementation. - Prepares RoPE frequencies into `[seq_len, head_dim_half]` real/imag tables from complex or split inputs, robust to extra leading dims (text/vision). - Launches a 2D grid over `(batch*seq, head)` and tiles over `head_dim/2`, maximizing parallelism and cache reuse. - Operates on interleaved real/imag pairs directly, avoiding complex views and extra buffers. - Applies the complex rotation with FMAs per tile; loads each frequency tile once and reuses it across heads in the program. - Uses conditional dtype/contiguity (fp32 only when inputs are fp32; otherwise bf16/fp16) to cut copies and autograd overhead. - Backward reuses the same kernel with `imag_sign = -1` (conjugate via sign flip), avoiding building conjugated tensors. - Fused operations: - Fuses Q and K rotations into a single kernel launch. - Fuses load → complex multiply → store for each tile (no separate ops or temp complex tensors). - Fuses per-head processing into the grid mapping (no serial head loops). - Fuses conjugation handling into the kernel via `imag_sign` instead of a separate conj step. We cannot use the current rope here as llama4 rope implementation uses complex computation using real/imaginary values (ie `torch.view_as_complex`) <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> python -m pytest test/convergence/bf16/test_mini_models.py -k llama4 python -m pytest test/convergence/fp32/test_mini_models.py -k llama4 python -m pytest test/convergence/bf16/test_mini_models_multimodal.py -k llama4 Benchmarks: Memory: <img width="1000" height="600" alt="llama4_rope_memory_full" src="https://github.com/user-attachments/assets/b287fb8d-6708-4667-89f1-b190cc293877" /> ~60% reduction in memory Speed: Forward - <img width="1000" height="600" alt="llama4_rope_speed_forward" src="https://github.com/user-attachments/assets/865015ef-1fb7-42a0-960c-d0c404159cbc" /> Backward - <img width="1000" height="600" alt="llama4_rope_speed_backward" src="https://github.com/user-attachments/assets/42bf041f-04fa-442e-8751-782f6883afb3" /> Full - <img width="1000" height="600" alt="llama4_rope_speed_full" src="https://github.com/user-attachments/assets/6018bf36-ee27-4288-9dbb-99b3062f50b5" /> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Manan Shah <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent 65c0ad1 commit 10b6cde

File tree

9 files changed

+643
-7
lines changed

9 files changed

+643
-7
lines changed

benchmark/data/all_benchmark_data.csv

Lines changed: 65 additions & 1 deletion
Large diffs are not rendered by default.
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import torch
2+
import triton
3+
4+
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
5+
from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding
6+
from transformers.models.llama4.modeling_llama4 import apply_rotary_emb
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb
15+
from liger_kernel.utils import infer_device
16+
from liger_kernel.utils import transformers_version_dispatch
17+
18+
device = infer_device()
19+
20+
21+
def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22+
provider = input.kernel_provider
23+
mode = input.kernel_operation_mode
24+
25+
extra_benchmark_config = input.extra_benchmark_config
26+
num_q_heads = extra_benchmark_config["num_q_heads"]
27+
num_kv_heads = extra_benchmark_config["num_kv_heads"]
28+
dtype = extra_benchmark_config["dtype"]
29+
30+
# x can be either hidden_size or seq_len
31+
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
32+
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
33+
34+
head_dim = hidden_size // num_q_heads
35+
36+
# Create Llama4TextConfig for the rotary embedding
37+
config = Llama4TextConfig(
38+
hidden_size=hidden_size,
39+
num_attention_heads=num_q_heads,
40+
num_key_value_heads=num_kv_heads,
41+
head_dim=head_dim,
42+
max_position_embeddings=seq_len,
43+
rope_theta=10000.0,
44+
rope_scaling=None, # Use default rope type
45+
)
46+
47+
rotary_emb = transformers_version_dispatch(
48+
"4.48.0",
49+
Llama4TextRotaryEmbedding,
50+
Llama4TextRotaryEmbedding,
51+
before_kwargs={"config": config, "device": device},
52+
after_kwargs={"config": config, "device": device},
53+
)
54+
55+
q = torch.randn(
56+
(1, seq_len, num_q_heads, head_dim),
57+
device=device,
58+
requires_grad=True,
59+
dtype=dtype,
60+
)
61+
k = torch.randn(
62+
(1, seq_len, num_kv_heads, head_dim),
63+
device=device,
64+
requires_grad=True,
65+
dtype=dtype,
66+
)
67+
dq, dk = (
68+
torch.randn_like(q, device=device, dtype=dtype),
69+
torch.randn_like(k, device=device),
70+
)
71+
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
72+
freqs_cis = rotary_emb(q, pos_ids)
73+
74+
def fwd():
75+
if provider == "liger":
76+
return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
77+
elif provider == "huggingface":
78+
return apply_rotary_emb(q, k, freqs_cis)
79+
else:
80+
raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding")
81+
82+
if mode == "forward":
83+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
84+
fwd,
85+
grad_to_none=[q, k],
86+
rep=400,
87+
quantiles=QUANTILES,
88+
)
89+
elif mode == "backward":
90+
q_out, k_out = fwd()
91+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
92+
lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
93+
grad_to_none=[q, k],
94+
rep=400,
95+
quantiles=QUANTILES,
96+
)
97+
elif mode == "full":
98+
99+
def full():
100+
q_out, k_out = fwd()
101+
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
102+
103+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
104+
full,
105+
grad_to_none=[q, k],
106+
rep=400,
107+
quantiles=QUANTILES,
108+
)
109+
return SingleBenchmarkRunOutput(
110+
y_20=ms_20,
111+
y_50=ms_50,
112+
y_80=ms_80,
113+
)
114+
115+
116+
def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
117+
provider = input.kernel_provider
118+
119+
extra_benchmark_config = input.extra_benchmark_config
120+
num_q_heads = extra_benchmark_config["num_q_heads"]
121+
num_kv_heads = extra_benchmark_config["num_kv_heads"]
122+
dtype = extra_benchmark_config["dtype"]
123+
124+
# x can be either hidden_size or seq_len
125+
hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
126+
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
127+
128+
head_dim = hidden_size // num_q_heads
129+
130+
# Create Llama4TextConfig for the rotary embedding
131+
config = Llama4TextConfig(
132+
hidden_size=hidden_size,
133+
num_attention_heads=num_q_heads,
134+
num_key_value_heads=num_kv_heads,
135+
head_dim=head_dim,
136+
max_position_embeddings=seq_len,
137+
rope_theta=10000.0,
138+
rope_scaling=None, # Use default rope type
139+
)
140+
141+
rotary_emb = transformers_version_dispatch(
142+
"4.48.0",
143+
Llama4TextRotaryEmbedding,
144+
Llama4TextRotaryEmbedding,
145+
before_kwargs={"config": config, "device": device},
146+
after_kwargs={"config": config, "device": device},
147+
)
148+
149+
q = torch.randn(
150+
(1, seq_len, num_q_heads, head_dim),
151+
device=device,
152+
requires_grad=True,
153+
dtype=dtype,
154+
)
155+
k = torch.randn(
156+
(1, seq_len, num_kv_heads, head_dim),
157+
device=device,
158+
requires_grad=True,
159+
dtype=dtype,
160+
)
161+
dq, dk = (
162+
torch.randn_like(q, device=device, dtype=dtype),
163+
torch.randn_like(k, device=device),
164+
)
165+
pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
166+
freqs_cis = rotary_emb(q, pos_ids)
167+
168+
def full():
169+
if provider == "liger":
170+
q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
171+
else:
172+
q_out, k_out = apply_rotary_emb(q, k, freqs_cis)
173+
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
174+
175+
mem_50, mem_20, mem_80 = _test_memory(
176+
full,
177+
quantiles=QUANTILES,
178+
)
179+
return SingleBenchmarkRunOutput(
180+
y_20=mem_20,
181+
y_50=mem_50,
182+
y_80=mem_80,
183+
)
184+
185+
186+
if __name__ == "__main__":
187+
args = parse_benchmark_script_args()
188+
189+
common_configs_varying_hidden_size = {
190+
"kernel_name": "llama4_rope",
191+
"x_name": "H",
192+
"x_label": "hidden size",
193+
"x_values": [32 * (2**i) for i in range(4, 10, 2)],
194+
"kernel_providers": ["liger", "huggingface"],
195+
"extra_benchmark_configs": [
196+
{
197+
"dtype": torch.bfloat16,
198+
"seq_len": 2048,
199+
"num_q_heads": 32,
200+
"num_kv_heads": 8,
201+
}
202+
],
203+
"overwrite": args.overwrite,
204+
}
205+
run_benchmarks(
206+
bench_test_fn=bench_speed_llama4_rope,
207+
kernel_operation_modes=["forward", "backward", "full"],
208+
metric_name="speed",
209+
metric_unit="ms",
210+
**common_configs_varying_hidden_size,
211+
)
212+
run_benchmarks(
213+
bench_test_fn=bench_memory_llama4_rope,
214+
kernel_operation_modes=["full"],
215+
metric_name="memory",
216+
metric_unit="MB",
217+
**common_configs_varying_hidden_size,
218+
)
219+
220+
common_configs_varying_seq_len = {
221+
"kernel_name": "llama4_rope",
222+
"x_name": "T",
223+
"x_label": "sequence length",
224+
"x_values": [2**i for i in range(10, 15)],
225+
"kernel_providers": ["liger", "huggingface"],
226+
"extra_benchmark_configs": [
227+
{
228+
"dtype": torch.bfloat16,
229+
"hidden_size": 8192,
230+
"num_q_heads": 32,
231+
"num_kv_heads": 8,
232+
}
233+
],
234+
"overwrite": args.overwrite,
235+
}
236+
run_benchmarks(
237+
bench_test_fn=bench_speed_llama4_rope,
238+
kernel_operation_modes=["forward", "backward", "full"],
239+
metric_name="speed",
240+
metric_unit="ms",
241+
**common_configs_varying_seq_len,
242+
)
243+
run_benchmarks(
244+
bench_test_fn=bench_memory_llama4_rope,
245+
kernel_operation_modes=["full"],
246+
metric_name="memory",
247+
metric_unit="MB",
248+
**common_configs_varying_seq_len,
249+
)

0 commit comments

Comments
 (0)