Skip to content

Commit

Permalink
Clarify FP8-Marlin use on capability 8.9 (#2940)
Browse files Browse the repository at this point in the history
The log message stated that the GPU does not support FP8 on capability
8.9. However we use FP8-Marlin on that capability because it is faster.
  • Loading branch information
danieldk authored Jan 22, 2025
1 parent 1d3c9be commit 1dd3466
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 10 additions & 0 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
# gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

if major == 8 and minor == 9:
log_once(
logger.info,
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
)
else:
log_once(
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
)

return GPTQMarlinFP8Linear

# On other systems let Torch decide if the hardware supports FP8.
Expand Down
4 changes: 0 additions & 4 deletions server/text_generation_server/layers/marlin/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.log import log_once

try:
import marlin_kernels
Expand All @@ -36,8 +34,6 @@ def __init__(
_check_marlin_kernels()
assert marlin_kernels is not None

log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")

scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
Expand Down

0 comments on commit 1dd3466

Please sign in to comment.