Skip to content

Commit aa310e0

Browse files
dsv4: use fused hc head in nextn
1 parent 327ab15 commit aa310e0

2 files changed

Lines changed: 82 additions & 0 deletions

File tree

python/sglang/srt/models/deepseek_v4_nextn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,17 @@ def hc_head(
117117
hc_scale: torch.Tensor,
118118
hc_base: torch.Tensor,
119119
):
120+
if x.numel() > 0:
121+
from sglang.srt.layers.mhc_head import fused_hc_head
122+
123+
return fused_hc_head(
124+
x.contiguous(),
125+
hc_fn,
126+
hc_scale,
127+
hc_base,
128+
norm_eps=self.rms_norm_eps,
129+
hc_eps=self.hc_eps,
130+
)
120131
shape, dtype = x.size(), x.dtype
121132
x = x.flatten(1).float()
122133
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.rms_norm_eps)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
3+
import torch
4+
5+
from sglang.srt.models.deepseek_v4_nextn import DeepseekV4ModelNextN
6+
from sglang.test.ci.ci_register import register_cuda_ci
7+
8+
9+
register_cuda_ci(est_time=5, stage="base-b", runner_config="1-gpu-large")
10+
11+
12+
def _reference_hc_head(
13+
x: torch.Tensor,
14+
hc_fn: torch.Tensor,
15+
hc_scale: torch.Tensor,
16+
hc_base: torch.Tensor,
17+
norm_eps: float,
18+
hc_eps: float,
19+
) -> torch.Tensor:
20+
shape, dtype = x.size(), x.dtype
21+
flat = x.flatten(1).float()
22+
rsqrt = torch.rsqrt(flat.square().mean(-1, keepdim=True) + norm_eps)
23+
mixes = torch.nn.functional.linear(flat, hc_fn) * rsqrt
24+
pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps
25+
out = torch.sum(pre.unsqueeze(-1) * flat.view(shape), dim=1)
26+
return out.to(dtype)
27+
28+
29+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
30+
class TestDeepseekV4NextNHcHead(unittest.TestCase):
31+
def _run_case(self, tokens: int, hc_mult: int, hidden_size: int) -> None:
32+
torch.manual_seed(1234 + tokens + hidden_size)
33+
device = "cuda"
34+
dtype = torch.bfloat16
35+
norm_eps = 1.0e-6
36+
hc_eps = 1.0e-3
37+
38+
model = object.__new__(DeepseekV4ModelNextN)
39+
model.rms_norm_eps = norm_eps
40+
model.hc_eps = hc_eps
41+
42+
x = torch.randn(
43+
tokens, hc_mult, hidden_size, device=device, dtype=torch.float32
44+
).to(dtype)
45+
hc_fn = torch.randn(
46+
hc_mult,
47+
hc_mult * hidden_size,
48+
device=device,
49+
dtype=torch.float32,
50+
) * 0.02
51+
hc_scale = torch.randn(1, device=device, dtype=torch.float32)
52+
hc_base = torch.randn(hc_mult, device=device, dtype=torch.float32)
53+
54+
expected = _reference_hc_head(x, hc_fn, hc_scale, hc_base, norm_eps, hc_eps)
55+
actual = DeepseekV4ModelNextN.hc_head(model, x, hc_fn, hc_scale, hc_base)
56+
57+
self.assertEqual(actual.shape, (tokens, hidden_size))
58+
self.assertEqual(actual.dtype, dtype)
59+
torch.testing.assert_close(
60+
actual.float(), expected.float(), rtol=3.0e-2, atol=3.0e-2
61+
)
62+
63+
def test_nextn_hc_head_uses_fused_kernel_at_dsv4_shape(self):
64+
self._run_case(tokens=16, hc_mult=4, hidden_size=7168)
65+
66+
def test_nextn_hc_head_handles_empty_batch(self):
67+
self._run_case(tokens=0, hc_mult=4, hidden_size=256)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

0 commit comments

Comments
 (0)