Skip to content

Commit cb747b3

Browse files
Narsilmht-sharma
andauthored
Add deepseekv3 (#2968)
* Add fp8 support moe models add deepseekv3 format codfe' update dockerfile update doc * Small modifications. * Moe kernels 0.8.1 * Upgrade to 0.8.1 * Fixing moe import. * Black. * Apply suggestions from code review Co-authored-by: Mohit Sharma <[email protected]> * Fixing Mixtral + Nits. * Put link to ref. * Fix other call locations. * Scoring func `softmax` is the only one that works. --------- Co-authored-by: Mohit Sharma <[email protected]>
1 parent 80e7d98 commit cb747b3

File tree

16 files changed

+864
-26
lines changed

16 files changed

+864
-26
lines changed

Dockerfile_amd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
279279

280280
FROM kernel-builder AS moe-kernels
281281
WORKDIR /usr/src
282-
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
282+
ENV MOE_KERNELS_BRANCH=d7e042bf9f7aff10c631212fc71b24895d66eb59
283283
ENV VLLM_TARGET_DEVICE=rocm
284284
RUN git clone https://github.com/danieldk/moe-kernels.git && \
285285
cd moe-kernels && \

docs/source/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
55

66
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
7+
- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
78
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
89
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
910
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)

flake.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
66
};
77
nix-filter.url = "github:numtide/nix-filter";
8-
tgi-nix.url = "github:huggingface/text-generation-inference-nix/moe-kernels-0.8.0";
8+
tgi-nix.url = "github:huggingface/text-generation-inference-nix/moe_0_8_1";
99
nixpkgs.follows = "tgi-nix/nixpkgs";
1010
flake-utils.url = "github:numtide/flake-utils";
1111
rust-overlay = {

launcher/src/main.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,7 @@ enum Gpu {
16351635
A40,
16361636
H100,
16371637
A100,
1638+
H200,
16381639
Unknown(String),
16391640
}
16401641

@@ -1661,6 +1662,7 @@ impl From<&str> for Gpu {
16611662
"nvidia-a100-sxm4-40gb" => Gpu::A100,
16621663
"nvidia-a100-80gb-pcie" => Gpu::A100,
16631664
"nvidia-a100" => Gpu::A100,
1665+
"nvidia-h200" => Gpu::H200,
16641666
card => Gpu::Unknown(card.to_string()),
16651667
}
16661668
}
@@ -1678,6 +1680,7 @@ impl std::fmt::Display for Gpu {
16781680
Gpu::A40 => write!(f, "nvidia-a40"),
16791681
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
16801682
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
1683+
Gpu::H200 => write!(f, "nvida-h200"),
16811684
Gpu::Unknown(card) => write!(f, "{}", card),
16821685
}
16831686
}
@@ -1702,11 +1705,13 @@ impl ComputeType {
17021705
// https://www.nvidia.com/en-us/data-center/a40/
17031706
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
17041707
Gpu::A40 => Some(149 * 10u64.pow(12)),
1708+
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
1709+
Gpu::A100 => Some(312 * 10u64.pow(12)),
17051710
// https://www.nvidia.com/en-us/data-center/h100/
17061711
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
17071712
Gpu::H100 => Some(900 * 10u64.pow(12)),
1708-
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
1709-
Gpu::A100 => Some(312 * 10u64.pow(12)),
1713+
// https://www.nvidia.com/en-us/data-center/h200/
1714+
Gpu::H200 => Some(989 * 10u64.pow(12)),
17101715
Gpu::Unknown(card) => {
17111716
tracing::warn!("Unkown compute for card {card}");
17121717
None

router/src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ pub enum Config {
224224
Qwen2,
225225
Opt,
226226
T5,
227+
DeepseekV2,
228+
DeepseekV3,
227229
}
228230

229231
#[derive(Clone, Debug, Serialize, Deserialize)]

server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ marlin-kernels = [
7575
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
7676
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
7777
]
78-
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.0/moe_kernels-0.8.0+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
78+
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.1/moe_kernels-0.8.1+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
7979

8080
[tool.pytest.ini_options]
8181
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

server/text_generation_server/layers/fp8.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
except ImportError:
2020
marlin_kernels = None
2121

22+
try:
23+
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
24+
except ImportError:
25+
w8a8_block_fp8_matmul = None
26+
per_token_group_quant_fp8 = None
27+
2228
quant_dtype: torch.dtype = (
2329
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
2430
)
@@ -38,7 +44,6 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
3844
"""
3945

4046
if SYSTEM == "cuda":
41-
4247
major, _ = torch.cuda.get_device_capability()
4348
# Marlin is W8A16, use it when:
4449
#
@@ -180,14 +185,29 @@ def fp8_quantize(
180185
class HybridFP8UnquantLoader(WeightsLoader):
181186
"""Weight loader that loads FP8 and unquantized Torch tensors."""
182187

183-
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
188+
def __init__(
189+
self,
190+
activation_scale_ub: Optional[float],
191+
to_fp8: bool,
192+
weight_block_size: Optional[List[int]] = None,
193+
):
184194
self.activation_scale_ub = activation_scale_ub
185195
self.to_fp8 = to_fp8
196+
self.weight_block_size = weight_block_size
186197

187198
def get_weights(self, weights: "Weights", prefix: str):
188199
w = weights.get_tensor(f"{prefix}.weight")
189200

190201
if w.dtype == torch.float8_e4m3fn:
202+
if self.weight_block_size is not None:
203+
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
204+
return Fp8Weight(
205+
weight=w,
206+
weight_scale=scale,
207+
activation_scale_ub=self.activation_scale_ub,
208+
dtype=weights.dtype,
209+
weight_block_size=self.weight_block_size,
210+
)
191211
# FP8 branch
192212
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
193213

@@ -276,6 +296,21 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
276296

277297
# FP8 branch
278298
if w.dtype == torch.float8_e4m3fn:
299+
if self.weight_block_size is not None:
300+
scale = [
301+
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
302+
for p in prefixes
303+
]
304+
scale = torch.cat(scale, dim=dim)
305+
scale = scale.to(weights.device)
306+
return Fp8Weight(
307+
weight=w,
308+
weight_scale=scale,
309+
activation_scale_ub=self.activation_scale_ub,
310+
dtype=weights.dtype,
311+
weight_block_size=self.weight_block_size,
312+
)
313+
279314
scale = [
280315
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
281316
for p, shape in zip(prefixes, shapes)
@@ -321,6 +356,18 @@ def get_weights_row(self, weights: "Weights", prefix: str):
321356
w = weights.get_sharded(f"{prefix}.weight", dim=1)
322357
# FP8 branch
323358
if w.dtype == torch.float8_e4m3fn:
359+
if self.weight_block_size is not None:
360+
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
361+
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
362+
363+
return Fp8Weight(
364+
weight=w,
365+
weight_scale=scale,
366+
activation_scale_ub=self.activation_scale_ub,
367+
dtype=weights.dtype,
368+
weight_block_size=self.weight_block_size,
369+
)
370+
324371
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
325372

326373
if SYSTEM == "cuda":
@@ -355,6 +402,7 @@ class Fp8Weight(Weight):
355402
input_scale: Optional[torch.Tensor] = None
356403
activation_scale_ub: Optional[float] = None
357404
force_w8a16: bool = False
405+
weight_block_size: Optional[List[int]] = None
358406

359407
def get_linear(self, bias: torch.Tensor):
360408
if self.weight_scale is None:
@@ -371,6 +419,7 @@ def get_linear(self, bias: torch.Tensor):
371419
bias=bias,
372420
input_scale=self.input_scale,
373421
scale_upper_bound=self.activation_scale_ub,
422+
weight_block_size=self.weight_block_size,
374423
)
375424

376425

@@ -385,6 +434,7 @@ def __init__(
385434
bias: Optional[torch.Tensor] = None,
386435
input_scale: Optional[torch.Tensor] = None,
387436
scale_upper_bound: Optional[float] = None,
437+
weight_block_size: Optional[List[int]] = None,
388438
) -> None:
389439
super().__init__()
390440
if CUTLASS_FP8_AVAILABLE:
@@ -398,6 +448,7 @@ def __init__(
398448
self.qweight = qweight
399449
self.scale = scale.float()
400450
self.input_scale = input_scale.float() if input_scale is not None else None
451+
self.weight_block_size = weight_block_size
401452

402453
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
403454
self.scale_upper_bound = torch.tensor(
@@ -431,6 +482,7 @@ def from_fp8(
431482
) -> "Fp8Linear":
432483
input_scale = kwargs.get("input_scale", None)
433484
scale_upper_bound = kwargs.get("scale_upper_bound", None)
485+
weight_block_size = kwargs.get("weight_block_size", None)
434486

435487
return cls(
436488
qweight=weight,
@@ -439,6 +491,7 @@ def from_fp8(
439491
scale_upper_bound=scale_upper_bound,
440492
bias=bias,
441493
dtype=dtype,
494+
weight_block_size=weight_block_size,
442495
)
443496

444497
@classmethod
@@ -450,6 +503,25 @@ def get_shared_device_identity(cls, device):
450503
return cls._device_identity_cache[device]
451504

452505
def forward(self, input: torch.Tensor) -> torch.Tensor:
506+
if self.weight_block_size is not None:
507+
# https://arxiv.org/pdf/2412.19437
508+
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
509+
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
510+
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
511+
# channels).
512+
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
513+
output = w8a8_block_fp8_matmul(
514+
qinput,
515+
self.qweight,
516+
scale,
517+
self.scale,
518+
self.weight_block_size,
519+
output_dtype=input.dtype,
520+
)
521+
522+
if self.bias is not None:
523+
output = output + self.bias
524+
return output.to(dtype=input.dtype)
453525
if CUTLASS_FP8_AVAILABLE:
454526
# cutlass FP8 supports per-token scales, so get non-scalar scales.
455527
qinput, scale = fp8_quantize(

server/text_generation_server/layers/moe/__init__.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
up_proj_name: str = "up_proj",
5353
down_proj_name: str = "down_proj",
5454
hidden_act: str = "silu",
55+
scoring_func: Optional[str] = None,
56+
e_score_correction_bias: Optional[float] = None,
5557
): ...
5658

5759
def forward(
@@ -81,9 +83,14 @@ def __init__(
8183
up_proj_name: str = "up_proj",
8284
down_proj_name: str = "down_proj",
8385
hidden_act: str = "silu",
86+
scoring_func: Optional[str] = None,
87+
e_score_correction_bias: Optional[float] = None,
8488
):
8589
super().__init__()
8690

91+
assert scoring_func is None, "scoring func is not handled"
92+
assert e_score_correction_bias is None, "scoring correction bias is not handled"
93+
8794
log_once(
8895
logger.info,
8996
"No fused layers are available for this model type, using (slower) dense MoE layer",
@@ -199,21 +206,24 @@ def __init__(
199206
topk: int,
200207
topk_group: Optional[int],
201208
weights: Weights,
209+
scoring_func: Optional[str] = "softmax",
210+
e_score_correction_bias: Optional[float] = None,
202211
gate_proj_name: str = "gate_proj",
203212
up_proj_name: str = "up_proj",
204213
down_proj_name: str = "down_proj",
205214
):
206215
super().__init__()
207-
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
208-
weights.loader.weight_class, UnquantizedWeight
209-
):
210-
cls = UnquantizedSparseMoELayer
211-
elif isinstance(weights.loader, HybridFP8UnquantLoader):
212-
cls = (
213-
FP8SparseMoELayer
214-
if weights.loader.to_fp8
215-
else UnquantizedSparseMoELayer
216-
)
216+
if (
217+
isinstance(weights.loader, DefaultWeightsLoader)
218+
and isinstance(weights.loader.weight_class, UnquantizedWeight)
219+
) or isinstance(weights.loader, HybridFP8UnquantLoader):
220+
if (
221+
isinstance(weights.loader, HybridFP8UnquantLoader)
222+
and weights.loader.to_fp8
223+
):
224+
cls = FP8SparseMoELayer
225+
else:
226+
cls = UnquantizedSparseMoELayer
217227
elif isinstance(
218228
weights.loader, GPTQMarlinWeightsLoader
219229
) and can_use_marlin_moe_gemm(
@@ -240,6 +250,8 @@ def __init__(
240250
topk=topk,
241251
topk_group=topk_group,
242252
weights=weights,
253+
scoring_func=scoring_func,
254+
e_score_correction_bias=e_score_correction_bias,
243255
gate_proj_name=gate_proj_name,
244256
up_proj_name=up_proj_name,
245257
down_proj_name=down_proj_name,

server/text_generation_server/layers/moe/fp8.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(
2828
topk: int,
2929
topk_group: Optional[int],
3030
weights: Weights,
31+
scoring_func: Optional[str] = "softmax",
32+
e_score_correction_bias: Optional[float] = None,
3133
gate_proj_name: str = "gate_proj",
3234
up_proj_name: str = "up_proj",
3335
down_proj_name: str = "down_proj",
@@ -42,6 +44,9 @@ def __init__(
4244
self.topk = topk
4345
self.topk_group = topk_group
4446
self.renormalize = renormalize
47+
self.weight_block_size = weights.weights_loader.weight_block_size
48+
self.scoring_func = scoring_func
49+
self.e_score_correction_bias = e_score_correction_bias
4550

4651
(
4752
self.gate_up_proj,
@@ -76,6 +81,8 @@ def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tens
7681
use_grouped_topk=self.n_expert_group is not None,
7782
num_expert_group=self.n_expert_group,
7883
topk_group=self.topk_group,
84+
scoring_func=self.scoring_func,
85+
e_score_correction_bias=self.e_score_correction_bias,
7986
use_fp8_w8a8=True,
8087
w1_scale=self.gate_up_proj_weight_scale,
8188
w2_scale=self.down_proj_weight_scale,
@@ -109,7 +116,7 @@ def _load_expert_weights(
109116
)
110117
if all_weight_scales is None:
111118
all_weight_scales = torch.empty(
112-
(n_experts,),
119+
(n_experts,) + weight.weight_scale.shape,
113120
dtype=torch.float32,
114121
device=weight.weight.device,
115122
)

0 commit comments

Comments
 (0)