Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions inference/bf16_cast_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import json
import re
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

import torch
from safetensors.torch import load_file, save_file

from kernel import weight_quant

# Layers that should not be quantized (remain in BF16)
SKIP_QUANT_PATTERNS = [
r".*\.layernorm\.weight$",
r".*\.norm\.weight$",
r".*input_layernorm\.weight$",
r".*post_attention_layernorm\.weight$",
r".*\.kv_a_layernorm\.weight$",
r".*\.q_a_layernorm\.weight$",
r".*\.embed_tokens\.weight$",
r".*\.head\.weight$",
r".*lm_head\.weight$",
r".*\.eh_proj\.weight$",
r".*\.gate\.e_score_correction_bias$",
r".*\.gate\.weight$"
]

def should_skip_quantization(weight_name):
"""Check if weight name matches any pattern in the skip list"""
return any(re.match(pattern, weight_name) for pattern in SKIP_QUANT_PATTERNS)

def main(bf16_path, fp8_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(fp8_path, exist_ok=True)

# Get list of safetensor files
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()

# Load model index if it exists
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
if os.path.exists(model_index_file):
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
else:
# Create a new weight map if there's no index file
weight_map = {}

# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []

for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict

new_state_dict = {}
for weight_name, weight in current_state_dict.items():
# Skip weights that should not be quantized
if should_skip_quantization(weight_name) or weight.dim() != 2:
new_state_dict[weight_name] = weight
else:
# Quantize weights to FP8
fp8_weight, scale_inv = weight_quant(weight)
new_state_dict[weight_name] = fp8_weight
scale_inv_name = f"{weight_name}_scale_inv"
new_state_dict[scale_inv_name] = scale_inv
fp8_weight_names.append(weight_name)

# Update weight map
if weight_name in weight_map:
weight_map[scale_inv_name] = file_name

new_safetensor_file = os.path.join(fp8_path, file_name)
save_file(new_state_dict, new_safetensor_file)

# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()

# Update model index
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-bf16-hf-path", type=str, required=True)
parser.add_argument("--output-fp8-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_bf16_hf_path, args.output_fp8_hf_path)
58 changes: 58 additions & 0 deletions inference/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,64 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
return y


@triton.jit
def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Quantizes weights in blocks and computes scaling factors for each block.

Args:
x_ptr (tl.pointer): Pointer to the input weights tensor.
y_ptr (tl.pointer): Pointer to the output buffer for quantized weights.
s_ptr (tl.pointer): Pointer to the output buffer for scaling factors.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.

Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
max_val = tl.max(tl.abs(x))
s = max_val / 448.0 # Same scaling as in act_quant
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y, mask=mask)
tl.store(s_ptr + pid_m * n + pid_n, s)


def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes a weight tensor using block-wise quantization.

Args:
x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.

Raises:
AssertionError: If `x` is not contiguous or if its dimensions are not 2.
"""
assert x.is_contiguous()
assert x.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(triton.cdiv(M, block_size), triton.cdiv(N, block_size), dtype=torch.float32)
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size)
return y, s


fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
Expand Down