diff --git a/README.md b/README.md index 44f9288..ee4ebc8 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,8 @@ conda activate lite_llama git clone https://github.com/harleyszhang/lite_llama.git cd lite_llama/ pip install -r requirement.txt -python test_weight_convert.py # model weight transformation -python generate.py --prompt "What is large language model" --checkpoint_path /path/to/model/Llama-3.2-1B-Instruct/ # Run on the basis that the model has been downloaded and placed in the specified directory +python apply_weight_convert.py --checkpoints_dir /path/to/model/Llama-3.2-1B-Instruct/ --model_type llama # model weight transformation +python generate.py -p "What is large language model" -m /path/to/model/Llama-3.2-1B-Instruct/ -f /path/to/figure# Run on the basis that the model has been downloaded and placed in the specified directory ``` ROCm version 5.7 and above is recommended. @@ -95,21 +95,20 @@ conda activate lite_llama git clone https://github.com/harleyszhang/lite_llama.git cd lite_llama/ pip install -r requirement.txt -python test_weight_convert.py # model weight transformation -python generate.py --prompt "What is large language model" --checkpoint_path /path/to/model/Llama-3.2-1B-Instruct/ # Run on the basis that the model has been downloaded and placed in the specified directory +python apply_weight_convert.py --checkpoints_dir /path/to/model/Llama-3.2-1B-Instruct/ --model_type llama # model weight transformation +python generate.py -p "What is large language model" -m /path/to/model/Llama-3.2-1B-Instruct/ -f /path/to/figure# Run on the basis that the model has been downloaded and placed in the specified directory ``` - ## Evaluation -After `cli.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal. - -![cli](./images/cli_stream.png) - After `generate.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal. ![generate](./images/generate_stream.png) +After `cli.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal. + +![cli](./images/cli_stream.png) + After `cli_llava.py` runs successfully, the terminal displays the interface as shown below, enter your picture and prompt word in the terminal, and then enter. ![llava model streaming output](./images/llava_output2.gif) diff --git a/apply_weight_convert.py b/apply_weight_convert.py old mode 100644 new mode 100755 index d84bf58..e1e51ee --- a/apply_weight_convert.py +++ b/apply_weight_convert.py @@ -241,11 +241,11 @@ def convert(checkpoints_dir: Path, new_sd: Dict[str, torch.Tensor] = {} # ---------- 1. 重映射 ---------- - for k, v in tqdm(hf_state.items(), desc=f"[{model_type}] 权重重映射"): + for k, v in tqdm(hf_state.items(), desc=f"[{model_type}] Weight mapping"): if (ck := mapping.get(k)) is not None: new_sd[ck] = v else: - logger.debug("忽略未映射参数 %s", k) + logger.debug("Ignore unmapped parameters %s", k) # ---------- 2. 仅对 *Qwen* 系列执行 KV 合并 ---------- if model_type.startswith("qwen") or model_type.startswith("llama"): # 只处理 Qwen-2 / Qwen-3 等 @@ -259,7 +259,7 @@ def convert(checkpoints_dir: Path, save_state_dict(out_dir, checkpoints_dir.name, new_sd) copy_metadata(checkpoints_dir, out_dir) - logger.info("🎉 转换完成,共 %d 个参数", len(new_sd)) + logger.info("🎉 Convert Complete,There are %d parameters in total", len(new_sd)) return new_sd @@ -313,8 +313,8 @@ def get_num_layers(checkpoints_dir: Path, model_type: str) -> int: def main() -> None: parser = argparse.ArgumentParser( description="Convert HF / bin checkpoints into Lite-LLaMA format.") - parser.add_argument("checkpoints_dir", type=Path, help="模型权重目录") - parser.add_argument("--model-type", + parser.add_argument("--checkpoints_dir", type=Path, help="模型权重目录") + parser.add_argument("--model_type", choices=_SPEC.keys(), help="显式指定模型类型;默认根据目录名猜测") parser.add_argument("--device", default="cuda", @@ -325,11 +325,11 @@ def main() -> None: # 1️⃣ **直接从 config.json 读取 model_type** ↓ model_type = detect_model_type(ckpt_dir) - logger.info("检测到 model_type = %s", model_type) + logger.info("Model Type is: %s", model_type) # 2️⃣ 获取层数 num_layers = get_num_layers(ckpt_dir, model_type) - logger.info("Transformer 层数 %d", num_layers) + logger.info("Transformer Number of layers %d", num_layers) # 3️⃣ 加载权重并执行转换 hf_sd = load_hf_state(ckpt_dir, model_type, device=args.device) diff --git a/generate.py b/generate.py index 6d359e0..4b26115 100644 --- a/generate.py +++ b/generate.py @@ -1,28 +1,36 @@ +# 对原有的generate.py进行修改,添加量化支持 + import torch -from typing import Optional +from typing import Optional, List +from lite_llama.utils.prompt_templates import get_prompter, get_image_token +from lite_llama.generate_stream import GenerateStreamText +from lite_llama.utils.image_process import vis_images + +import warnings + warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") from lite_llama.utils.common import get_gpu_memory, detect_device, count_tokens, get_model_type -from lite_llama.utils.prompt_templates import get_prompter -from lite_llama.generate_stream import GenerateStreamText # 导入 GenerateText 类 -import warnings +from lite_llama.llava_generate_stream import LlavaGeneratorStream import sys, os, time from pathlib import Path - # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) import psutil +from lite_llama.utils.logger import log -process = psutil.Process(os.getpid()) +# 新增导入 +from lite_llama.quantization.quant_manager import quantization_manager, QuantizationType +process = psutil.Process(os.getpid()) -def report_resource_usage(ram_before, vram_before, gpu_type) -> None: +def report_resource_usage(ram_before, vram_before) -> None: end_time = time.time() ram_after = process.memory_info().rss - vram_after = get_gpu_memory(gpu_type) + vram_after = get_gpu_memory(detect_device()) - ram_used = (ram_after - ram_before) / (1024**3) # Bytes to GB + ram_used = (ram_after - ram_before) / (1024 ** 3) # Bytes to GB if vram_before is not None and vram_after is not None: vram_used = vram_after - vram_before @@ -30,57 +38,57 @@ def report_resource_usage(ram_before, vram_before, gpu_type) -> None: else: vram_text = "Unavailable" - print(f"CPU RAM Used: {ram_used:.2f} GB") - print(f"GPU VRAM Used: {vram_text}") - - -def main( - prompt: str = "Hello, my name is", - *, - temperature: float = 0.6, - top_p: float = 0.9, - max_seq_len: int = 2048, - max_gpu_num_blocks=40960, - max_gen_len: Optional[int] = 1024, - load_model: bool = True, - compiled_model: bool = False, - triton_weight: bool = True, - gpu_type: str = "nvidia", - checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), - quantize: Optional[str] = None, + log.info(f"CPU RAM Used: {ram_used:.2f} GB") + log.info(f"GPU VRAM Used: {vram_text}") + + +def generate_llama( + prompt: str = "Hello, my name is", + quantization: Optional[str] = None, # 新增参数 + *, + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 2048, + max_gpu_num_blocks=40960, + max_gen_len: Optional[int] = 1024, + compiled_model: bool = False, + gpu_type: str = "nvidia", + checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), ): - device = "cuda" if torch.cuda.is_available() else "cpu" + device = 'cuda' if torch.cuda.is_available() else 'cpu' assert checkpoint_path.is_dir(), checkpoint_path checkpoint_path = str(checkpoint_path) + # 检测量化类型 + if quantization is None: + quantization = quantization_manager.detect_quantization_type(checkpoint_path) + if quantization != QuantizationType.NONE: + log.info(f"Automatically detect the quantization type: {quantization}") + if max_seq_len <= 1024: short_prompt = True else: short_prompt = False - model_prompter = get_prompter( - get_model_type(checkpoint_path), checkpoint_path, short_prompt - ) + model_prompter = get_prompter(get_model_type(checkpoint_path), checkpoint_path, short_prompt) + # Start resource tracking ram_before = process.memory_info().rss - - gpu_type = detect_device() vram_before = get_gpu_memory(gpu_type) - # Init LLM generator - start = time.perf_counter() + # 创建生成器,传入量化参数 generator = GenerateStreamText( checkpoints_dir=checkpoint_path, tokenizer_path=checkpoint_path, max_gpu_num_blocks=max_gpu_num_blocks, max_seq_len=max_seq_len, - load_model=load_model, compiled_model=compiled_model, - triton_weight=triton_weight, device=device, + quantization=quantization, # 新增参数 ) model_prompter.insert_prompt(prompt) prompts = [model_prompter.model_input] + # Call the generation function and start the stream generation stream = generator.text_completion_stream( prompts, @@ -88,28 +96,297 @@ def main( top_p=top_p, max_gen_len=max_gen_len, ) - end = time.perf_counter() - completion = "" # Initialize to generate the result - # NOTE: After creating a generator, it can be iterated through a for loop + completion = '' text_msg = "" + start = time.perf_counter() for batch_completions in stream: - new_text = batch_completions[0]["generation"][len(completion) :] - completion = batch_completions[0]["generation"] - print(new_text, end="", flush=True) + new_text = batch_completions[0]['generation'][len(completion):] + completion = batch_completions[0]['generation'] + print(new_text, end='', flush=True) text_msg += new_text + end = time.perf_counter() print("\n\n==================================\n") - print( - f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec" - ) + log.info( + f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec") # Report resource usage - report_resource_usage(ram_before, vram_before, gpu_type) + report_resource_usage(ram_before, vram_before) + + +def generate_llava( + prompt: str = "Hello, my name is", + checkpoint_path: Path = Path("checkpoints/lit-llama/7B/"), + figure_path: Path = Path("figures/lit-llama/"), + gpu_type: str = "nvidia", + quantization: Optional[str] = None, # 新增参数 + temperature: float = 0.6, + top_p: float = 0.9, + max_seq_len: int = 2048, + max_gpu_num_blocks=None, + max_gen_len: Optional[int] = 512, + compiled_model: bool = False, +): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # 检测量化类型 + if quantization is None: + quantization = quantization_manager.detect_quantization_type(str(checkpoint_path)) + if quantization != QuantizationType.NONE: + log.info(f"Automatically detect the quantization type: {quantization}") + + if max_seq_len <= 1024: + short_prompt = True + else: + short_prompt = False + + if not os.path.isfile(figure_path): + log.error(f"'{figure_path}' Not a valid file path!") + else: + image_input = str(figure_path).strip() + image_items = [image_input] + image_num = len(image_items) + vis_images(image_items) + assert checkpoint_path.is_dir(), checkpoint_path + checkpoint_path = str(checkpoint_path) + model_prompter = get_prompter("llama", checkpoint_path, short_prompt) + + # Start resource tracking + ram_before = process.memory_info().rss + vram_before = get_gpu_memory(gpu_type) + + try: + generator = LlavaGeneratorStream( + checkpoints_dir=checkpoint_path, + tokenizer_path=checkpoint_path, + max_gpu_num_blocks=max_gpu_num_blocks, + max_seq_len=max_seq_len, + compiled_model=compiled_model, + device=device, + quantization=quantization, # 新增参数 + ) + except Exception as e: + log.error(f"Model loading failure: {e}") + sys.exit(1) + + image_token = get_image_token() + model_prompter.insert_prompt(image_token * image_num + prompt) + prompts = [model_prompter.model_input] + + try: + stream = generator.text_completion_stream( + prompts, + image_items, + temperature=temperature, + top_p=top_p, + max_gen_len=max_gen_len, + ) + except Exception as e: + log.error(f"Text Generation Failure: {e}") + + completion = '' + text_msg = "" + start = time.perf_counter() + + for batch_completions in stream: + next_text = batch_completions[0]['generation'][len(completion):] + completion = batch_completions[0]['generation'] + print(f"\033[91m{next_text}\033[0m", end='', flush=True) + text_msg += next_text + end = time.perf_counter() + + print("\n\n==================================\n") + log.info( + f"Time for inference: {(end - start):.2f} sec, {count_tokens(text_msg, generator.tokenizer) / (end - start):.2f} tokens/sec") + + # Report resource usage + report_resource_usage(ram_before, vram_before) if __name__ == "__main__": from jsonargparse import CLI torch.set_float32_matmul_precision("high") + + + def main( + prompt: str = "Hello, my name is", + checkpoint_path: Path = Path("checkpoints/lite-llama/7B/"), + figure_path: Optional[Path] = None, + quantization: Optional[str] = None, # 新增参数 + ): + """ + Generate text using lite_llama with optional quantization support + + Args: + prompt: Input prompt text + checkpoint_path: Path to model checkpoint directory + figure_path: Path to Image file for LLaVA generation, optional + quantization: Quantization method ('gptq', 'awq', 'smoothquant', or None for auto-detection) + """ + gpu_type = detect_device() + model_path = os.path.abspath(checkpoint_path) + + # 验证量化参数 + if quantization and quantization not in ['gptq', 'awq', 'smoothquant']: + log.error(f"不支持的量化方法: {quantization}") + log.info("支持的量化方法: gptq, awq, smoothquant") + return + + if figure_path: + generate_llava( + prompt=prompt, + checkpoint_path=Path(model_path), + figure_path=Path(figure_path), + gpu_type=gpu_type, + quantization=quantization, + ) + else: + generate_llama( + prompt=prompt, + checkpoint_path=Path(model_path), + gpu_type=gpu_type, + quantization=quantization + ) + + CLI(main) + + +# 新增量化推理的便捷函数 +def run_quantized_inference( + model_path: str, + prompt: str, + quantization_method: Optional[str] = None, + **kwargs +): + """ + 运行量化推理的便捷函数 + + Args: + model_path: 模型路径 + prompt: 输入提示 + quantization_method: 量化方法,None为自动检测 + **kwargs: 其他推理参数 + """ + + # 检查模型是否存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型路径不存在: {model_path}") + + # 获取模型类型 + model_type = get_model_type(model_path) + + # 设置默认参数 + default_params = { + 'temperature': 0.6, + 'top_p': 0.9, + 'max_seq_len': 2048, + 'max_gen_len': 1024, + 'compiled_model': False, + } + default_params.update(kwargs) + + if model_type == 'llava': + # LLaVA模型需要图像输入 + figure_path = kwargs.get('figure_path') + if not figure_path: + log.warning("LLaVA模型需要图像输入,将使用默认图像") + # 这里可以设置一个默认图像路径 + + generate_llava( + prompt=prompt, + checkpoint_path=Path(model_path), + figure_path=Path(figure_path) if figure_path else None, + quantization=quantization_method, + **default_params + ) + else: + generate_llama( + prompt=prompt, + checkpoint_path=Path(model_path), + quantization=quantization_method, + **default_params + ) + + +# 量化性能测试函数 +def benchmark_quantized_model( + model_path: str, + quantization_methods: Optional[List[str]] = None, + test_prompts: Optional[List[str]] = None, + num_runs: int = 3 +): + """ + 对量化模型进行性能基准测试 + + Args: + model_path: 模型路径 + quantization_methods: 要测试的量化方法列表 + test_prompts: 测试提示列表 + num_runs: 每个配置的运行次数 + """ + + if quantization_methods is None: + quantization_methods = ['gptq', 'awq', 'smoothquant', None] # None代表无量化 + + if test_prompts is None: + test_prompts = [ + "What is artificial intelligence?", + "Explain quantum computing in simple terms.", + "Write a short story about a robot." + ] + + results = {} + + for method in quantization_methods: + method_name = method or "no_quantization" + log.info(f"测试量化方法: {method_name}") + + method_results = [] + + for prompt in test_prompts: + prompt_results = [] + + for run in range(num_runs): + log.info(f"运行 {run + 1}/{num_runs}: {prompt[:50]}...") + + start_time = time.time() + try: + run_quantized_inference( + model_path=model_path, + prompt=prompt, + quantization_method=method, + max_gen_len=256 # 限制生成长度以便快速测试 + ) + end_time = time.time() + prompt_results.append(end_time - start_time) + + except Exception as e: + log.error(f"测试失败 ({method_name}, run {run + 1}): {e}") + prompt_results.append(float('inf')) + + method_results.append(prompt_results) + + results[method_name] = method_results + + # 打印结果摘要 + log.info("=" * 60) + log.info("基准测试结果摘要") + log.info("=" * 60) + + for method_name, method_results in results.items(): + avg_times = [] + for prompt_results in method_results: + valid_times = [t for t in prompt_results if t != float('inf')] + if valid_times: + avg_times.append(sum(valid_times) / len(valid_times)) + + if avg_times: + overall_avg = sum(avg_times) / len(avg_times) + log.info(f"{method_name:15}: {overall_avg:.2f}s 平均响应时间") + else: + log.info(f"{method_name:15}: 测试失败") + + return results \ No newline at end of file diff --git a/lite_llama/executor/model_executor.py b/lite_llama/executor/model_executor.py index 92c6ce7..d933b17 100644 --- a/lite_llama/executor/model_executor.py +++ b/lite_llama/executor/model_executor.py @@ -1,9 +1,11 @@ +# 在原有的model_executor.py基础上添加以下修改 + import torch import torch.nn as nn import json, time from pathlib import Path -from typing import Callable, Type +from typing import Callable, Type, Optional from transformers import LlavaConfig from accelerate import init_empty_weights, load_checkpoint_and_dispatch @@ -17,10 +19,9 @@ from ..kernels import update_kv_index from ..utils.logger import log - -# ----------------------------------------------------------------------------- -# Registry helpers (avoid long if/elif chains) -# ----------------------------------------------------------------------------- +# 新增导入 +from ..quantization.quant_manager import quantization_manager, QuantizationType +from ..models.quantized_models import create_quantized_model CONFIG_CLASS_MAP: dict[str, Type] = { "llama": LlamaConfig, @@ -29,6 +30,7 @@ "llava": LlavaConfig, } + class ModelExecutor: # 定义类属性 model_config = None @@ -38,26 +40,39 @@ class ModelExecutor: # 通过静态方法 build 将类属性当作默认配置使用 @staticmethod def build( - checkpoints_dir: str, - max_seq_len: int, - max_gpu_num_blocks: None, - compiled_model: bool = False, - device: str = "cuda", + checkpoints_dir: str, + max_seq_len: int, + max_gpu_num_blocks: None, + compiled_model: bool = False, + device: str = "cuda", + quantization: Optional[str] = None, # 新增参数 ): """ 构建 ModelExecutor 实例, 加载模型、分词器和初始化推理信息结构体 atten_info。 参数: checkpoints_dir (str): 模型检查点目录路径。 - load_model (bool): 是否加载模型权重。 max_seq_len (int): 最大序列长度。 + max_gpu_num_blocks: GPU块数量限制。 + compiled_model (bool): 是否使用编译模型。 device (str): 设备类型('cuda'或'cpu')。 + quantization (str, optional): 量化类型,如'gptq', 'awq', 'smoothquant'等。 返回: ModelExecutor: 初始化后的 ModelExecutor 实例。 """ model_config = ModelExecutor._load_model_config(checkpoints_dir, max_seq_len) - model = ModelExecutor._load_model_weight(model_config, checkpoints_dir, device=device) + + # 检测或使用指定的量化类型 + if quantization is None: + quantization = quantization_manager.detect_quantization_type(checkpoints_dir) + log.info(f"Automatically detect the quantization type: {quantization}") + else: + log.info(f"Use the specified quantization type: {quantization}") + + model = ModelExecutor._load_model_weight( + model_config, checkpoints_dir, device=device, quantization=quantization + ) return ModelExecutor( model_config, model, max_gpu_num_blocks, compiled_model, device @@ -71,43 +86,36 @@ def _load_model_config(checkpoints_dir: str, max_seq_len: int): params = json.loads(cfg_path.read_text()) cfg_cls = CONFIG_CLASS_MAP.get(params["model_type"].lower()) - + if cfg_cls is None: raise ValueError(f"Unsupported model_type {params['model_type']!r}") - - return cfg_cls.from_dict(params) - - @staticmethod - def _accelerate_load_weight( - model_config, - checkpoints_dir, - device="cuda", - ): - with init_empty_weights(): - model = ModelExecutor._initialize_model(model_config, device=device) - # 假设 model 是使用 init_empty_weights 初始化的空模型 - model = load_checkpoint_and_dispatch( - model, checkpoints_dir, device_map="auto", dtype=torch.float16 - ) - - # 将模型转换为半精度, 并验证抓换 - model.to(device) - model.half() - for param in model.parameters(): - assert param.dtype == torch.float16, "Model parameters are not in FP16" - log.info("Converted model to half precision (FP16)") - - return model + config = cfg_cls.from_dict(params) + config.max_seq_len = max_seq_len # 确保设置正确的max_seq_len + return config @staticmethod def _load_model_weight( - model_config, - checkpoints_dir, - device="cuda", + model_config, + checkpoints_dir, + device="cuda", + quantization: Optional[str] = None, ): start_time = time.time() + if quantization and quantization != QuantizationType.NONE: + # 加载量化模型 + log.info(f"Load quantitative model: {quantization}") + model = quantization_manager.load_quantized_model( + model_path=checkpoints_dir, + model_config=model_config, + device=device + ) + + log.info(f"The quantitative model has been loaded successfully, taking {time.time() - start_time:.2f}s") + return model + + # 原有的非量化模型加载逻辑 # 初始化模型 with init_empty_weights(): model = ModelExecutor._initialize_model(model_config, device=device) @@ -174,12 +182,12 @@ def _initialize_model(model_config, device: str) -> nn.Module: return model def __init__( - self, - model_config, - model, - max_gpu_num_blocks=None, - compiled_model=False, - device="cuda", + self, + model_config, + model, + max_gpu_num_blocks=None, + compiled_model=False, + device="cuda", ): self.model_config = model_config self.device = device @@ -219,6 +227,8 @@ def __init__( if self.compiled_model: self.apply_cuda_graph() # 调用 cuda graph 优化 + # ... 其余方法保持不变 ... + def _get_max_avaliable_tokens(self, gpu_memory_utilization=0.9, block_size=1): avaliable_blocks = ComputeMaxAvailableBlocks( num_layers=self.llm_config.num_layers, @@ -234,7 +244,7 @@ def _get_max_avaliable_tokens(self, gpu_memory_utilization=0.9, block_size=1): return max_gpu_num_blocks, max_gpu_num_tokens def _init_mem_manager( - self, gpu_num_blocks, block_size=1, dtype=torch.float16, device="cuda" + self, gpu_num_blocks, block_size=1, dtype=torch.float16, device="cuda" ): kv_mem_manager = KVCacheMemoryManager( num_layers=self.llm_config.num_layers, @@ -249,7 +259,7 @@ def _init_mem_manager( return kv_mem_manager def apply_cuda_graph( - self, + self, ): """应用 cuda graph 优化 参数: @@ -266,7 +276,7 @@ def apply_cuda_graph( self.model_runner.capture_decode_graph() def init_req_to_tokens_table( - self, b_req_tokens_table, b_req_idx, b_seq_len, alloc_mem_index + self, b_req_tokens_table, b_req_idx, b_seq_len, alloc_mem_index ): """ 初始化 prefill 阶段已分配的批次请求项的 kv cache 所用 tokens 索引 @@ -282,19 +292,19 @@ def init_req_to_tokens_table( b_start_loc[i] = start_index cur_seq_len = b_seq_len_numpy[i] b_req_tokens_table[b_req_idx_numpy[i], :cur_seq_len] = alloc_mem_index[ - start_index : start_index + cur_seq_len - ] + start_index: start_index + cur_seq_len + ] start_index += cur_seq_len return b_start_loc def prefill_alloc_kv_cache( - self, - max_prompt_len, - actual_prompt_lens, - b_req_idx, - image_batch_size=None, - debug_mode=False, + self, + max_prompt_len, + actual_prompt_lens, + b_req_idx, + image_batch_size=None, + debug_mode=False, ): """ start_index: tensor([ 0, 270, 540, 810], device='cuda:0', dtype=torch.int32) diff --git a/lite_llama/generate.py b/lite_llama/generate.py index 4575e51..0fdb74c 100644 --- a/lite_llama/generate.py +++ b/lite_llama/generate.py @@ -241,4 +241,4 @@ def process_output_tokens( out_tokens.append(generated_toks) - return out_tokens + return out_tokens \ No newline at end of file diff --git a/lite_llama/generate_stream.py b/lite_llama/generate_stream.py index e60220a..0d318ef 100644 --- a/lite_llama/generate_stream.py +++ b/lite_llama/generate_stream.py @@ -7,6 +7,9 @@ from transformers import AutoTokenizer +# 新增导入 +from .quantization.quant_manager import QuantizationType + # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -18,44 +21,27 @@ class CompletionPrediction(TypedDict, total=False): logprobs: list[float] # not required +# 保持采样函数不变 @torch.inference_mode() def sample_top_p(probs, p): """ 执行 Top-p (Nucleus) 采样, 从概率分布中采样下一个词。 - - 参数: - probs (torch.Tensor): 概率分布张量,形状为 `[batch_size, vocab_size]`。 - p (float): 累积概率阈值,取值范围在 0 到 1 之间。 - 返回: - torch.Tensor: 采样得到的词索引,形状为 `[batch_size, 1]`。 - - 说明: - Top-p 采样算法: 选择概率累积和超过阈值 p 的最小集合,将这些词的概率重新归一化后进行采样。 """ - # 对概率分布进行降序排序。probs_sort: 排序后的概率值,形状与 probs 相同。probs_idx: 排序后的索引,用于映射回原始词汇表。 probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - # 计算排序后概率的累积和. 返回的 probs_sum 是累积概率分布。 probs_sum = torch.cumsum(probs_sort, dim=-1) - # 保留累积概率未超过阈值 p 的词汇的概率,其余词汇的概率被置为 0.0。 - mask = ( - probs_sum - probs_sort > p - ) # 创建掩码,对于每个位置,计算累积概率(不包括当前词)是否超过阈值 p。 - probs_sort[mask] = 0.0 # 将累积概率超过阈值 p 的词的概率置零。 + mask = (probs_sum - probs_sort > p) + probs_sort[mask] = 0.0 - # 对剩余的概率重新归一化, 确保总和为 1。 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - # 从重新归一化的概率分布中采样下一个词. 返回的 next_token 是采样得到的词在排序后概率分布中的索引。 next_token_sorted_idx = torch.multinomial(probs_sort, num_samples=1) - # 在 probs_idx 的最后一维(dim=-1)中,使用 next_token_sorted_idx 作为索引,提取对应的值。沿着 dim=1(列)进行索引提取 - # NOTE: torch.gather 函数按照给定的索引张量 index,从输入张量中收集 (获取) 数据,并返回一个与索引张量形状一致的张量。 next_token = torch.gather(probs_idx, -1, index=next_token_sorted_idx) - return next_token # 返回采样得到的下一个词的索引 + return next_token class GenerateStreamText: """ - GenerateText 类用于加载LLaMA模型并执行迭代式生成式推理 (文本生成)。 + 支持量化的GenerateStreamText类 """ def __init__( @@ -66,20 +52,28 @@ def __init__( max_seq_len=1024, compiled_model=False, device="cuda", + quantization: Optional[str] = None, # 新增参数 ): self.checkpoints_dir = checkpoints_dir + self.quantization = quantization # 存储量化类型 + # 创建ModelExecutor时传入量化参数 self.model_executor = ModelExecutor.build( checkpoints_dir=checkpoints_dir, max_gpu_num_blocks=max_gpu_num_blocks, max_seq_len=max_seq_len, compiled_model=compiled_model, device=device, + quantization=quantization, # 新增参数 ) self.tokenizer = self.load_tokenizer(tokenizer_path) self.model_config = self.model_executor.model_config self.device = device + # 记录量化信息 + if self.quantization and self.quantization != QuantizationType.NONE: + logger.info(f"使用量化推理: {self.quantization}") + def load_tokenizer(self, pretrained_model_name_or_path): model_name = get_model_name_from_path(pretrained_model_name_or_path) @@ -106,22 +100,9 @@ def generate_stream( ) -> Generator[tuple[list[str], Optional[list[float]]], None, None]: """ 基于提供的 prompt_tokens, 使用语言生成模型逐个生成 token, 并在生成时立即输出。 - - 参数: - prompt_tokens (list[list[int]]): 已经进行分词的 prompt, 每个 prompt 是一个整数列表。 - max_gen_len (int): 生成的最大长度。 - temperature (float, optional): 控制采样随机性的温度值。默认为 0.6。 - top_p (float, optional): 用于 nucleus sampling 的概率阈值。默认为 0.9。 - logprobs (bool, optional): 是否计算生成 token 的对数概率。默认为 False。 - echo (bool, optional): 是否在输出中包含 prompt_tokens。默认为 False。 - - generator 输出: - tuple[list[str], Optional[list[float]]]: 包含生成的文本和对应的对数概率(如果 logprobs 为 True)。 - 说明: - 该方法在生成循环中,每生成一个新 token, 就立即输出对应的文本和概率(如果需要)。 + 支持量化模型推理。 """ bsz = len(prompt_tokens) - # min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= self.model_config.max_seq_len total_len = min(self.model_config.max_seq_len, max_gen_len + max_prompt_len) @@ -141,7 +122,7 @@ def generate_stream( prev_pos = 0 last_yielded_pos = [ len(prompt_tokens[i]) if not echo else 0 for i in range(bsz) - ] # 初始化每个样本已输出的位置 + ] # 填充提示词到 tokens 张量 for k, t in enumerate(prompt_tokens): @@ -164,24 +145,19 @@ def generate_stream( .repeat(batch_size, 1) # shape: [batch_size, seq_len], 不分配额外内存 ) + # 使用量化模型进行前向推理 logits = self.model_executor.forward(input_ids, position_ids) decode_select_index = self.model_executor.decode_alloc_kv_cache(bsz) all_select_index_list.append(decode_select_index) if temperature > 0: - # NOTE: logits[:, -1] 表示选择的是最后一个位置(seq_len 维度的最后一项)对应的 logits。 - # NOTE: 在生成模型中的 prefill 阶段,我们只关心当前生成的最后一个 token 的分布。 probs = softmax_split(logits[:, -1] / temperature) - # NOTE: 使用核采样方法,从高概率的候选 token 中选择下一个 token 索引. top_p 控制采样范围(候选 token 的概率累积值)。 next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) input_ids = next_token # [batch_size, 1] - # 仅在需要生成的情况下替换 token - # NOTE: input_text_mask[:, cur_pos]:获取掩码中当前列的布尔值,表示每个序列在当前位置是否为实际输入词元。 - # NOTE: tokens[:, cur_pos]:获取 tokens 中当前列的值。next_token:包含当前生成的词元 ID。 mask = ~input_text_mask[:, cur_pos] # [batch_size] tokens[:, cur_pos] = torch.where( mask, next_token.reshape(-1), tokens[:, cur_pos] @@ -192,12 +168,6 @@ def generate_stream( ) prev_pos = cur_pos - # eos_reached 是一个布尔张量,记录每个序列是否到达了终止状态, 形状为 [batch_size, 1]。 - # NOTE: ~input_text_mask[:, cur_pos] 标记当前生成位置是否是模型生成的部分(非输入部分)。True 表示当前列是待生成的部分。False 表示当前列是输入部分。 - # NOTE: next_token == self.tokenizer.eos_token_id 表示检测当前生成的 next_token 是否等于 eos_token_id,即模型生成了终止标记。 - # NOTE: & 表示按位与操作,确保当前位置是非输入部分且生成了终止标记。 - # NOTE: 使用 |= 按位或更新,表示如果某个序列已经到达 eos_token_id,则保持 True 状态,不会被后续重置为 False。 - # 为整个批次收集输出 batch_outputs = [] for i in range(bsz): @@ -207,11 +177,11 @@ def generate_stream( token = tokens[i, start:end].tolist() text = self.tokenizer.decode( token, skip_special_tokens=True - ) # 解码时跳过特殊标记。 + ) batch_outputs.append(text) last_yielded_pos[i] = end else: - batch_outputs.append("") # 如果没有新生成的内容,添加空字符串 + batch_outputs.append("") # 将整个批次的输出一次性 yield yield batch_outputs @@ -251,4 +221,4 @@ def text_completion_stream( for batch_outputs in stream: for i, text in enumerate(batch_outputs): completions[i]["generation"] += text - yield completions.copy() + yield completions.copy() \ No newline at end of file diff --git a/lite_llama/kernels/awq_linear.py b/lite_llama/kernels/awq_linear.py new file mode 100644 index 0000000..7a1363b --- /dev/null +++ b/lite_llama/kernels/awq_linear.py @@ -0,0 +1,290 @@ +import torch +import triton +import triton.language as tl +from typing import Optional +import psutil, os, sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from lite_llama.quantization.utils import pack_weight + + +@triton.jit +def awq_linear_kernel( + input_ptr, qweight_ptr, qscales_ptr, qzeros_ptr, output_ptr, bias_ptr, + M, N, K, group_size, + stride_input_m, stride_input_k, + stride_qweight_n, stride_qweight_k, + stride_qscales_n, stride_qscales_g, + stride_qzeros_n, stride_qzeros_g, + stride_output_m, stride_output_n, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, HAS_BIAS: tl.constexpr, +): + """Ultra-simplified AWQ linear kernel""" + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Only process one element per thread for maximum compatibility + m_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + n_idx = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + m_mask = m_idx < M + n_mask = n_idx < N + + # Initialize accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Simple loop over all K without chunking + for k in range(K): + # Load input values [BLOCK_M] + input_ptrs = input_ptr + m_idx * stride_input_m + k * stride_input_k + input_vals = tl.load(input_ptrs, mask=m_mask, other=0.0) + + # Load and process weights [BLOCK_N] + packed_k = k // 2 + is_high = k % 2 + + qweight_ptrs = qweight_ptr + n_idx * stride_qweight_n + packed_k * stride_qweight_k + packed_weights = tl.load(qweight_ptrs, mask=n_mask, other=0) + + # Unpack 4-bit weights + if is_high == 1: + weights_int4 = (packed_weights >> 4) & 0xF + else: + weights_int4 = packed_weights & 0xF + + # Get quantization parameters + group_idx = k // GROUP_SIZE + + qscales_ptrs = qscales_ptr + n_idx * stride_qscales_n + group_idx * stride_qscales_g + qzeros_ptrs = qzeros_ptr + n_idx * stride_qzeros_n + group_idx * stride_qzeros_g + + scales = tl.load(qscales_ptrs, mask=n_mask, other=1.0) + zeros = tl.load(qzeros_ptrs, mask=n_mask, other=0.0) + + # Dequantize + weights_fp = (weights_int4.to(tl.float32) - zeros) * scales + + # Accumulate outer product: input[m] * weight[n] -> acc[m, n] + acc += input_vals[:, None] * weights_fp[None, :] + + # Add bias + if HAS_BIAS: + bias_ptrs = bias_ptr + n_idx + bias_vals = tl.load(bias_ptrs, mask=n_mask, other=0.0) + acc += bias_vals[None, :] + + # Store result + output_ptrs = output_ptr + m_idx[:, None] * stride_output_m + n_idx[None, :] * stride_output_n + tl.store(output_ptrs, acc.to(tl.float16), mask=m_mask[:, None] & n_mask[None, :]) + + +def awq_linear_triton( + input: torch.Tensor, + qweight: torch.Tensor, + qscales: torch.Tensor, + qzeros: torch.Tensor, + bias: Optional[torch.Tensor] = None, + group_size: int = 128 +) -> torch.Tensor: + """ + AWQ quantized linear layer using Triton + + Args: + input: Input tensor [*, in_features] in fp16 + qweight: Packed quantized weights [out_features, in_features//2] in int8 + qscales: Quantization scales [out_features, in_features//group_size] + qzeros: Quantization zeros [out_features, in_features//group_size] + bias: Optional bias [out_features] + group_size: Group size for quantization + + Returns: + Output tensor [*, out_features] in fp16 + """ + + # Reshape input to 2D + input_shape = input.shape + input_2d = input.view(-1, input.shape[-1]) + M, K = input_2d.shape + N = qweight.shape[0] + + # Ensure input is fp16 + if input_2d.dtype != torch.float16: + input_2d = input_2d.to(torch.float16) + + # Create output tensor + output = torch.empty((M, N), dtype=torch.float16, device=input.device) + + # Block sizes - smaller for better compatibility + BLOCK_M = 16 + BLOCK_N = 16 + BLOCK_K = 16 + + # Grid configuration + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + + # Launch kernel + awq_linear_kernel[grid]( + input_2d, qweight, qscales, qzeros, output, bias, + M, N, K, group_size, + # Input strides + input_2d.stride(0), input_2d.stride(1), + # QWeight strides + qweight.stride(0), qweight.stride(1), + # QScales strides + qscales.stride(0), qscales.stride(1), + # QZeros strides + qzeros.stride(0), qzeros.stride(1), + # Output strides + output.stride(0), output.stride(1), + # Block sizes + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + GROUP_SIZE=group_size, + HAS_BIAS=bias is not None, + ) + + # Reshape output back to original shape + output_shape = input_shape[:-1] + (N,) + return output.view(output_shape) + + +class AWQLinear(torch.nn.Module): + """ + AWQ Quantized Linear Layer using Triton + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + group_size: int = 128, + wbits: int = 4, + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.wbits = wbits + + # Calculate number of groups + self.num_groups = (in_features + group_size - 1) // group_size + + # Register quantized weight parameters + # Packed weights: 2 int4 values per int8 + packed_width = (in_features + 1) // 2 + self.register_buffer('qweight', torch.zeros((out_features, packed_width), dtype=torch.uint8)) + self.register_buffer('qscales', torch.zeros((out_features, self.num_groups), dtype=torch.float16)) + self.register_buffer('qzeros', torch.zeros((out_features, self.num_groups), dtype=torch.float16)) + + if bias: + self.register_parameter('bias', torch.nn.Parameter(torch.zeros(out_features, dtype=torch.float16))) + else: + self.register_buffer('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass using Triton kernel""" + return awq_linear_triton( + input=x, + qweight=self.qweight, + qscales=self.qscales, + qzeros=self.qzeros, + bias=self.bias, + group_size=self.group_size + ) + + @classmethod + def from_float( + cls, + linear: torch.nn.Linear, + qweight: torch.Tensor, + qscales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int = 128, + ): + """ + Create AWQLinear from a regular Linear layer and quantization parameters + + Args: + linear: Original torch.nn.Linear layer + qweight: Packed quantized weights + qscales: Quantization scales + qzeros: Quantization zeros + group_size: Group size used for quantization + """ + + awq_linear = cls( + in_features=linear.in_features, + out_features=linear.out_features, + bias=linear.bias is not None, + group_size=group_size, + ) + + # Copy quantized parameters + with torch.no_grad(): + awq_linear.qweight.copy_(qweight) + awq_linear.qscales.copy_(qscales) + awq_linear.qzeros.copy_(qzeros) + + if linear.bias is not None: + awq_linear.bias.copy_(linear.bias.to(torch.float16)) + + return awq_linear + + +def demo_awq_triton(): + """Demo function for AWQ Triton linear layer""" + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print("CUDA not available, demo will run on CPU (Triton requires CUDA)") + return + + # Create test data + batch_size, seq_len = 2, 128 + in_features, out_features = 768, 768 + group_size = 128 + + # Create original linear layer + linear = torch.nn.Linear(in_features, out_features, bias=True).to(device) + + + # Mock quantized weights (in practice, these come from AWQ.quantize()) + weight_int4 = torch.randint(0, 16, (out_features, in_features), dtype=torch.uint8, device=device) + qweight = pack_weight(weight_int4) + + num_groups = (in_features + group_size - 1) // group_size + qscales = torch.randn((out_features, num_groups), dtype=torch.float16, device=device).abs() + 0.1 + qzeros = torch.randint(0, 16, (out_features, num_groups), dtype=torch.float16, device=device) + + # Create AWQ linear layer + awq_linear = AWQLinear.from_float(linear, qweight, qscales, qzeros, group_size) + awq_linear = awq_linear.to(device) + + # Test input + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.float16, device=device) + + print(f"Input shape: {x.shape}") + print(f"QWeight shape: {qweight.shape}") + print(f"QScales shape: {qscales.shape}") + print(f"QZeros shape: {qzeros.shape}") + + # Forward pass + with torch.no_grad(): + output = awq_linear(x) + + print(f"Output shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + print("AWQ Triton linear demo completed successfully!") + + +if __name__ == "__main__": + demo_awq_triton() \ No newline at end of file diff --git a/lite_llama/kernels/gptq_linear.py b/lite_llama/kernels/gptq_linear.py new file mode 100644 index 0000000..e23ecef --- /dev/null +++ b/lite_llama/kernels/gptq_linear.py @@ -0,0 +1,302 @@ +import triton +import triton.language as tl +import torch +import torch.nn as nn +import numpy as np +import sys, os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from lite_llama.quantization.utils import pack_weight + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def int4_gemm_kernel( + a_ptr, b_ptr, c_ptr, + bscales_ptr, bzeros_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + GROUP_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_mask = offs_am[:, None] < M + b_mask = offs_bn[None, :] < N + + a_ptrs = a_ptr + stride_am * offs_am[:, None] + stride_ak * offs_k[None, :] + b_ptrs = b_ptr + stride_bn * offs_bn[None, :] + stride_bk * (offs_k[:, None] // 2) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16) + + for k in range(0, K, BLOCK_SIZE_K): + b_q = tl.load(b_ptrs, mask=b_mask) + a = tl.load(a_ptrs, mask=a_mask).to(tl.float16) + + # Compute per-group index + k_offset = k + offs_k + group_idx = k_offset // GROUP_SIZE + + # Load scale and zero for each [N, G] + scale = tl.load(bscales_ptr + stride_bn * offs_bn[None, :] + group_idx[:, None]).to(tl.float16) + zero = tl.load(bzeros_ptr + stride_bn * offs_bn[None, :] + group_idx[:, None]).to(tl.float16) + + # Extract int4 values from uint8 + shift = (k_offset[:, None] % 2) * 4 + q = (b_q.to(tl.uint8) >> shift) & 0xF + b_deq = (q.to(tl.float16) - zero) * scale + + accumulator += tl.dot(a, b_deq, out_dtype=tl.float16) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def triton_int4_gemm( + inp: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + group_size: int = 64 +) -> torch.Tensor: + weight = weight.t().contiguous() + c_shape = inp.shape[:-1] + weight.shape[-1:] + inp = inp.view(-1, inp.shape[-1]).contiguous() + + PAD_TO = 256 + if inp.shape[0] % PAD_TO != 0: + c_crop = inp.shape[0] + new_inp = inp.new_zeros(((inp.shape[0] + PAD_TO - 1) // PAD_TO * PAD_TO, inp.shape[1])) + new_inp[:c_crop] = inp + inp = new_inp + else: + c_crop = None + + M, K = inp.shape + N = weight.shape[1] + + c = torch.empty((M, N), device=inp.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + int4_gemm_kernel[grid]( + inp, weight, c, + scales, zeros, + M, N, K, + inp.stride(0), inp.stride(1), + weight.stride(0), weight.stride(1), + c.stride(0), c.stride(1), + GROUP_SIZE=group_size, + ) + + return c[:c_crop] if c_crop is not None else c.view(c_shape) + + +class GPTQLinear(nn.Module): + """ + 4-bit quantized linear layer using Triton kernels (修复版本) + """ + + def __init__(self, in_features, out_features, bias=True, dtype=torch.float16, bits=4, groupsize=64, device="cuda", + tile_cols=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.groupsize = groupsize + self.device = device + self.dtype = dtype + self.bits = bits + self.tile_cols = groupsize + self.original_out_features = out_features + + # 计算量化参数的形状 + self.num_groups = (in_features + groupsize - 1) // groupsize + packed_width = (in_features + 1) // 2 # 2个int4打包成1个uint8 + + # 注册量化参数缓冲区 - 修复:确保所有缓冲区都被正确初始化 + self.register_buffer("packed_weight", torch.zeros(out_features, packed_width, dtype=torch.uint8)) + self.register_buffer("scales", torch.ones(out_features, self.num_groups, dtype=torch.float16)) + self.register_buffer("zeros", torch.zeros(out_features, self.num_groups, dtype=torch.float16)) + + # Bias参数 + if bias: + self.register_parameter("bias", nn.Parameter(torch.zeros(out_features, dtype=torch.float16))) + else: + self.register_buffer("bias", None) + + + def set_quantized_params(self, packed_weight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor): + """设置量化参数 - 新增方法""" + with torch.no_grad(): + if packed_weight is not None: + self.packed_weight.copy_(packed_weight) + if scales is not None: + self.scales.copy_(scales) + if zeros is not None: + self.zeros.copy_(zeros) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_flat = x.view(-1, self.in_features) + + # 确保所有参数都已正确设置 + if self.packed_weight is None or self.scales is None or self.zeros is None: + raise RuntimeError("Quantized parameters not properly initialized. Call set_quantized_params() first.") + + # 使用Triton优化的int4 GEMM + output = triton_int4_gemm( + x_flat.float(), + self.packed_weight, + self.scales, + self.zeros, + group_size=self.groupsize, + ) + + if self.bias is not None: + output += self.bias + + return output.view(*x.shape[:-1], self.out_features) + + @classmethod + def from_linear(cls, linear_layer: nn.Linear, groupsize: int = 128, bits: int = 4): + """从标准线性层创建GPTQ层""" + gptq_layer = cls( + in_features=linear_layer.in_features, + out_features=linear_layer.out_features, + bias=linear_layer.bias is not None, + groupsize=groupsize, + bits=bits, + device=linear_layer.weight.device + ) + + # 复制bias + if linear_layer.bias is not None: + gptq_layer.bias.data.copy_(linear_layer.bias.data) + + return gptq_layer + + def __repr__(self): + return f"GPTQLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, bits={self.bits}, groupsize={self.groupsize})" + + +def test_gptqlinear_vs_nnlinear( + in_features=2048, + out_features=4096, + groupsize=64, + wbits=4, + device="cuda" +): + torch.manual_seed(42) + np.random.seed(42) + + # ---- Create single input vector ---- + x = torch.randn(in_features, device=device, dtype=torch.float16) + linear = nn.Linear(in_features, out_features, bias=True, device=device, dtype=torch.float16).eval() + + weight = linear.weight.detach().to(device).float() + bias = linear.bias.detach().to(device).float() if linear.bias is not None else None + + # --- 创建GPTQ层并模拟量化参数 --- + gptqlinear = GPTQLinear(in_features, out_features, bias=True, groupsize=groupsize, device=device).to(device) + + # 模拟量化参数(实际使用中这些来自GPTQ量化算法) + num_groups = (in_features + groupsize - 1) // groupsize + packed_width = (in_features + 1) // 2 + + # 创建模拟的量化参数 + mock_packed_weight = torch.randint(0, 255, (out_features, packed_width), dtype=torch.uint8, device=device) + mock_scales = torch.randn(out_features, num_groups, dtype=torch.float16, device=device).abs() + 0.1 + mock_zeros = torch.randint(0, 15, (out_features, num_groups), dtype=torch.float16, device=device) + + # 设置量化参数 + gptqlinear.set_quantized_params(mock_packed_weight, mock_scales, mock_zeros) + + if bias is not None: + gptqlinear.bias.data.copy_(bias) + + gptqlinear.eval() + + print("\n== Latency ==") + time_fp = triton.testing.do_bench(lambda: linear(x)) + time_q = triton.testing.do_bench(lambda: gptqlinear(x)) + print(f"nn.Linear (fp16): {time_fp:.3f} ms") + print(f"GPTQLinear: {time_q:.3f} ms") + + # 测试输出形状 + a = linear(x) + b = gptqlinear(x) + print(f"Output shapes - Linear: {a.shape}, GPTQ: {b.shape}") + print("GPTQ layer test completed successfully!") + + +if __name__ == "__main__": + test_gptqlinear_vs_nnlinear() \ No newline at end of file diff --git a/lite_llama/kernels/sq_linear.py b/lite_llama/kernels/sq_linear.py new file mode 100644 index 0000000..a9ef948 --- /dev/null +++ b/lite_llama/kernels/sq_linear.py @@ -0,0 +1,297 @@ +import torch +import triton +import triton.language as tl +from typing import Optional +import math + + +def pack_weight(weight): + """Pack two INT4 values into one UINT8 value""" + rows, cols = weight.shape + if cols % 2 != 0: + weight = torch.nn.functional.pad(weight, (0, 1), value=0) + cols += 1 + + # Ensure weights are in INT4 range [-8, 7] + weight = torch.clamp(weight, min=-8, max=7) + + # Convert to unsigned representation [0, 15] + weight_unsigned = weight + 8 + + # Pack two INT4 values into one UINT8 + packed = (weight_unsigned[:, 0::2] & 0xF) | ((weight_unsigned[:, 1::2] & 0xF) << 4) + return packed.contiguous().to(torch.uint8) + + +def unpack_weight(packed_weight, original_cols): + """Unpack UINT8 values back to two INT4 values""" + rows, packed_cols = packed_weight.shape + unpacked = torch.zeros((rows, packed_cols * 2), dtype=torch.uint8, device=packed_weight.device) + unpacked[:, 0::2] = packed_weight & 0xF + unpacked[:, 1::2] = (packed_weight >> 4) & 0xF + # Convert back to signed INT4 range [-8, 7] + unpacked = unpacked.to(torch.int8) - 8 + return unpacked[:, :original_cols].contiguous() + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=4), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def smoothquant_int4_kernel( + # Pointers to matrices + x_ptr, w_ptr, bias_ptr, output_ptr, + # Quantization parameters + scale_ptr, zp_ptr, smooth_ptr, + # Matrix dimensions + M, N, K, + # Strides + stride_xm, stride_xk, + stride_wn, stride_wk, + stride_om, stride_on, + stride_sm, stride_sn, + stride_zpm, stride_zpn, + # Group size for quantization + group_size, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Optimized SmoothQuant linear kernel with INT4 weights and FP16 activations. + """ + # Program ID + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # Block offsets + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Initialize accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16) + + # Main loop over K dimension + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Current k indices + k_indices = k * BLOCK_SIZE_K + offs_k + + # Load input activations + x_ptrs = x_ptr + offs_am[:, None] * stride_xm + k_indices[None, :] * stride_xk + x_mask = (offs_am[:, None] < M) & (k_indices[None, :] < K) + x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float16) + + # Apply SmoothQuant inverse smoothing + if smooth_ptr is not None: + smooth_ptrs = smooth_ptr + k_indices + smooth_mask = k_indices < K + smooth_scale = tl.load(smooth_ptrs, mask=smooth_mask, other=1.0).to(tl.float16) + # FIX: Ensure the division result stays as fp16 + x = (x / smooth_scale[None, :]).to(tl.float16) + + # Load packed weights (each packed value contains 2 INT4 weights) + k_pack = k_indices // 2 # Packed dimension + w_pack_ptrs = w_ptr + offs_bn[None, :] * stride_wn + k_pack[:, None] * stride_wk + w_packed = tl.load(w_pack_ptrs, mask=(k_pack[:, None] < tl.cdiv(K, 2)) & (offs_bn[None, :] < N), other=0) + + # Unpack INT4 weights + # For even k_indices, use lower 4 bits; for odd k_indices, use upper 4 bits + even_mask = (k_indices % 2 == 0) + w_vals = tl.where(even_mask[:, None], w_packed & 0xF, (w_packed >> 4) & 0xF) + w_vals = w_vals.to(tl.int8) - 8 # Convert to signed INT4 range [-8, 7] + + # Load quantization parameters (group-wise) + group_idx = k_indices // group_size + scale_ptrs = scale_ptr + group_idx[:, None] * stride_sm + offs_bn[None, :] * stride_sn + zp_ptrs = zp_ptr + group_idx[:, None] * stride_zpm + offs_bn[None, :] * stride_zpn + + scale_mask = (group_idx[:, None] < tl.cdiv(K, group_size)) & (offs_bn[None, :] < N) + scale = tl.load(scale_ptrs, mask=scale_mask, other=1.0).to(tl.float16) + zp = tl.load(zp_ptrs, mask=scale_mask, other=0.0).to(tl.float16) + + # Dequantize weights: (w_vals - zero_point) * scale + w_vals = (w_vals.to(tl.float16) - zp) * scale + + # Matrix multiplication + # Ensure we only process valid k values + valid_mask = k_indices[:, None] < K + # FIX: Use fp16 zero value to prevent dtype promotion + w_vals = tl.where(valid_mask, w_vals, tl.zeros_like(w_vals)) + + accumulator += tl.dot(x, w_vals, out_dtype=tl.float16) + + # Convert to output precision + c = accumulator.to(tl.float16) + + # Add bias if provided + if bias_ptr is not None: + bias_ptrs = bias_ptr + offs_bn + bias_mask = offs_bn < N + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0).to(tl.float16) + c += bias[None, :] + + # Store output + output_ptrs = output_ptr + offs_am[:, None] * stride_om + offs_bn[None, :] * stride_on + output_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(output_ptrs, c, mask=output_mask) + + +class SmoothQuantLinear(torch.nn.Module): + """ + PyTorch module wrapper for SmoothQuant INT4 linear layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + group_size: int = 128, + alpha: float = 0.5 + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.alpha = alpha + + # Ensure in_features is compatible with packing + self.packed_in_features = (in_features + 1) // 2 # For packing 2 INT4 into 1 UINT8 + + # Initialize quantized weight storage (packed) + self.register_buffer('packed_weight', + torch.zeros(out_features, self.packed_in_features, dtype=torch.uint8)) + + # Quantization parameters (group-wise) + self.num_groups = (in_features + group_size - 1) // group_size + self.register_buffer('weight_scale', + torch.ones(self.num_groups, out_features, dtype=torch.float16)) + self.register_buffer('weight_zero', + torch.zeros(self.num_groups, out_features, dtype=torch.float16)) + + # SmoothQuant scaling factor + self.register_buffer('smooth_scale', + torch.ones(in_features, dtype=torch.float16)) + + # Bias + if bias: + self.register_buffer('bias', torch.zeros(out_features, dtype=torch.float16)) + else: + self.register_buffer('bias', None) + + def quantize_weight(self, weight: torch.Tensor, act_scales: Optional[torch.Tensor] = None): + """ + Quantize FP16/FP32 weight to INT4 with SmoothQuant. + """ + assert weight.shape == (self.out_features, self.in_features) + + # Compute smoothing scale + if act_scales is not None: + # SmoothQuant formula: s_j = (max|X_j|)^α / (max|W_j|)^(1-α) + weight_scales = weight.abs().max(dim=0)[0] + # Avoid division by zero + weight_scales = torch.clamp(weight_scales, min=1e-5) + act_scales = torch.clamp(act_scales, min=1e-5) + + smooth_scale = (act_scales.pow(self.alpha) / + weight_scales.pow(1 - self.alpha)) + smooth_scale = torch.clamp(smooth_scale, min=0.01, max=100.0) + self.smooth_scale.copy_(smooth_scale.to(torch.float16)) + + # Apply smoothing to weights + weight = weight * self.smooth_scale.unsqueeze(0) + + # Group-wise quantization + weight_groups = [] + scales = [] + zeros = [] + + for i in range(self.num_groups): + start_idx = i * self.group_size + end_idx = min((i + 1) * self.group_size, self.in_features) + + # Extract group + w_group = weight[:, start_idx:end_idx] + + # Compute scale and zero point for this group + w_max = w_group.max(dim=1, keepdim=True)[0] + w_min = w_group.min(dim=1, keepdim=True)[0] + + # Symmetric quantization for INT4 [-8, 7] + scale = (w_max - w_min) / 15.0 + scale = torch.clamp(scale, min=1e-5) + zero = torch.round((w_max + w_min) / 2.0 / scale) * scale + + # Quantize to INT4 range + w_quant = torch.round((w_group - zero) / scale).clamp(-8, 7) + + weight_groups.append(w_quant) + scales.append(scale.squeeze(1)) # [out_features] + zeros.append((zero / scale).squeeze(1)) # [out_features] + + # Concatenate groups back + weight_quantized = torch.cat(weight_groups, dim=1).to(torch.int8) + + # Store quantization parameters + self.weight_scale.copy_(torch.stack(scales, dim=0).to(torch.float16)) + self.weight_zero.copy_(torch.stack(zeros, dim=0).to(torch.float16)) + + # Pack weights + self.packed_weight.copy_(pack_weight(weight_quantized)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with INT4 quantized weights and FP16 activations. + """ + assert x.shape[-1] == self.in_features + assert x.dtype == torch.float16, "Input must be FP16" + + # Flatten input for matrix multiplication + x_shape = x.shape + x = x.view(-1, self.in_features) + M, K = x.shape + N = self.out_features + + # Allocate output + output = torch.empty(M, N, dtype=torch.float16, device=x.device) + + # Launch kernel + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + smoothquant_int4_kernel[grid]( + x, self.packed_weight, self.bias, output, + self.weight_scale, self.weight_zero, self.smooth_scale, + M, N, K, + x.stride(0), x.stride(1), + self.packed_weight.stride(0), self.packed_weight.stride(1), + output.stride(0), output.stride(1), + self.weight_scale.stride(0), self.weight_scale.stride(1), + self.weight_zero.stride(0), self.weight_zero.stride(1), + self.group_size + ) + + return output.view(*x_shape[:-1], N) + + + + diff --git a/lite_llama/llava_generate_stream.py b/lite_llama/llava_generate_stream.py index 91b664c..ba5f292 100644 --- a/lite_llama/llava_generate_stream.py +++ b/lite_llama/llava_generate_stream.py @@ -9,6 +9,9 @@ from transformers import AutoTokenizer, AutoProcessor +# 新增导入 +from .quantization.quant_manager import QuantizationType + # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -21,85 +24,73 @@ class CompletionPrediction(TypedDict, total=False): def tokenizer_image_token( - prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None + prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None ): """ 处理包含特殊标记 的文本提示, 将其转换为相应的 token 序列,并在 位置插入指定的图像 token 索引。 - - "A cat is sitting on the mat." - [65,32,99,97,116,32000,32,105,115,32,115,105,116,116,105,110,103,32000,32,111,110,32,116,104,101,32,109,97,116,46] - - 参数: - prompt (str): 包含 标记的文本。 - tokenizer: 分词器对象,需支持调用 tokenizer(chunk).input_ids。 - image_token_index (int): 用于替换 标记的图像 token 索引。 - return_tensors (str, optional): 指定返回的张量类型,例如 'pt' 表示 PyTorch 张量。 - - 返回: - list 或 torch.Tensor: 生成的 token 序列。 """ - # 使用正则表达式分割,移除 '' 前的空格,但保留后的空格 prompt_chunks = re.split(r"\s?", prompt) - # 不过滤空片段,以处理多个连续的 '' 标记 token_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks] input_ids = [] offset = 0 - # 检查第一个片段是否以 BOS token 开始 if ( - len(token_chunks) > 0 - and len(token_chunks[0]) > 0 - and token_chunks[0][0] == tokenizer.bos_token_id + len(token_chunks) > 0 + and len(token_chunks[0]) > 0 + and token_chunks[0][0] == tokenizer.bos_token_id ): offset = 1 input_ids.append(token_chunks[0][0]) - # 插入图像 token for i, chunk in enumerate(token_chunks): - input_ids.extend( - chunk[offset:] - ) # 添加当前片段的 token,跳过 BOS token(如果已添加) - offset = 0 # 仅适用于第一个片段 - if i < len(token_chunks) - 1: # 如果不是最后一个片段,插入图像 token + input_ids.extend(chunk[offset:]) + offset = 0 + if i < len(token_chunks) - 1: input_ids.append(image_token_index) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") - """ - [1, 3148, 1001, 29901, 32000, 1, 29871, 13, 5618, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 319, 1799, 9047, 13566, 29901] - """ + return input_ids class LlavaGeneratorStream: """ - GenerateText 类用于加载LLaMA模型并执行迭代式生成式推理 (文本生成)。 + 支持量化的LlavaGeneratorStream类 """ def __init__( - self, - checkpoints_dir: str, - tokenizer_path: str, - max_gpu_num_blocks=None, - max_seq_len=2048, - compiled_model=False, - device="cuda", + self, + checkpoints_dir: str, + tokenizer_path: str, + max_gpu_num_blocks=None, + max_seq_len=2048, + compiled_model=False, + device="cuda", + quantization: Optional[str] = None, # 新增参数 ): self.checkpoints_dir = checkpoints_dir self.compiled_model = compiled_model self.max_seq_len = max_seq_len self.device = device + self.quantization = quantization # 存储量化类型 + # 创建ModelExecutor时传入量化参数 self.model_executor = ModelExecutor.build( checkpoints_dir=checkpoints_dir, max_gpu_num_blocks=max_gpu_num_blocks, max_seq_len=max_seq_len, device=device, + quantization=quantization, # 新增参数 ) self.tokenizer = self.load_tokenizer(tokenizer_path) + # 记录量化信息 + if self.quantization and self.quantization != QuantizationType.NONE: + logger.info(f"使用量化推理 (LLaVA): {self.quantization}") + def load_tokenizer(self, pretrained_model_name_or_path): model_name = get_model_name_from_path(pretrained_model_name_or_path) @@ -143,32 +134,18 @@ def encode_images(self, image_items: list[Union[str, Image.Image]]): @torch.inference_mode() def generate_stream( - self, - prompt_tokens: list[list[int]], - image_tensors: Optional[torch.FloatTensor] = None, - max_gen_len: int = 2048, - temperature: float = 0.6, - top_p: float = 0.9, - echo: bool = False, + self, + prompt_tokens: list[list[int]], + image_tensors: Optional[torch.FloatTensor] = None, + max_gen_len: int = 2048, + temperature: float = 0.6, + top_p: float = 0.9, + echo: bool = False, ) -> Generator[tuple[list[str], Optional[list[float]]], None, None]: """ - 基于提供的 prompt_tokens, 使用语言生成模型逐个生成 token, 并在生成时立即输出。 - - 参数: - prompt_tokens (list[list[int]]): 已经进行分词的 prompt, 每个 prompt 是一个整数列表。 - max_gen_len (int): 生成的最大长度。 - temperature (float, optional): 控制采样随机性的温度值。默认为 0.6。 - top_p (float, optional): 用于 nucleus sampling 的概率阈值。默认为 0.9。 - logprobs (bool, optional): 是否计算生成 token 的对数概率。默认为 False。 - echo (bool, optional): 是否在输出中包含 prompt_tokens。默认为 False。 - - generator 输出: - tuple[list[str], Optional[list[float]]]: 包含生成的文本和对应的对数概率(如果 logprobs 为 True)。 - 说明: - 该方法在生成循环中,每生成一个新 token, 就立即输出对应的文本和概率(如果需要)。 + 基于提供的 prompt_tokens, 使用量化的LLaVA模型逐个生成 token, 并在生成时立即输出。 """ bsz = len(prompt_tokens) - # min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= self.max_seq_len total_seq_len = min(self.max_seq_len, max_gen_len + max_prompt_len) @@ -185,16 +162,14 @@ def generate_stream( tokens = torch.full( (bsz, total_seq_len), pad_id, dtype=torch.long, device=self.device ) - # 生成一个布尔张量,它的值为 True 的位置表示输入序列的实际内容(即非填充部分), 形状为 (batch_size, total_seq_len) input_text_mask = tokens != pad_id eos_reached = torch.tensor([False] * bsz, device=self.device) last_yielded_pos = [ len(prompt_tokens[i]) if not echo else 0 for i in range(bsz) - ] # 初始化每个样本已输出的位置 + ] # 填充提示词到 tokens 张量 for seq_id, token_ids in enumerate(prompt_tokens): - # NOTE: torch.long 等同于 torch.int64 tokens[seq_id, : len(token_ids)] = ( token_ids.clone().detach().to(dtype=torch.long, device=self.device) ) @@ -213,9 +188,11 @@ def generate_stream( input_ids = tokens[:, :max_prompt_len] # [batch_size, seq_len] for cur_pos in range(max_prompt_len, total_seq_len): batch_size, _ = input_ids.shape + + # 使用量化模型进行前向推理 logits = self.model_executor.forward( input_ids, position_ids, image_tensors - ) # step 0: position_ids 由 llava 模型类给出 + ) start_pos += bsz position_ids = ( @@ -240,7 +217,7 @@ def generate_stream( ) eos_reached = eos_reached | ( - mask & (next_token == self.tokenizer.eos_token_id) + mask & (next_token == self.tokenizer.eos_token_id) ) # 为整个批次收集输出 @@ -254,7 +231,7 @@ def generate_stream( batch_outputs.append(text) last_yielded_pos[i] = end else: - batch_outputs.append("") # 如果没有新生成的内容,添加空字符串 + batch_outputs.append("") # 将整个批次的输出一次性 yield yield batch_outputs @@ -267,13 +244,13 @@ def generate_stream( self.model_executor.kv_mem_manager.release_ref(all_select_indexs) def text_completion_stream( - self, - prompts: list[str], - image_items: list[Union[str, Image.Image]], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - echo: bool = False, + self, + prompts: list[str], + image_items: list[Union[str, Image.Image]], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + echo: bool = False, ) -> Generator[list[CompletionPrediction], None, None]: """每次迭代时,生成器返回一个包含多个 CompletionPrediction 字典的列表。""" @@ -285,11 +262,8 @@ def text_completion_stream( x, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ) for x in prompts - ] # torch.Size([1, 22]) - image_tensors = self.encode_images( - image_items - ) # image_tensors shape is torch.Size([1, 3, 336, 336]) - # print(f"prompt 0 shape: {prompt_tokens[0].shape}, image_tensors shape: {image_tensors.shape}") + ] + image_tensors = self.encode_images(image_items) stream = self.generate_stream( prompt_tokens=prompt_tokens, @@ -311,32 +285,14 @@ def text_completion_stream( def sample_top_p(probs, p): """ 执行 Top-p (Nucleus) 采样, 从概率分布中采样下一个词。 - - 参数: - probs (torch.Tensor): 概率分布张量,形状为 `[batch_size, vocab_size]`。 - p (float): 累积概率阈值,取值范围在 0 到 1 之间。 - 返回: - torch.Tensor: 采样得到的词索引,形状为 `[batch_size, 1]`。 - - 说明: - Top-p 采样算法: 选择概率累积和超过阈值 p 的最小集合,将这些词的概率重新归一化后进行采样。 """ - # 对概率分布进行降序排序。probs_sort: 排序后的概率值,形状与 probs 相同。probs_idx: 排序后的索引,用于映射回原始词汇表。 probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - # 计算排序后概率的累积和. 返回的 probs_sum 是累积概率分布。 probs_sum = torch.cumsum(probs_sort, dim=-1) - # 保留累积概率未超过阈值 p 的词汇的概率,其余词汇的概率被置为 0.0。 - mask = ( - probs_sum - probs_sort > p - ) # 创建掩码,对于每个位置,计算累积概率(不包括当前词)是否超过阈值 p。 - probs_sort[mask] = 0.0 # 将累积概率超过阈值 p 的词的概率置零。 + mask = (probs_sum - probs_sort > p) + probs_sort[mask] = 0.0 - # 对剩余的概率重新归一化, 确保总和为 1。 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - # 从重新归一化的概率分布中采样下一个词. 返回的 next_token 是采样得到的词在排序后概率分布中的索引。 next_token_sorted_idx = torch.multinomial(probs_sort, num_samples=1) - # 在 probs_idx 的最后一维(dim=-1)中,使用 next_token_sorted_idx 作为索引,提取对应的值。沿着 dim=1(列)进行索引提取 - # NOTE: torch.gather 函数按照给定的索引张量 index,从输入张量中收集 (获取) 数据,并返回一个与索引张量形状一致的张量。 next_token = torch.gather(probs_idx, -1, index=next_token_sorted_idx) - return next_token # 返回采样得到的下一个词的索引 + return next_token \ No newline at end of file diff --git a/lite_llama/models/__init__.py b/lite_llama/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lite_llama/models/quantized_models.py b/lite_llama/models/quantized_models.py new file mode 100644 index 0000000..8139409 --- /dev/null +++ b/lite_llama/models/quantized_models.py @@ -0,0 +1,353 @@ +""" +Quantized Model Builder for lite_llama +Creates quantized versions of supported models +""" +import torch +import torch.nn as nn +from typing import Dict, Any, Optional, Union +import copy + +from .llama import LlamaModel, FusedAttention as LlamaAttention, FusedMLP as LlamaMLP +from .qwen2 import Qwen2Model, Qwen2Attention, FusedMLP as Qwen2MLP +from .qwen3 import Qwen3Model, Qwen3Attention, FusedMLP as Qwen3MLP +from .llava import LlavaLlama +from .model_config import LlamaConfig, Qwen2Config, Qwen3Config + +# Import quantized layers +from lite_llama.kernels.awq_linear import AWQLinear +from lite_llama.kernels.gptq_linear import GPTQLinear +from lite_llama.kernels.sq_linear import SmoothQuantLinear + +from ..quantization.quant_manager import QuantizationType + + +class QuantizedAttentionMixin: + """量化Attention层的Mixin""" + + def replace_linear_with_quantized(self, quantization_method: str, config: Dict[str, Any]): + """替换线性层为量化层""" + + if quantization_method == QuantizationType.GPTQ: + # 替换投影层为GPTQ量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_gptq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_gptq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_gptq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_gptq_linear(self.o_proj, config) + # 处理融合的kv_proj权重 + if hasattr(self, 'kv_proj_weight'): + # 需要特殊处理融合权重 + pass + + elif quantization_method == QuantizationType.AWQ: + # 替换为AWQ量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_awq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_awq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_awq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_awq_linear(self.o_proj, config) + + elif quantization_method == QuantizationType.SMOOTHQUANT: + # 替换为SmoothQuant量化层 + if hasattr(self, 'q_proj'): + self.q_proj = self._create_sq_linear(self.q_proj, config) + if hasattr(self, 'k_proj'): + self.k_proj = self._create_sq_linear(self.k_proj, config) + if hasattr(self, 'v_proj'): + self.v_proj = self._create_sq_linear(self.v_proj, config) + if hasattr(self, 'o_proj'): + self.o_proj = self._create_sq_linear(self.o_proj, config) + + def _create_gptq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> GPTQLinear: + """创建GPTQ量化线性层""" + gptq_layer = GPTQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + dtype=torch.float16, + bits=config.get('w_bit', 4), + groupsize=config.get('group_size', 128), + device=config.get('device', 'cuda') + ) + return gptq_layer + + def _create_awq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> AWQLinear: + """创建AWQ量化线性层""" + awq_layer = AWQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + group_size=config.get('group_size', 128), + wbits=config.get('w_bit', 4) + ) + return awq_layer + + def _create_sq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> SmoothQuantLinear: + """创建SmoothQuant量化线性层""" + from ..quantization.quant_config import SmoothQuantConfig + sq_config = SmoothQuantConfig(**config) + + sq_layer = SmoothQuantLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + ) + return sq_layer + + +class QuantizedMLPMixin: + """量化MLP层的Mixin""" + + def replace_linear_with_quantized(self, quantization_method: str, config: Dict[str, Any]): + """替换线性层为量化层""" + + if quantization_method == QuantizationType.GPTQ: + self.gate_proj = self._create_gptq_linear(self.gate_proj, config) + self.up_proj = self._create_gptq_linear(self.up_proj, config) + self.down_proj = self._create_gptq_linear(self.down_proj, config) + + elif quantization_method == QuantizationType.AWQ: + self.gate_proj = self._create_awq_linear(self.gate_proj, config) + self.up_proj = self._create_awq_linear(self.up_proj, config) + self.down_proj = self._create_awq_linear(self.down_proj, config) + + elif quantization_method == QuantizationType.SMOOTHQUANT: + self.gate_proj = self._create_sq_linear(self.gate_proj, config) + self.up_proj = self._create_sq_linear(self.up_proj, config) + self.down_proj = self._create_sq_linear(self.down_proj, config) + + def _create_gptq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> GPTQLinear: + """创建GPTQ量化线性层""" + gptq_layer = GPTQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + dtype=torch.float16, + bits=config.get('w_bit', 4), + groupsize=config.get('group_size', 128), + device=config.get('device', 'cuda') + ) + return gptq_layer + + def _create_awq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> AWQLinear: + """创建AWQ量化线性层""" + awq_layer = AWQLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + group_size=config.get('group_size', 128), + wbits=config.get('w_bit', 4) + ) + return awq_layer + + def _create_sq_linear(self, original_layer: nn.Linear, config: Dict[str, Any]) -> SmoothQuantLinear: + """创建SmoothQuant量化线性层""" + from ..quantization.quant_config import SmoothQuantConfig + + sq_layer = SmoothQuantLinear( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias is not None, + ) + return sq_layer + + +# 创建量化版本的Attention层 +class QuantizedLlamaAttention(LlamaAttention, QuantizedAttentionMixin): + def __init__(self, config: LlamaConfig, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen2Attention(Qwen2Attention, QuantizedAttentionMixin): + def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int, + quantization_method: str, quantization_config: Dict[str, Any], dtype=torch.float16): + super().__init__(hidden_size, num_heads, num_kv_heads, dtype) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen3Attention(Qwen3Attention, QuantizedAttentionMixin): + def __init__(self, config: Qwen3Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +# 创建量化版本的MLP层 +class QuantizedLlamaMLP(LlamaMLP, QuantizedMLPMixin): + def __init__(self, config: LlamaConfig, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen2MLP(Qwen2MLP, QuantizedMLPMixin): + def __init__(self, config: Qwen2Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +class QuantizedQwen3MLP(Qwen3MLP, QuantizedMLPMixin): + def __init__(self, config: Qwen3Config, quantization_method: str, quantization_config: Dict[str, Any]): + super().__init__(config) + self.replace_linear_with_quantized(quantization_method, quantization_config) + + +def create_quantized_model( + model_config: Union[LlamaConfig, Qwen2Config, Qwen3Config], + quantization_method: str, + quantization_config: Dict[str, Any], + device: str = "cuda" +) -> torch.nn.Module: + """创建量化模型""" + + model_type = model_config.model_type.lower() + + if model_type == "llama": + model = create_quantized_llama(model_config, quantization_method, quantization_config, device) + elif model_type == "qwen2": + model = create_quantized_qwen2(model_config, quantization_method, quantization_config, device) + elif model_type == "qwen3": + model = create_quantized_qwen3(model_config, quantization_method, quantization_config, device) + elif model_type == "llava": + model = create_quantized_llava(model_config, quantization_method, quantization_config, device) + else: + raise ValueError(f"不支持的模型类型: {model_type}") + + return model.to(device) + + +def create_quantized_llama( + config: LlamaConfig, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> LlamaModel: + """创建量化的Llama模型""" + + # 创建基础模型 + model = LlamaModel(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedLlamaAttention( + config, quantization_method, quantization_config + ) + + # 复制权重信息(在实际加载时会被覆盖) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedLlamaMLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + # 替换lm_head如果需要 + if quantization_method in [QuantizationType.GPTQ, QuantizationType.AWQ]: + if quantization_method == QuantizationType.GPTQ: + quantized_lm_head = GPTQLinear( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + bias=model.lm_head.bias is not None, + dtype=torch.float16, + bits=quantization_config.get('w_bit', 4), + groupsize=quantization_config.get('group_size', 128), + device=device + ) + else: # AWQ + quantized_lm_head = AWQLinear( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + bias=model.lm_head.bias is not None, + group_size=quantization_config.get('group_size', 128), + wbits=quantization_config.get('w_bit', 4) + ) + + model.lm_head = quantized_lm_head + + return model + + +def create_quantized_qwen2( + config: Qwen2Config, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> Qwen2Model: + """创建量化的Qwen2模型""" + + # 创建基础模型 + model = Qwen2Model(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedQwen2Attention( + config.hidden_size, config.num_heads, config.num_kv_heads, + quantization_method, quantization_config + ) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedQwen2MLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + return model + + +def create_quantized_qwen3( + config: Qwen3Config, + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> Qwen3Model: + """创建量化的Qwen3模型""" + + # 创建基础模型 + model = Qwen3Model(config) + + # 替换层为量化版本 + for i, layer in enumerate(model.layers): + # 替换attention + quantized_attention = QuantizedQwen3Attention( + config, quantization_method, quantization_config + ) + layer.self_attn = quantized_attention + + # 替换MLP + quantized_mlp = QuantizedQwen3MLP( + config, quantization_method, quantization_config + ) + layer.mlp = quantized_mlp + + return model + + +def create_quantized_llava( + config: Any, # LlavaConfig + quantization_method: str, + quantization_config: Dict[str, Any], + device: str +) -> LlavaLlama: + """创建量化的LLaVA模型""" + + # 创建基础模型 + model = LlavaLlama(config) + + # 量化language_model部分 + llama_config = model.llama_config + quantized_language_model = create_quantized_llama( + llama_config, quantization_method, quantization_config, device + ) + + model.language_model = quantized_language_model + + return model \ No newline at end of file diff --git a/lite_llama/quantization/__init__.py b/lite_llama/quantization/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/lite_llama/quantization/awq.py b/lite_llama/quantization/awq.py new file mode 100644 index 0000000..a28dad7 --- /dev/null +++ b/lite_llama/quantization/awq.py @@ -0,0 +1,328 @@ +import torch +import torch.nn as nn +import numpy as np +from typing import Dict, Tuple, Optional, Any, List +from tqdm.auto import tqdm +from lite_llama.quantization.utils import pack_weight +from lite_llama.quantization.quant_config import AWQConfig + + +class AWQ: + def __init__(self, config: AWQConfig): + self.config = config + self.wbits = self.config.w_bit + self.groupsize = self.config.group_size if self.config.group_size != -1 else float('inf') + self.device = self.config.device + self.maxq = 2 ** self.wbits - 1 # For 4-bit: 0-15 + self.zero_point = config.zero_point + self.alpha = self.config.alpha + self.search_scale = self.config.search_scale + self.auto_scale = self.config.auto_scale + + # Store activation statistics + self.activation_stats = {} + + def collect_activations(self, layer_name: str, input_tensor: torch.Tensor): + """Collect activation statistics for AWQ calibration""" + if layer_name not in self.activation_stats: + self.activation_stats[layer_name] = { + 'mean': [], + 'max': [], + 'inputs': [] + } + + # Store input activations (limit storage to prevent OOM) + if len(self.activation_stats[layer_name]['inputs']) < 32: + self.activation_stats[layer_name]['inputs'].append(input_tensor.detach().cpu()) + + # Compute per-channel statistics + if input_tensor.dim() == 3: # [batch, seq, hidden] + channel_means = input_tensor.abs().mean(dim=(0, 1)) + channel_maxs = input_tensor.abs().max(dim=1)[0].max(dim=0)[0] + elif input_tensor.dim() == 2: # [batch, hidden] + channel_means = input_tensor.abs().mean(dim=0) + channel_maxs = input_tensor.abs().max(dim=0)[0] + else: + channel_means = input_tensor.abs().view(-1, input_tensor.shape[-1]).mean(dim=0) + channel_maxs = input_tensor.abs().view(-1, input_tensor.shape[-1]).max(dim=0)[0] + + self.activation_stats[layer_name]['mean'].append(channel_means.cpu()) + self.activation_stats[layer_name]['max'].append(channel_maxs.cpu()) + + def get_salient_channels(self, layer_name: str, top_k: float = 0.01) -> torch.Tensor: + """Identify salient channels based on activation statistics""" + if layer_name not in self.activation_stats: + return None + + stats = self.activation_stats[layer_name] + if stats['mean']: + mean_activations = torch.stack(stats['mean']).mean(dim=0) + max_activations = torch.stack(stats['max']).mean(dim=0) + + # Combine mean and max for saliency score + saliency_score = mean_activations * 0.7 + max_activations * 0.3 + + # Select top-k% most salient channels + num_salient = max(1, int(len(saliency_score) * top_k)) + _, salient_indices = torch.topk(saliency_score, num_salient) + return salient_indices + + return None + + def search_best_scale(self, layer_name: str, weight: torch.Tensor, input_feat: torch.Tensor) -> torch.Tensor: + """Search for the best per-channel scaling factors""" + device = weight.device + org_out = torch.matmul(input_feat, weight.t()) + + if org_out.abs().max() < 0.01: # Very small activations + return torch.ones(weight.shape[0], device=device, dtype=weight.dtype) + + # Get salient channels + salient_channels = self.get_salient_channels(layer_name) + + best_error = float('inf') + best_scales = torch.ones(weight.shape[0], device=device, dtype=weight.dtype) + + # Grid search for best alpha + alpha_candidates = [0.0, 0.1, 0.25, 0.5, 0.75, 1.0] if self.search_scale else [self.alpha] + + for alpha in alpha_candidates: + # Compute channel-wise scaling factors + if layer_name in self.activation_stats and self.activation_stats[layer_name]['mean']: + stats = self.activation_stats[layer_name] + mean_activations = torch.stack(stats['mean']).mean(dim=0).to(device) + + # AWQ scaling: s_j = (max|X_j|^alpha) / (max|W_j|^(1-alpha)) + weight_max = weight.abs().max(dim=0)[0].clamp(min=1e-5) + act_max = mean_activations.clamp(min=1e-5) + + scales = act_max.pow(alpha) / weight_max.pow(1 - alpha) + scales = scales.clamp(min=0.1, max=10.0) # Prevent extreme values + + # Protect salient channels with higher scale values + if salient_channels is not None: + scales[salient_channels] = scales[salient_channels].clamp(min=0.5) + else: + # Fallback to weight-based scaling + weight_max = weight.abs().max(dim=0)[0] + scales = weight_max.pow(alpha).clamp(min=0.1, max=10.0) + scales = scales / scales.max() # Normalize + + # Test this scaling + weight_scaled = weight * scales.view(-1, 1) + qweight, qzeros, qscales = self.quantize_with_scales(weight_scaled, torch.ones_like(scales)) + + # Dequantize to test reconstruction quality + weight_sim = self.dequantize_weight(qweight, qzeros, qscales, weight.shape[1]) + + # Compute reconstruction error + out_sim = torch.matmul(input_feat, weight_sim.t()) + loss = (org_out - out_sim).float().pow(2).mean().item() + + if loss < best_error: + best_error = loss + best_scales = scales.clone() + + return best_scales + + def quantize_with_scales(self, weight: torch.Tensor, scales: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor]: + """quantization with proper scaling""" + device = weight.device + rows, cols = weight.shape + + # Apply per-output-channel scaling + weight_scaled = weight * scales.view(-1, 1) + + # Group-wise quantization + if self.groupsize == float('inf'): + groupsize = cols + else: + groupsize = min(int(self.groupsize), cols) + + num_groups = (cols + groupsize - 1) // groupsize + + qweight = torch.zeros_like(weight_scaled, dtype=torch.uint8) + qzeros = torch.zeros((rows, num_groups), dtype=torch.float16, device=device) + qscales = torch.zeros((rows, num_groups), dtype=torch.float16, device=device) + + for g in range(num_groups): + start_col = g * groupsize + end_col = min((g + 1) * groupsize, cols) + w_group = weight_scaled[:, start_col:end_col] + + if self.zero_point: + # Asymmetric quantization: map [w_min, w_max] to [0, 2^bits-1] + w_min = w_group.min(dim=1, keepdim=True)[0] + w_max = w_group.max(dim=1, keepdim=True)[0] + + # Compute scale and zero point + range_val = (w_max - w_min).clamp(min=1e-5) + scale = range_val / self.maxq + zero = (-w_min / scale).round().clamp(0, self.maxq) + + # Quantize: q = round((w - w_min) / scale) = round(w/scale + zero) + q = torch.round(w_group / scale + zero).clamp(0, self.maxq) + else: + # Symmetric quantization around zero + w_max = w_group.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-5) + scale = w_max / (self.maxq // 2) # Use half range for signed values + zero = torch.full_like(scale, self.maxq // 2) # Midpoint as zero + + # Quantize: q = round(w / scale) + zero_point + q = torch.round(w_group / scale + zero).clamp(0, self.maxq) + + qweight[:, start_col:end_col] = q.to(torch.uint8) + qscales[:, g] = scale.squeeze(-1) + qzeros[:, g] = zero.squeeze(-1) + + return qweight, qzeros.to(torch.float16), qscales.to(torch.float16) + + def dequantize_weight(self, qweight: torch.Tensor, qzeros: torch.Tensor, + qscales: torch.Tensor, original_cols: int) -> torch.Tensor: + """Dequantize weights for testing""" + rows, _ = qweight.shape + num_groups = qzeros.shape[1] + groupsize = (original_cols + num_groups - 1) // num_groups + + weight = torch.zeros((rows, original_cols), dtype=torch.float16, device=qweight.device) + + for g in range(num_groups): + start_col = g * groupsize + end_col = min((g + 1) * groupsize, original_cols) + + q = qweight[:, start_col:end_col].float() + scale = qscales[:, g:g + 1] + zero = qzeros[:, g:g + 1] + + # Dequantize: w = (q - zero) * scale + weight[:, start_col:end_col] = ((q - zero) * scale).to(torch.float16) + + return weight + + def quantize(self, weight: torch.Tensor, layer_name: str = "") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Main AWQ quantization function with fixes""" + assert weight.ndim == 2 + device = weight.device + + # Get representative input if available + input_feat = None + if layer_name in self.activation_stats and self.activation_stats[layer_name]['inputs']: + inputs = self.activation_stats[layer_name]['inputs'][:3] # Use first few + input_feat = torch.cat([inp.to(device) for inp in inputs], dim=0) + if input_feat.dim() == 3: + input_feat = input_feat.view(-1, input_feat.shape[-1]) + + # Search for best scales if we have input data + if input_feat is not None and self.search_scale: + scales = self.search_best_scale(layer_name, weight, input_feat) + else: + # Use activation statistics for scaling + if self.auto_scale and layer_name in self.activation_stats: + stats = self.activation_stats[layer_name] + if stats['mean']: + mean_activations = torch.stack(stats['mean']).mean(dim=0).to(device) + weight_max = weight.abs().max(dim=0)[0].clamp(min=1e-5) + act_max = mean_activations.clamp(min=1e-5) + + scales = act_max.pow(self.alpha) / weight_max.pow(1 - self.alpha) + scales = scales.clamp(min=0.1, max=10.0) + scales = scales / scales.max() # Normalize + else: + scales = torch.ones(weight.shape[0], device=device, dtype=weight.dtype) + else: + scales = torch.ones(weight.shape[0], device=device, dtype=weight.dtype) + + # Quantize with computed scales + qweight, qzeros, qscales = self.quantize_with_scales(weight, scales) + + # Pack weights consistently + packed_qweight = pack_weight(qweight) + + return packed_qweight, qzeros, qscales + + +def quantize_awq( + model_state_dict: Dict[str, torch.Tensor], + calibration_loader: Optional[Any] = None, + model: Optional[torch.nn.Module] = None, + target_layers: Optional[List[str]] = None, + config: AWQConfig = None, + device: str = "cuda" +) -> Dict[str, torch.Tensor]: + """AWQ quantization function""" + + awq = AWQ(config) + quantized_state_dict = {} + + # Default target layers + if target_layers is None: + target_layers = [] + for name in model_state_dict.keys(): + if any(pattern in name for pattern in [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "kv_proj", "lm_head" + ]): + target_layers.append(name) + + # Collect activation statistics if calibration data is provided + if calibration_loader is not None and model is not None: + print("Collecting activation statistics for AWQ...") + + hooks = [] + + def make_hook(layer_name): + def hook_fn(module, input, output): + if isinstance(input, tuple) and len(input) > 0: + awq.collect_activations(layer_name, input[0]) + else: + awq.collect_activations(layer_name, input) + + return hook_fn + + # Register hooks + for name, module in model.named_modules(): + if name in target_layers and isinstance(module, torch.nn.Linear): + hook = module.register_forward_hook(make_hook(name)) + hooks.append(hook) + + # Run calibration + model.eval() + with torch.no_grad(): + for i, batch in enumerate(tqdm(calibration_loader, desc="Calibrating")): + if i >= 32: # Limit calibration samples + break + try: + if hasattr(model, 'forward'): + _ = model(batch) + except Exception as e: + print(f"Calibration batch {i} failed: {e}") + continue + + # Remove hooks + for hook in hooks: + hook.remove() + + print(f"Quantizing {len(target_layers)} layers to {config.w_bit} bits with AWQ...") + + # Quantize each target layer + for name, param in tqdm(model_state_dict.items(), desc="Quantizing layers"): + if name in target_layers and param.dim() == 2: + weight = param.to(device) + layer_name = name.replace(".weight", "").replace("_weight", "") + + # Quantize using AWQ + qweight, qzeros, qscales = awq.quantize(weight, layer_name) + + # Store quantized parameters + base_name = layer_name + quantized_state_dict[f"{base_name}.qweight"] = qweight.cpu() + quantized_state_dict[f"{base_name}.qzeros"] = qzeros.cpu() + quantized_state_dict[f"{base_name}.qscales"] = qscales.cpu() + else: + # Keep non-quantized parameters + quantized_state_dict[name] = param.cpu() + + print("AWQ quantization completed!") + return quantized_state_dict \ No newline at end of file diff --git a/lite_llama/quantization/gptq.py b/lite_llama/quantization/gptq.py new file mode 100755 index 0000000..d2273bc --- /dev/null +++ b/lite_llama/quantization/gptq.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +import numpy as np +from typing import Dict, Tuple, Optional, Any +from tqdm.auto import tqdm +import time, gc, psutil, os, sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) +from lite_llama.quantization.quant_config import GPTQConfig +from lite_llama.utils.common import get_gpu_memory +from lite_llama.quantization.utils import pack_weight, unpack_weight + + +class GPTQ: + def __init__(self, config: GPTQConfig): + self.wbits = config.w_bit + self.groupsize = config.group_size if config.group_size != -1 else float('inf') + self.device = config.device + self.maxq = 2 ** self.wbits - 1 + + def find_params(self, x, weight): + """Standard min-max quantization parameter calculation""" + self.maxq = torch.tensor(2 ** self.wbits - 1) + + shape = weight.shape + if self.groupsize != float('inf'): + groupsize = min(int(self.groupsize), shape[1]) + else: + groupsize = shape[1] + + weight = weight.float() + weight = weight.reshape((-1, groupsize)) + + # Calculate min/max for each group + tmp = torch.zeros(weight.shape[0], device=self.device) + xmin = torch.minimum(weight.min(1)[0], tmp) + xmax = torch.maximum(weight.max(1)[0], tmp) + + # Symmetric quantization around zero + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + # Calculate scale and zero point + scale = (xmax - xmin) / self.maxq + zero = torch.round(-xmin / scale) + + # Clamp zero point to valid range + zero = torch.clamp(zero, 0, self.maxq) + + # Handle edge cases + scale = torch.clamp(scale, min=1e-8) + + return scale.reshape(shape[0], -1), zero.reshape(shape[0], -1) + + def quantize(self, W: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """ + Improved GPTQ quantization with better numerical stability + """ + assert W.ndim == 2 + rows, cols = W.shape + device = W.device + original_cols = cols + + # Determine groupsize + if self.groupsize == float('inf'): + groupsize = cols + else: + groupsize = min(int(self.groupsize), cols) + + num_groups = (cols + groupsize - 1) // groupsize + + # Initialize output tensors + qweight = torch.zeros((rows, cols), dtype=torch.uint8, device=device) + scales = torch.zeros((rows, num_groups), dtype=torch.float32, device=device) + zeros = torch.zeros((rows, num_groups), dtype=torch.float32, device=device) + + # Process each group + for g in range(num_groups): + start_col = g * groupsize + end_col = min((g + 1) * groupsize, cols) + + W_group = W[:, start_col:end_col].clone() + + # Calculate quantization parameters for this group + scale, zero = self.find_params(None, W_group) + + # Store parameters + scales[:, g] = scale.squeeze(-1) + zeros[:, g] = zero.squeeze(-1) + + # Quantize the group + q = torch.clamp( + torch.round(W_group / scale + zero), + 0, self.maxq + ) + qweight[:, start_col:end_col] = q.to(torch.uint8) + + # Pack the weights + packed_qweight = pack_weight(qweight) + + return ( + packed_qweight, + zeros.to(torch.float16), + scales.to(torch.float16), + original_cols + ) + + def dequantize(self, qweight: torch.Tensor, zeros: torch.Tensor, + scales: torch.Tensor) -> torch.Tensor: + """Dequantize packed weights""" + # Unpack weights first + original_cols = qweight.shape[1] * 2 # Assuming 4-bit packing + weight = unpack_weight(qweight, original_cols) + + rows, cols = weight.shape + groupsize = min(int(self.groupsize), cols) if self.groupsize != float('inf') else cols + num_groups = (cols + groupsize - 1) // groupsize + + dequantized = torch.zeros_like(weight, dtype=torch.float16) + + for g in range(num_groups): + start_col = g * groupsize + end_col = min((g + 1) * groupsize, cols) + + group_weight = weight[:, start_col:end_col].float() + group_scale = scales[:, g].unsqueeze(-1) + group_zero = zeros[:, g].unsqueeze(-1) + + # Dequantize: (q - zero) * scale + dequantized[:, start_col:end_col] = ((group_weight - group_zero) * group_scale).to(torch.float16) + + return dequantized + + +def quantize_gptq( + model_state_dict: Dict[str, torch.Tensor], + target_layers: Optional[list] = None, + device: str = "cuda" +) -> Dict[str, torch.Tensor]: + """ + Improved GPTQ quantization function + """ + quantized_state_dict = {} + config = GPTQConfig() + + # Default target layers if not specified + if target_layers is None: + target_layers = [] + for name in model_state_dict.keys(): + if any(pattern in name for pattern in [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "kv_proj", "lm_head" + ]): + target_layers.append(name) + + print(f"Quantizing {len(target_layers)} layers...") + + for name, param in tqdm(model_state_dict.items(), desc="Processing layers"): + if name in target_layers and param.dim() == 2: + # Create GPTQ quantizer for this layer + gptq = GPTQ(config) + + # Move weight to device and ensure float32 for quantization + weight = param.to(device).float() + # Quantize the weight + qweight, qzeros, scales, original_cols = gptq.quantize(weight) + + # Store quantized parameters + base_name = name.replace(".weight", "").replace("_weight", "") + quantized_state_dict[f"{base_name}.qweight"] = qweight.cpu() + quantized_state_dict[f"{base_name}.qzeros"] = qzeros.cpu() + quantized_state_dict[f"{base_name}.scales"] = scales.cpu() + quantized_state_dict[f"{base_name}.original_cols"] = torch.tensor(original_cols) + + # Verify quantization quality + dequantized = gptq.dequantize(qweight, qzeros, scales) + error = (weight.half() - dequantized).abs().mean().item() + + else: + # Keep non-quantized parameters as is + quantized_state_dict[name] = param.cpu() + + return quantized_state_dict \ No newline at end of file diff --git a/lite_llama/quantization/quant_config.py b/lite_llama/quantization/quant_config.py new file mode 100644 index 0000000..cb35004 --- /dev/null +++ b/lite_llama/quantization/quant_config.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class AWQConfig: + + """Configuration for AWQ quantization""" + w_bit: int = 4 # Weight quantization bits + group_size: int = 128 # Group size for quantization + zero_point: bool = True # Whether to use zero point + version: str = "GEMM" # GEMM or GEMV + calib_data_size: int = 128 # Calibration dataset size + search_scale: bool = False # Whether to search for optimal scales + auto_scale: bool = True # Automatic scaling + device: str = "cuda" + alpha: float = 0.5 + + +@dataclass +class GPTQConfig: + """GPTQ量化配置""" + w_bit: int = 4 + group_size: int = 64 # 减少组大小提高压缩率 + device: str = "cuda" + quantize_embedding: bool = True + quantize_lm_head: bool = True + adaptive_group_size: bool = True # 自适应组大小 + optimize_for_compression: bool = True # 优化压缩率 + + + +@dataclass +class SmoothQuantConfig: + """Configuration for SmoothQuant""" + alpha: float = 0.5 # Smoothing factor balance between act and weight + w_bit: int = 8 # Weight quantization bits + a_bit: int = 8 # Activation quantization bits + device: str = "cuda" + symmetric_weight: bool = True # Use symmetric quantization for weights + symmetric_activation: bool = False # Use asymmetric quantization for activations + per_channel_weight: bool = True # Per-channel quantization for weights + per_token_activation: bool = True # Per-token quantization for activations + calibration_samples: int = 128 # Number of calibration samples + smooth_layers: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj" + ]) + +@dataclass +class QuantLayerConfig: + """Configuration for QuantLayer""" + quant_layers: List[str] = field(default_factory=lambda: [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + "kv_proj", "lm_head" + ]) diff --git a/lite_llama/quantization/quant_manager.py b/lite_llama/quantization/quant_manager.py new file mode 100644 index 0000000..65af049 --- /dev/null +++ b/lite_llama/quantization/quant_manager.py @@ -0,0 +1,266 @@ +""" +Quantization Manager for lite_llama +Provides unified interface for GPTQ, AWQ, and SmoothQuant +""" +import os +import json +import torch +import torch.nn as nn +from typing import Dict, Optional, Union, Any, List +from pathlib import Path +from tqdm import tqdm + +from .awq import AWQ, quantize_awq +from .gptq import GPTQ, quantize_gptq +from .sq import SmoothQuantizer, apply_smoothquant +from .quant_config import AWQConfig, GPTQConfig, SmoothQuantConfig, QuantLayerConfig + +# Import quantized linear layers +from ..kernels.awq_linear import AWQLinear +from ..kernels.gptq_linear import GPTQLinear +from ..kernels.sq_linear import SmoothQuantLinear +from ..utils.logger import log + +class QuantizationType: + NONE = "none" + GPTQ = "gptq" + AWQ = "awq" + SMOOTHQUANT = "smoothquant" + INT4 = "int4" + INT8 = "int8" + + +class QuantizationManager: + """统一的量化管理器""" + + def __init__(self): + self.supported_methods = { + QuantizationType.GPTQ: self._load_gptq, + QuantizationType.AWQ: self._load_awq, + QuantizationType.SMOOTHQUANT: self._load_smoothquant, + } + + def detect_quantization_type(self, model_path: str) -> str: + """Automatically detect the quantization type of the model""" + model_path = Path(model_path) + + # 检查量化配置文件 + quant_config_path = model_path / "quantization_config.json" + if quant_config_path.exists(): + with open(quant_config_path, 'r') as f: + config = json.load(f) + return config.get("quantization_method", QuantizationType.NONE) + + # 通过权重文件名检测 + weight_files = list(model_path.glob("*.pth")) + if weight_files: + state_dict = torch.load(weight_files[0], map_location="cpu") + + # 检查是否有量化相关的键 + for key in state_dict.keys(): + if "qweight" in key and "qzeros" in key: + if "qscales" in key: + return QuantizationType.AWQ + elif "scales" in key: + return QuantizationType.GPTQ + elif "weight_scale" in key and "smoothing_factor" in key: + return QuantizationType.SMOOTHQUANT + + return QuantizationType.NONE + + def quantize_model( + self, + model_path: str, + output_path: str, + method: str, + config: Optional[Dict] = None, + calibration_data: Optional[Any] = None, + model: Optional[torch.nn.Module] = None + ) -> str: + """Quantitative model""" + log.info(f"Using the {method} method to quantify the model...") + + # 加载原始模型状态字典 + model_path = Path(model_path) + weight_files = list(model_path.glob("*.pth")) + if not weight_files: + raise ValueError(f"The weight file was not found in {model_path}") + + state_dict = torch.load(weight_files[0], map_location="cpu") + + # 根据方法进行量化 + if method == QuantizationType.GPTQ: + config = config or {} + gptq_config = GPTQConfig(**config) + quantized_state_dict = quantize_gptq( + model_state_dict=state_dict, + target_layers=self._get_target_layers(state_dict), + device=gptq_config.device + ) + + elif method == QuantizationType.AWQ: + config = config or {} + awq_config = AWQConfig(**config) + quantized_state_dict = quantize_awq( + model_state_dict=state_dict, + model=model, + config=awq_config, + target_layers=self._get_target_layers(state_dict), + device=awq_config.device + ) + + elif method == QuantizationType.SMOOTHQUANT: + config = config or {} + config.smooth_layers = self._get_target_layers(state_dict) + sq_config = SmoothQuantConfig(**config) + quantized_state_dict = apply_smoothquant( + model_state_dict=state_dict, + config=sq_config, + ) + + else: + raise ValueError(f"Unsupported quantitative methods: {method}") + + # 保存量化后的模型 + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + # 保存权重 + torch.save( + quantized_state_dict, + output_path / f"{model_path.name}.pth", + _use_new_zipfile_serialization=True + ) + + # 复制其他文件 + for file in model_path.glob("*.json"): + if file.name != "quantization_config.json": + import shutil + shutil.copy2(file, output_path) + + # 复制tokenizer文件 + for file in model_path.glob("tokenizer*"): + import shutil + shutil.copy2(file, output_path) + + # 保存量化配置 + quant_config = { + "quantization_method": method, + "config": config, + "quantized_at": torch.cuda.get_device_name() if torch.cuda.is_available() else "cpu" + } + + with open(output_path / "quantization_config.json", 'w') as f: + json.dump(quant_config, f, indent=2) + + log.info(f"Quantification completed! Saved to: {output_path}") + return str(output_path) + + def load_quantized_model( + self, + model_path: str, + model_config: Any, + device: str = "cuda" + ) -> torch.nn.Module: + """Load the quantized model""" + quant_type = self.detect_quantization_type(model_path) + + if quant_type == QuantizationType.NONE: + # 正常加载非量化模型 + return self._load_normal_model(model_path, model_config, device) + + if quant_type in self.supported_methods: + return self.supported_methods[quant_type](model_path, model_config, device) + else: + raise ValueError(f"Unsupported quantization types: {quant_type}") + + def _load_gptq(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the GPTQ quantitative model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.GPTQ, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_awq(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the AWQ quantification model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.AWQ, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_smoothquant(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """Load the SmoothQuant quantitative model""" + from ..models.quantized_models import create_quantized_model + + # 读取量化配置 + quant_config_path = Path(model_path) / "quantization_config.json" + with open(quant_config_path, 'r') as f: + quant_config = json.load(f) + + # 创建量化模型 + model = create_quantized_model( + model_config=model_config, + quantization_method=QuantizationType.SMOOTHQUANT, + quantization_config=quant_config.get("config", {}), + device=device + ) + + # 加载量化权重 + weight_files = list(Path(model_path).glob("*.pth")) + state_dict = torch.load(weight_files[0], map_location=device) + model.load_state_dict(state_dict, strict=False) + + return model + + def _load_normal_model(self, model_path: str, model_config: Any, device: str) -> torch.nn.Module: + """加载非量化模型 - 这里需要调用原有的模型加载逻辑""" + # 这里应该调用现有的模型加载逻辑 + # 需要根据具体的模型架构来实现 + pass + + def _get_target_layers(self, state_dict: Dict[str, torch.Tensor]) -> List[str]: + """Obtain the layers that need to be quantified""" + target_layers = [] + quant_layer = QuantLayerConfig() + for name in state_dict.keys(): + if any(pattern in name for pattern in quant_layer.quant_layers): + target_layers.append(name) + return target_layers + + +# 全局量化管理器实例 +quantization_manager = QuantizationManager() \ No newline at end of file diff --git a/lite_llama/quantization/sq.py b/lite_llama/quantization/sq.py new file mode 100644 index 0000000..e823ea8 --- /dev/null +++ b/lite_llama/quantization/sq.py @@ -0,0 +1,404 @@ +from dataclasses import field + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import Dict, Tuple, Optional, Any, List +from tqdm.auto import tqdm +import time, os, sys, gc +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) +from lite_llama.quantization.quant_config import SmoothQuantConfig + + +class SmoothQuantizer: + def __init__(self, config: SmoothQuantConfig = field(default_factory=SmoothQuantConfig)): + self.config = config + self.alpha = self.config.alpha + self.w_bit = self.config.w_bit + self.a_bit = self.config.a_bit + self.device = self.config.device + + # Quantization ranges + self.w_qmax = 2 ** (self.w_bit - 1) - 1 + self.w_qmin = -2 ** (self.w_bit - 1) + self.a_qmax = 2 ** self.a_bit - 1 + self.a_qmin = 0 + + # Statistics storage + self.activation_stats = {} + self.weight_stats = {} + self.smoothing_factors = {} + + def collect_activation_stats(self, model, calibration_dataloader): + """Collect activation statistics for smoothing factor calculation""" + print("Collecting activation statistics...") + + # Register hooks to collect activations + activation_dict = {} + hooks = [] + + def make_hook(name): + def hook(module, input, output): + if isinstance(input, tuple): + x = input[0] + else: + x = input + + if name not in activation_dict: + activation_dict[name] = [] + + # Store input activations (before linear layer) + if x.dim() == 3: # [batch, seq, hidden] + x_flat = x.view(-1, x.size(-1)) # [batch*seq, hidden] + activation_dict[name].append(x_flat.detach().cpu()) + + return hook + + # Register hooks for target layers + for name, module in model.named_modules(): + if any(layer in name for layer in self.config.smooth_layers): + if isinstance(module, nn.Linear): + hook = module.register_forward_hook(make_hook(name)) + hooks.append(hook) + + # Collect activations + model.eval() + with torch.no_grad(): + for i, batch in enumerate(tqdm(calibration_dataloader, desc="Collecting stats")): + if i >= self.config.calibration_samples: + break + + # Move batch to device + if isinstance(batch, dict): + batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items()} + _ = model(**batch) + else: + batch = batch.to(self.device) + _ = model(batch) + + # Remove hooks + for hook in hooks: + hook.remove() + + # Compute statistics + for name, activations in activation_dict.items(): + if activations: + all_acts = torch.cat(activations, dim=0) # [total_tokens, hidden] + + # Compute per-channel statistics + act_max = all_acts.abs().max(dim=0)[0] # [hidden] + act_mean = all_acts.abs().mean(dim=0) # [hidden] + + self.activation_stats[name] = { + 'max': act_max, + 'mean': act_mean, + 'std': all_acts.std(dim=0) + } + + print(f"Collected stats for {len(self.activation_stats)} layers") + + def compute_smoothing_factors(self, model): + """Compute per-channel smoothing factors""" + print("Computing smoothing factors...") + + for name, module in model.named_modules(): + if any(layer in name for layer in self.config.smooth_layers): + if isinstance(module, nn.Linear) and name in self.activation_stats: + weight = module.weight.data # [out_features, in_features] + + # Get activation statistics + act_stats = self.activation_stats[name] + act_max = act_stats['max'] # [in_features] + + # Compute weight statistics (per input channel) + weight_max = weight.abs().max(dim=0)[0] # [in_features] + + # Compute smoothing factor s = (act_max^alpha / weight_max^(1-alpha)) + # To avoid division by zero + weight_max = torch.clamp(weight_max, min=1e-5) + act_max = torch.clamp(act_max, min=1e-5) + + smoothing_factor = (act_max.pow(self.alpha) / + weight_max.pow(1 - self.alpha)) + + # Normalize to prevent extreme values + smoothing_factor = torch.clamp(smoothing_factor, min=0.01, max=100.0) + + self.smoothing_factors[name] = smoothing_factor.to(self.device) + + print(f"Layer {name}: smoothing range [{smoothing_factor.min():.3f}, {smoothing_factor.max():.3f}]") + + def apply_smoothing(self, model): + """Apply smoothing factors to model weights""" + print("Applying smoothing to model...") + + for name, module in model.named_modules(): + if name in self.smoothing_factors: + smoothing_factor = self.smoothing_factors[name] + + # Apply smoothing: W' = W * diag(s), where s is smoothing factor + # Weight: [out_features, in_features] + # Smoothing: [in_features] + module.weight.data = module.weight.data * smoothing_factor.unsqueeze(0) + + print(f"Applied smoothing to {name}") + + def quantize_weight(self, weight: torch.Tensor, per_channel: bool = True) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize weights to INT8""" + if per_channel: + # Per output channel quantization + dim = 0 # Quantize along output dimension + w_max = weight.abs().max(dim=1, keepdim=True)[0] # [out_features, 1] + else: + # Per tensor quantization + w_max = weight.abs().max() + + # Compute scale + scale = w_max / self.w_qmax + scale = torch.clamp(scale, min=1e-5) + + if self.config.symmetric_weight: + # Symmetric quantization + zero_point = torch.zeros_like(scale) + qweight = torch.round(weight / scale).clamp(self.w_qmin, self.w_qmax) + else: + # Asymmetric quantization + w_min = weight.min(dim=1, keepdim=True)[0] if per_channel else weight.min() + zero_point = torch.round(-w_min / scale).clamp(self.w_qmin, self.w_qmax) + qweight = torch.round(weight / scale + zero_point).clamp(self.w_qmin, self.w_qmax) + + return qweight.to(torch.int8), scale, zero_point + + def dequantize_weight(self, qweight: torch.Tensor, scale: torch.Tensor, + zero_point: torch.Tensor) -> torch.Tensor: + """Dequantize weights from INT8""" + if self.config.symmetric_weight: + return qweight.float() * scale + else: + return (qweight.float() - zero_point) * scale + + def quantize_activation(self, activation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize activations to INT8""" + original_shape = activation.shape + + if self.config.per_token_activation and activation.dim() == 3: + # Per-token quantization: [batch, seq, hidden] -> quantize per token + batch_size, seq_len, hidden_size = activation.shape + activation_flat = activation.view(-1, hidden_size) # [batch*seq, hidden] + + # Compute per-token statistics + act_max = activation_flat.abs().max(dim=-1, keepdim=True)[0] # [batch*seq, 1] + act_min = activation_flat.min(dim=-1, keepdim=True)[0] # [batch*seq, 1] + + # Compute scale and zero point + if self.config.symmetric_activation: + scale = act_max / self.a_qmax # [batch*seq, 1] + zero_point = torch.zeros_like(scale) # [batch*seq, 1] + else: + scale = (act_max - act_min) / self.a_qmax # [batch*seq, 1] + zero_point = torch.round(-act_min / scale).clamp(self.a_qmin, self.a_qmax) # [batch*seq, 1] + + scale = torch.clamp(scale, min=1e-5) + + # Quantize + if self.config.symmetric_activation: + qactivation_flat = torch.round(activation_flat / scale).clamp(-self.a_qmax, self.a_qmax) + else: + qactivation_flat = torch.round(activation_flat / scale + zero_point).clamp(self.a_qmin, self.a_qmax) + + # Reshape everything back to original shape + qactivation = qactivation_flat.view(original_shape).to(torch.int8) + scale = scale.view(batch_size, seq_len, 1) # [batch, seq, 1] + zero_point = zero_point.view(batch_size, seq_len, 1) # [batch, seq, 1] + + else: + # Per-tensor quantization + act_max = activation.abs().max() + act_min = activation.min() + + # Compute scale and zero point (scalars) + if self.config.symmetric_activation: + scale = act_max / self.a_qmax + zero_point = torch.zeros_like(scale) + else: + scale = (act_max - act_min) / self.a_qmax + zero_point = torch.round(-act_min / scale).clamp(self.a_qmin, self.a_qmax) + + scale = torch.clamp(scale, min=1e-5) + + # Quantize + if self.config.symmetric_activation: + qactivation = torch.round(activation / scale).clamp(-self.a_qmax, self.a_qmax) + else: + qactivation = torch.round(activation / scale + zero_point).clamp(self.a_qmin, self.a_qmax) + + qactivation = qactivation.to(torch.int8) + + return qactivation, scale, zero_point + + def dequantize_activation(self, qactivation: torch.Tensor, scale: torch.Tensor, + zero_point: torch.Tensor) -> torch.Tensor: + """Dequantize activations from INT8""" + if self.config.symmetric_activation: + return qactivation.float() * scale + else: + return (qactivation.float() - zero_point) * scale + + +def convert_to_smoothquant(model, calibration_dataloader, config: SmoothQuantConfig = None): + """Convert a model to use SmoothQuant""" + config = config or SmoothQuantConfig() + quantizer = SmoothQuantizer(config) + + # Step 1: Collect activation statistics + quantizer.collect_activation_stats(model, calibration_dataloader) + + # Step 2: Compute smoothing factors + quantizer.compute_smoothing_factors(model) + + # Step 3: Apply smoothing to weights + quantizer.apply_smoothing(model) + + # Step 4: Convert linear layers to quantized versions + quantized_state_dict = {} + + for name, module in model.named_modules(): + if any(layer in name for layer in config.smooth_layers): + if isinstance(module, nn.Linear): + # Quantize the smoothed weights + qweight, weight_scale, weight_zero_point = quantizer.quantize_weight( + module.weight.data, per_channel=config.per_channel_weight + ) + + # Get smoothing factor for this layer + smoothing_factor = quantizer.smoothing_factors.get(name, + torch.ones(module.in_features)) + + + # Store in state dict + base_name = name.replace(".weight", "").replace("_weight", "") + quantized_state_dict[f"{base_name}.qweight"] = qweight.cpu() + quantized_state_dict[f"{base_name}.weight_scale"] = weight_scale.cpu() + quantized_state_dict[f"{base_name}.weight_zero_point"] = weight_zero_point.cpu() + quantized_state_dict[f"{base_name}.smoothing_factor"] = smoothing_factor.cpu() + + if module.bias is not None: + quantized_state_dict[f"{name}"] = module.bias.cpu() + + print(f"Quantized layer: {name}") + + return quantized_state_dict, quantizer + + +def apply_smoothquant(model_state_dict: Dict[str, torch.Tensor], + config: SmoothQuantConfig = None) -> Dict[str, torch.Tensor]: + """ + Apply SmoothQuant to a model state dictionary + + Args: + model_state_dict: Original model state dictionary + calibration_dataloader: DataLoader for calibration data + config: SmoothQuant configuration + + Returns: + Dictionary containing quantized weights and parameters + """ + print("Starting SmoothQuant quantization...") + + config = config or SmoothQuantConfig() + + # Note: This is a simplified version. In practice, you'd need to: + # 1. Load the model from state_dict + # 2. Run calibration + # 3. Apply smoothing and quantization + # 4. Return the quantized state dict + + # For demonstration, we'll show the structure: + quantized_state_dict = {} + + # Process each layer + for name, param in tqdm(model_state_dict.items(), desc="Processing layers"): + if any(layer in name for layer in config.smooth_layers) and param.dim() == 2: + # This would be where you apply the full SmoothQuant pipeline + quantizer = SmoothQuantizer(config) + + # Simulate quantization (in practice, you'd use actual calibration data) + weight = param.float() + qweight, scale, zero_point = quantizer.quantize_weight(weight) + + base_name = name.replace(".weight", "") + quantized_state_dict[f"{base_name}.qweight"] = qweight.cpu() + quantized_state_dict[f"{base_name}.weight_scale"] = scale.cpu() + quantized_state_dict[f"{base_name}.weight_zero_point"] = zero_point.cpu() + + else: + # Keep non-quantized parameters + quantized_state_dict[name] = param.cpu() + + print("SmoothQuant quantization completed!") + return quantized_state_dict + + +# Example usage and testing +if __name__ == "__main__": + # Example configuration + config = SmoothQuantConfig( + alpha=0.5, + w_bit=8, + a_bit=8, + symmetric_weight=True, + symmetric_activation=False, + per_channel_weight=True, + per_token_activation=True + ) + + # Test quantization + print("Testing SmoothQuant implementation...") + + # Create a simple test case + test_weight = torch.randn(1024, 512) * 0.1 + quantizer = SmoothQuantizer(config) + + # Test weight quantization + print("Testing weight quantization...") + qweight, scale, zero_point = quantizer.quantize_weight(test_weight) + reconstructed = quantizer.dequantize_weight(qweight, scale, zero_point) + + # Compute error + error = (test_weight - reconstructed).abs().mean() + print(f"Weight quantization error: {error:.6f}") + print(f"Weight shapes - original: {test_weight.shape}, quantized: {qweight.shape}") + print(f"Scale shape: {scale.shape}, Zero point shape: {zero_point.shape}") + + # Test activation quantization with different shapes + print("\nTesting activation quantization...") + + # Test 3D tensor (typical transformer input) + test_activation_3d = torch.randn(8, 128, 512) * 2.0 + print(f"Input activation shape: {test_activation_3d.shape}") + + qact, act_scale, act_zero_point = quantizer.quantize_activation(test_activation_3d) + print(f"Quantized activation shape: {qact.shape}") + print(f"Activation scale shape: {act_scale.shape}") + print(f"Activation zero point shape: {act_zero_point.shape}") + + reconstructed_act = quantizer.dequantize_activation(qact, act_scale, act_zero_point) + print(f"Reconstructed activation shape: {reconstructed_act.shape}") + + act_error = (test_activation_3d - reconstructed_act).abs().mean() + print(f"Activation quantization error: {act_error:.6f}") + + # Test 2D tensor + print("\nTesting 2D activation...") + test_activation_2d = torch.randn(64, 512) * 2.0 + qact_2d, act_scale_2d, act_zero_point_2d = quantizer.quantize_activation(test_activation_2d) + reconstructed_act_2d = quantizer.dequantize_activation(qact_2d, act_scale_2d, act_zero_point_2d) + + act_error_2d = (test_activation_2d - reconstructed_act_2d).abs().mean() + print(f"2D Activation quantization error: {act_error_2d:.6f}") + print(f"2D shapes - scale: {act_scale_2d.shape}, zero_point: {act_zero_point_2d.shape}") + + print("SmoothQuant implementation test completed!") \ No newline at end of file diff --git a/lite_llama/quantization/utils.py b/lite_llama/quantization/utils.py new file mode 100644 index 0000000..8b3f65f --- /dev/null +++ b/lite_llama/quantization/utils.py @@ -0,0 +1,49 @@ +import torch + + +def pack_weight(weight): + """ + Pack two 4-bit values into one uint8 value consistently + + Args: + weight: Tensor of shape [out_features, in_features] with values in [0, 15] + + Returns: + packed: Tensor of shape [out_features, in_features//2] with packed values + """ + rows, cols = weight.shape + + # Ensure even number of columns for packing + if cols % 2 != 0: + weight = torch.cat([weight, torch.zeros(rows, 1, dtype=weight.dtype, device=weight.device)], dim=1) + cols += 1 + + # Pack: lower 4 bits from even indices, upper 4 bits from odd indices + # Format: [odd_value << 4] | even_value + packed = (weight[:, 0::2] & 0xF) | ((weight[:, 1::2] & 0xF) << 4) + + return packed.contiguous().to(torch.uint8) + + +def unpack_weight(packed_weight, original_cols): + """ + Unpack uint8 values back to two 4-bit values consistently + + Args: + packed_weight: Packed tensor of shape [out_features, packed_cols] + original_cols: Original number of columns before packing + + Returns: + unpacked: Tensor of shape [out_features, original_cols] with unpacked values + """ + rows, packed_cols = packed_weight.shape + + # Allocate unpacked tensor + unpacked = torch.zeros((rows, packed_cols * 2), dtype=torch.uint8, device=packed_weight.device) + + # Unpack: even positions get lower 4 bits, odd positions get upper 4 bits + unpacked[:, 0::2] = packed_weight & 0xF # Lower 4 bits + unpacked[:, 1::2] = (packed_weight >> 4) & 0xF # Upper 4 bits + + # Trim to original size + return unpacked[:, :original_cols].contiguous() diff --git a/lite_llama/utils/common.py b/lite_llama/utils/common.py index 55dbbc9..f271290 100644 --- a/lite_llama/utils/common.py +++ b/lite_llama/utils/common.py @@ -2,7 +2,9 @@ import time, os import subprocess from typing import List, Optional - +import torch +from contextlib import contextmanager +import functools def read_json(json_path): with open(json_path, "r") as json_file: @@ -37,7 +39,7 @@ def getProjectPath(): return os.path.abspath(os.path.join(script_path, "..")) -def get_gpu_memory(gpu_type="amd", device_id="0"): +def get_gpu_memory(gpu_type, device_id="0"): try: if gpu_type == "amd": result = subprocess.run( @@ -67,7 +69,7 @@ def get_gpu_memory(gpu_type="amd", device_id="0"): elif gpu_type == "cpu": return None except Exception as e: - from utils.logger import log + from lite_llama.utils.logger import log log.warning(f"Unable to fetch GPU memory: {e}") return None @@ -82,7 +84,7 @@ def count_tokens(texts: List[str], tokenizer) -> int: def get_model_type(checkpoint_path: str) -> str | None: - from utils.logger import log + from .logger import log model_type = ["llama", "falcon", "mpt", "qwen2", "llava"] @@ -94,3 +96,126 @@ def get_model_type(checkpoint_path: str) -> str | None: return m log.error(f"No model type found: {checkpoint_path}") return None + + +def check_model_compatibility(model_path): + """Check if the model is compatible for quantization""" + # Check if model path exists and contains .pth files + if not os.path.exists(model_path): + return False, f"Model path does not exist: {model_path}" + + pth_files = [f for f in os.listdir(model_path) if f.endswith('.pth')] + if not pth_files: + return False, f"No .pth files found in {model_path}" + + # Check if required config files exist + config_files = ["config.json", "tokenizer_config.json"] + missing_configs = [f for f in config_files if not os.path.exists(os.path.join(model_path, f))] + if missing_configs: + print(f"Warning: Missing config files: {missing_configs}") + + return True, "Model is compatible" + + +def get_model_info(model_path): + """Get basic information about the model""" + model_info = { + "model_name": os.path.basename(model_path), + "model_type": "unknown", + "size": 0.0 + } + + # Detect model type from path or config + model_name_lower = model_info["model_name"].lower() + if "llava" in model_name_lower: + model_info["model_type"] = "llava" + elif "qwen2" in model_name_lower: + model_info["model_type"] = "qwen2" + elif "llama" in model_name_lower: + model_info["model_type"] = "llama" + + # Try to read from config.json + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + with open(config_path, 'r') as f: + config = json.load(f) + if "architectures" in config: + arch = config["architectures"][0].lower() + if "llava" in arch: + model_info["model_type"] = "llava" + elif "qwen2" in arch: + model_info["model_type"] = "qwen2" + elif "llama" in arch: + model_info["model_type"] = "llama" + except: + pass + + # Calculate total size + total_size = 0 + for f in os.listdir(model_path): + if f.endswith('.pth'): + file_path = os.path.join(model_path, f) + total_size += os.path.getsize(file_path) + + model_info["size"] = total_size / (1024 ** 3) # Convert to GB + + return model_info + + +def get_model_dtype(checkpoints_dir: str): + """ + Get the model dtype from config.json + + Args: + checkpoints_dir: Path to model checkpoint directory + + Returns: + torch.dtype or str: The dtype specified in config.json + """ + config_path = os.path.join(checkpoints_dir, "config.json") + + try: + with open(config_path, 'r') as f: + config = json.load(f) + + torch_dtype_str = config.get("torch_dtype", "float16").lower() + + # Map string to torch dtype or string identifiers for quantized formats + dtype_mapping = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float": torch.float32, + "int8": torch.int8, + "int4": "int4", # Placeholder, since PyTorch doesn't natively support int4 + } + + dtype = dtype_mapping.get(torch_dtype_str, torch.float16) + print(f"Detected model dtype from config: {torch_dtype_str} -> {dtype}") + + return dtype + + except Exception as e: + print(f"Warning: Could not read dtype from config.json: {e}") + print("Defaulting to torch.float16") + return torch.float16 + + +@contextmanager +def quantization(mode: str = None): + quantized_linear_cls = None + if mode == 'gptq.int4': + from ..kernels.gptq_linear import GPTQLinear + quantized_linear_cls = functools.partial(GPTQLinear, bits=4, tile_cols=-1) + elif mode is not None: + raise ValueError(f"Unknown quantization mode: {mode}") + + enabled = mode is not None + torch_linear_cls = torch.nn.Linear + if enabled: + torch.nn.Linear = quantized_linear_cls + yield + if enabled: + torch.nn.Linear = torch_linear_cls + diff --git a/quantize_lite_llama.py b/quantize_lite_llama.py new file mode 100644 index 0000000..30b7248 --- /dev/null +++ b/quantize_lite_llama.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +quantize_lite_llama.py +~~~~~~~~~~~~~~~~~~~~ +用于量化lite_llama格式模型的工具脚本 + +支持GPTQ、AWQ、SmoothQuant三种量化方法 + +Usage +----- +# GPTQ量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method gptq --bits 4 --group-size 128 + +# AWQ量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method awq --bits 4 --group-size 128 --calib-data /path/to/calib.txt + +# SmoothQuant量化 +python quantize_lite_llama.py --model-path /path/to/model --output-path /path/to/output --method smoothquant --alpha 0.5 + +Author: AI Assistant (2025-01-22) +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path +from typing import Optional, Dict, Any, List +import torch +from tqdm import tqdm + +# Add lite_llama to Python path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from lite_llama.quantization.quant_manager import quantization_manager, QuantizationType +from lite_llama.quantization.quant_config import AWQConfig, GPTQConfig, SmoothQuantConfig +from lite_llama.utils.common import get_model_info, check_model_compatibility +from lite_llama.utils.logger import log +from lite_llama.executor.model_executor import ModelExecutor +from transformers import AutoTokenizer + + +class CalibrationDataLoader: + """校准数据加载器""" + + def __init__(self, data_path: str, tokenizer_path: str, max_samples: int = 128, max_length: int = 512): + self.data_path = data_path + self.max_samples = max_samples + self.max_length = max_length + + # 加载分词器 + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # 加载校准数据 + self.texts = self._load_calibration_data() + + def _load_calibration_data(self) -> List[str]: + """加载校准数据""" + texts = [] + + if self.data_path.endswith('.txt'): + # 纯文本文件,每行一个样本 + with open(self.data_path, 'r', encoding='utf-8') as f: + texts = [line.strip() for line in f if line.strip()] + + elif self.data_path.endswith('.json'): + # JSON文件 + with open(self.data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, list): + # 假设是文本列表 + texts = [item if isinstance(item, str) else item.get('text', '') for item in data] + else: + # 假设是包含文本字段的对象 + texts = [data.get('text', '')] + + elif self.data_path.endswith('.jsonl'): + # JSONL文件 + with open(self.data_path, 'r', encoding='utf-8') as f: + for line in f: + item = json.loads(line.strip()) + texts.append(item.get('text', '')) + + else: + raise ValueError(f"Unsupported file formats: {self.data_path}") + + # 限制样本数量 + texts = texts[:self.max_samples] + log.info(f"{len(texts)} calibration samples were loaded") + + return texts + + def __len__(self): + return len(self.texts) + + def __iter__(self): + """返回批次数据的迭代器""" + for text in self.texts: + # 编码文本 + encoding = self.tokenizer( + text, + return_tensors='pt', + max_length=self.max_length, + truncation=True, + padding=True + ) + + yield encoding + + +def create_default_calibration_data(tokenizer_path: str, num_samples: int = 32) -> List[str]: + """创建默认的校准数据""" + default_texts = [ + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming the world.", + "Machine learning models require careful optimization.", + "Deep neural networks can learn complex patterns.", + "Natural language processing enables human-computer interaction.", + "Computer vision systems can understand visual content.", + "Quantization reduces model size while maintaining accuracy.", + "Large language models demonstrate emergent capabilities.", + "Transformer architectures have revolutionized AI.", + "Self-attention mechanisms capture long-range dependencies." + ] + + # 重复样本以达到所需数量 + texts = (default_texts * ((num_samples // len(default_texts)) + 1))[:num_samples] + log.info(f"Using the default calibration data, there are a total of {len(texts)} samples") + + return texts + + +def validate_quantization_config(method: str, config: Dict[str, Any]) -> Dict[str, Any]: + """验证和标准化量化配置""" + + if method == QuantizationType.GPTQ: + validated_config = { + 'w_bit': config.get('bits', 4), + 'group_size': config.get('group_size', 128), + 'device': config.get('device', 'cuda') + } + + # 验证参数范围 + if validated_config['w_bit'] not in [2, 3, 4, 8]: + raise ValueError(f"The number of bits not supported by GPTQ: {validated_config['w_bit']}") + + elif method == QuantizationType.AWQ: + validated_config = { + 'w_bit': config.get('bits', 4), + 'group_size': config.get('group_size', 128), + 'zero_point': config.get('zero_point', True), + 'search_scale': config.get('search_scale', False), + 'auto_scale': config.get('auto_scale', True), + 'alpha': config.get('alpha', 0.5), + 'device': config.get('device', 'cuda') + } + + if validated_config['w_bit'] not in [4, 8]: + raise ValueError(f"The number of bits not supported by AWQ: {validated_config['w_bit']}") + + elif method == QuantizationType.SMOOTHQUANT: + validated_config = { + 'alpha': config.get('alpha', 0.5), + 'w_bit': config.get('w_bits', 8), + 'a_bit': config.get('a_bits', 8), + 'symmetric_weight': config.get('symmetric_weight', True), + 'symmetric_activation': config.get('symmetric_activation', False), + 'per_channel_weight': config.get('per_channel_weight', True), + 'per_token_activation': config.get('per_token_activation', True), + 'calibration_samples': config.get('calibration_samples', 128), + 'device': config.get('device', 'cuda') + } + + if not (0.0 <= validated_config['alpha'] <= 1.0): + raise ValueError(f"The alpha parameter of SmoothQuant must be between 0 and 1: {validated_config['alpha']}") + + else: + raise ValueError(f"Unsupported quantitative methods: {method}") + + return validated_config + + +def main(): + parser = argparse.ArgumentParser(description="Quantify the model in lite_llama format") + + # 基本参数 + parser.add_argument("--model-path", type=str, required=True, + help="Input model path") + parser.add_argument("--output-path", type=str, required=True, + help="Output model path") + parser.add_argument("--method", type=str, required=True, + choices=['gptq', 'awq', 'smoothquant'], + help="Quantitative method") + + # 量化参数 + parser.add_argument("--bits", type=int, default=4, + help="Quantification bit number (default: 4)") + parser.add_argument("--group-size", type=int, default=128, + help="Group size (default: 128)") + + # AWQ特有参数 + parser.add_argument("--alpha", type=float, default=0.5, + help="The alpha parameter of AWQ/SmoothQuant (default: 0.5)") + parser.add_argument("--search-scale", action='store_true', + help="Does AWQ search for the optimal scaling factor") + parser.add_argument("--auto-scale", action='store_true', default=True, + help="Does AWQ scale automatically") + + # SmoothQuant特有参数 + parser.add_argument("--w-bits", type=int, default=8, + help="Weighted quantification number of bits (SmoothQuant, default: 8)") + parser.add_argument("--a-bits", type=int, default=8, + help="Activation quantization bit number (SmoothQuant, default: 8)") + + # 校准数据 + parser.add_argument("--calib-data", type=str, default=None, + help="Calibrate the data file path (.txt/.json/.jsonl)") + parser.add_argument("--calib-samples", type=int, default=128, + help="Calibration sample quantity (default: 128)") + parser.add_argument("--max-length", type=int, default=512, + help="The maximum length of the calibration data (default: 512)") + + # 其他参数 + parser.add_argument("--device", type=str, default="cuda", + choices=['cuda', 'cpu'], + help="device (default: cuda)") + parser.add_argument("--no-verify", action='store_true', + help="Skip quantitative validation") + + args = parser.parse_args() + + # 检查模型兼容性 + is_compatible, message = check_model_compatibility(args.model_path) + if not is_compatible: + log.error(f"The model compatibility check failed: {message}") + return 1 + + # 获取模型信息 + model_info = get_model_info(args.model_path) + log.info(f"Model information: {model_info}") + + # 准备量化配置 + config = { + 'bits': args.bits, + 'group_size': args.group_size, + 'alpha': args.alpha, + 'search_scale': args.search_scale, + 'auto_scale': args.auto_scale, + 'w_bits': args.w_bits, + 'a_bits': args.a_bits, + 'device': args.device, + 'calibration_samples': args.calib_samples + } + + # 验证配置 + try: + validated_config = validate_quantization_config(args.method, config) + log.info(f"Quantitative configuration: {validated_config}") + except ValueError as e: + log.error(f"Configuration verification failed: {e}") + return 1 + + # 准备校准数据 + calibration_data = None + model = None + + if args.method in ['awq', 'smoothquant']: + log.info("Prepare calibration data...") + + if args.calib_data: + # 使用用户提供的校准数据 + try: + calibration_data = CalibrationDataLoader( + args.calib_data, + args.model_path, + args.calib_samples, + args.max_length + ) + log.info(f"Load calibration data: {len(calibration_data)} samples") + except Exception as e: + log.error(f"Failed to load calibration data: {e}") + log.info("The default calibration data will be used") + calibration_data = create_default_calibration_data( + args.model_path, args.calib_samples + ) + else: + # 使用默认校准数据 + calibration_data = create_default_calibration_data( + args.model_path, args.calib_samples + ) + + # 如果需要,加载原始模型用于校准 + if args.method == 'awq': + log.info("Load the original model for AWQ calibration...") + try: + model_executor = ModelExecutor.build( + checkpoints_dir=args.model_path, + max_seq_len=2048, + max_gpu_num_blocks=None, + compiled_model=False, + device=args.device + ) + model = model_executor.model + log.info("The model has been loaded successfully.") + except Exception as e: + log.error(f"Model loading failed: {e}") + return 1 + + # 执行量化 + log.info(f"Quantifying the model using the {args.method.upper()} method...") + start_time = time.time() + + try: + output_path = quantization_manager.quantize_model( + model_path=args.model_path, + output_path=args.output_path, + method=args.method, + config=validated_config, + calibration_data=calibration_data, + model=model + ) + + quantization_time = time.time() - start_time + log.info(f"Quantification completed! Time consumption: {quantization_time:.2f}s") + log.info(f"The quantitative model saved to: {output_path}") + + except Exception as e: + log.error(f"Quantitative failure: {e}") + return 1 + + # 验证量化结果 + if not args.no_verify: + log.info("Verify the quantification results...") + try: + # 检测量化类型 + detected_type = quantization_manager.detect_quantization_type(output_path) + if detected_type == args.method: + log.info(f"The quantitative type verification has been passed: {detected_type}") + else: + log.warning(f"Quantization type mismatch: expected {args.method}, detected {detected_type}") + + # 检查文件大小 + original_size = sum(f.stat().st_size for f in Path(args.model_path).glob("*.pth")) + quantized_size = sum(f.stat().st_size for f in Path(output_path).glob("*.pth")) + compression_ratio = original_size / quantized_size if quantized_size > 0 else 1.0 + + log.info(f"Original model size: {original_size / (1024 ** 3):.2f} GB") + log.info(f"Quantitative model size: {quantized_size / (1024 ** 3):.2f} GB") + log.info(f"Compression ratio: {compression_ratio:.2f}x") + + except Exception as e: + log.warning(f"Quantitative verification failed: {e}") + + log.info("Quantitative task completion!") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/quant/__init__.py b/tests/quant/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/quant/test_AWQLinear.py b/tests/quant/test_AWQLinear.py new file mode 100644 index 0000000..e2efd5c --- /dev/null +++ b/tests/quant/test_AWQLinear.py @@ -0,0 +1,493 @@ +import torch +import torch.nn as nn +import time +import psutil +import os +import numpy as np +from typing import Dict, List, Tuple +from dataclasses import dataclass + +import sys, os, time +import torch +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +try: + from lite_llama.quantization.utils import pack_weight, unpack_weight + from lite_llama.quantization.quant_config import AWQConfig + from lite_llama.quantization.awq import AWQ, quantize_awq + from lite_llama.kernels.awq_linear import AWQLinear + + UTILS_AVAILABLE = True +except ImportError: + UTILS_AVAILABLE = False + + +@dataclass +class BenchmarkResults: + """Store benchmark results for comparison""" + layer_size: Tuple[int, int] + batch_size: int + sequence_length: int + + # Speed metrics (milliseconds) + fp16_time: float + awq_time: float + speedup: float + + # Accuracy metrics + max_error: float + mean_error: float + rmse: float + cosine_similarity: float + + # Memory metrics (MB) + fp16_memory: float + awq_memory: float + memory_saving: float + + +class AWQBenchmark: + """Comprehensive benchmark for AWQ vs nn.Linear""" + + def __init__(self, device: str = "cuda"): + self.device = torch.device(device if torch.cuda.is_available() else "cpu") + self.results: List[BenchmarkResults] = [] + + def get_memory_usage(self) -> float: + """Get current GPU memory usage in MB""" + if self.device.type == "cuda": + return torch.cuda.memory_allocated() / 1024 / 1024 + else: + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024 / 1024 + + def measure_model_memory(self, model: nn.Module) -> float: + """Measure memory footprint of a model in MB""" + param_size = 0 + buffer_size = 0 + + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + return (param_size + buffer_size) / 1024 / 1024 + + def create_test_data(self, batch_size: int, seq_len: int, hidden_dim: int) -> torch.Tensor: + """Create test input data""" + return torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.float16, device=self.device) + + def quantize_linear_layer(self, linear_layer: nn.Linear, group_size: int = 128) -> AWQLinear: + """Quantize a linear layer using AWQ""" + # Create state dict for the layer + state_dict = {"test_layer.weight": linear_layer.weight.data} + if linear_layer.bias is not None: + state_dict["test_layer.bias"] = linear_layer.bias.data + + # Quantize using AWQ + from lite_llama.quantization.quant_config import AWQConfig + awq_config = AWQConfig() + quantized_dict = quantize_awq( + model_state_dict=state_dict, + config=awq_config, + target_layers=["test_layer.weight"], + device=str(self.device) + ) + + # Extract quantized parameters + qweight = quantized_dict["test_layer.qweight"] + qscales = quantized_dict["test_layer.qscales"] + qzeros = quantized_dict["test_layer.qzeros"] + + # Create AWQ linear layer + awq_layer = AWQLinear.from_float( + linear_layer, qweight, qscales, qzeros, group_size + ) + + return awq_layer.to(self.device) + + def warmup_triton_kernel(self, model: nn.Module, input_data: torch.Tensor, warmup_runs: int = 50): + """Extensive warmup for Triton kernels to ensure compilation and caching""" + model.eval() + print(f" Warming up Triton kernels ({warmup_runs} runs)...") + + with torch.no_grad(): + # First few runs trigger compilation + for i in range(warmup_runs): + _ = model(input_data) + if i < 10 and self.device.type == "cuda": + # Extra synchronization for first few runs to handle compilation + torch.cuda.synchronize() + + def measure_inference_time(self, model: nn.Module, input_data: torch.Tensor, + warmup_runs: int = 20, test_runs: int = 100, + is_triton: bool = False) -> float: + """Measure average inference time in milliseconds""" + model.eval() + + # Extended warmup for Triton kernels + if is_triton: + self.warmup_triton_kernel(model, input_data, warmup_runs=50) + else: + # Regular warmup for PyTorch + with torch.no_grad(): + for _ in range(warmup_runs): + _ = model(input_data) + + # Clear cache and synchronize + if self.device.type == "cuda": + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Measure time + start_time = time.perf_counter() + + with torch.no_grad(): + for _ in range(test_runs): + _ = model(input_data) + + if self.device.type == "cuda": + torch.cuda.synchronize() + + end_time = time.perf_counter() + + avg_time_ms = (end_time - start_time) * 1000 / test_runs + return avg_time_ms + + def compute_accuracy_metrics(self, fp16_output: torch.Tensor, + awq_output: torch.Tensor) -> Dict[str, float]: + """Compute accuracy metrics between FP16 and AWQ outputs""" + # Flatten tensors for easier computation + fp16_flat = fp16_output.view(-1).float() + awq_flat = awq_output.view(-1).float() + + # Error metrics + error = (fp16_flat - awq_flat).abs() + max_error = error.max().item() + mean_error = error.mean().item() + rmse = torch.sqrt(((fp16_flat - awq_flat) ** 2).mean()).item() + + # Cosine similarity + cos_sim = torch.nn.functional.cosine_similarity( + fp16_flat.unsqueeze(0), awq_flat.unsqueeze(0) + ).item() + + return { + "max_error": max_error, + "mean_error": mean_error, + "rmse": rmse, + "cosine_similarity": cos_sim + } + + def benchmark_layer_size(self, in_features: int, out_features: int, + batch_size: int = 16, seq_len: int = 128) -> BenchmarkResults: + """Benchmark a specific layer configuration""" + print(f"\nBenchmarking layer [{in_features} -> {out_features}], " + f"batch_size={batch_size}, seq_len={seq_len}") + + # Create original FP16 linear layer + fp16_layer = nn.Linear(in_features, out_features, bias=True) + fp16_layer = fp16_layer.to(self.device).half() + + # Quantize to AWQ + print(" Quantizing layer...") + awq_layer = self.quantize_linear_layer(fp16_layer) + + # Create test input + input_data = self.create_test_data(batch_size, seq_len, in_features) + + # Measure memory usage + fp16_memory = self.measure_model_memory(fp16_layer) + awq_memory = self.measure_model_memory(awq_layer) + memory_saving = (fp16_memory - awq_memory) / fp16_memory * 100 + + print(f" Memory: FP16={fp16_memory:.2f}MB, AWQ={awq_memory:.2f}MB, " + f"Saving={memory_saving:.1f}%") + + # Measure inference speed with proper warmup + print(" Measuring FP16 speed...") + fp16_time = self.measure_inference_time(fp16_layer, input_data, is_triton=False) + + print(" Measuring AWQ speed (with Triton warmup)...") + awq_time = self.measure_inference_time(awq_layer, input_data, is_triton=True) + + speedup = fp16_time / awq_time if awq_time > 0 else 0.0 + + print(f" Speed: FP16={fp16_time:.3f}ms, AWQ={awq_time:.3f}ms, " + f"Speedup={speedup:.2f}x") + + # Measure accuracy + with torch.no_grad(): + fp16_output = fp16_layer(input_data) + awq_output = awq_layer(input_data) + + accuracy_metrics = self.compute_accuracy_metrics(fp16_output, awq_output) + + print(f" Accuracy: RMSE={accuracy_metrics['rmse']:.6f}, " + f"CosSim={accuracy_metrics['cosine_similarity']:.6f}") + + # Store results + result = BenchmarkResults( + layer_size=(in_features, out_features), + batch_size=batch_size, + sequence_length=seq_len, + fp16_time=fp16_time, + awq_time=awq_time, + speedup=speedup, + max_error=accuracy_metrics["max_error"], + mean_error=accuracy_metrics["mean_error"], + rmse=accuracy_metrics["rmse"], + cosine_similarity=accuracy_metrics["cosine_similarity"], + fp16_memory=fp16_memory, + awq_memory=awq_memory, + memory_saving=memory_saving + ) + + self.results.append(result) + return result + + def run_comprehensive_benchmark(self): + """Run benchmark across different layer sizes and configurations""" + print("Starting comprehensive AWQ vs FP16 benchmark...") + print(f"Device: {self.device}") + + # Test configurations + layer_configs = [ + (768, 768), # Small transformer layer + (768, 3072), # FFN up projection + (3072, 768), # FFN down projection + (1024, 1024), # Medium layer + (2048, 2048), # Large layer + (4096, 4096), # Very large layer + ] + + batch_configs = [ + (1, 128), # Single sequence + (8, 128), # Small batch + (16, 512), # Medium batch + longer sequence + (32, 128), # Large batch + ] + + # Run benchmarks + for in_features, out_features in layer_configs: + for batch_size, seq_len in batch_configs: + try: + self.benchmark_layer_size( + in_features, out_features, batch_size, seq_len + ) + except Exception as e: + print(f" Error: {e}") + continue + + print(f"\nCompleted {len(self.results)} benchmark tests") + + def analyze_results(self): + """Analyze and summarize benchmark results""" + if not self.results: + print("No results to analyze") + return + + print("\n" + "=" * 80) + print("BENCHMARK SUMMARY") + print("=" * 80) + + # Speed analysis + avg_speedup = np.mean([r.speedup for r in self.results]) + max_speedup = max([r.speedup for r in self.results]) + min_speedup = min([r.speedup for r in self.results]) + + print(f"\nSPEED ANALYSIS:") + print(f" Average speedup: {avg_speedup:.2f}x") + print(f" Max speedup: {max_speedup:.2f}x") + print(f" Min speedup: {min_speedup:.2f}x") + + # Memory analysis + avg_memory_saving = np.mean([r.memory_saving for r in self.results]) + max_memory_saving = max([r.memory_saving for r in self.results]) + min_memory_saving = min([r.memory_saving for r in self.results]) + + print(f"\nMEMORY ANALYSIS:") + print(f" Average memory saving: {avg_memory_saving:.1f}%") + print(f" Max memory saving: {max_memory_saving:.1f}%") + print(f" Min memory saving: {min_memory_saving:.1f}%") + + # Accuracy analysis + avg_rmse = np.mean([r.rmse for r in self.results]) + max_rmse = max([r.rmse for r in self.results]) + avg_cosine_sim = np.mean([r.cosine_similarity for r in self.results]) + min_cosine_sim = min([r.cosine_similarity for r in self.results]) + + print(f"\nACCURACY ANALYSIS:") + print(f" Average RMSE: {avg_rmse:.6f}") + print(f" Max RMSE: {max_rmse:.6f}") + print(f" Average cosine similarity: {avg_cosine_sim:.6f}") + print(f" Min cosine similarity: {min_cosine_sim:.6f}") + + # Find best and worst cases + best_speedup_idx = np.argmax([r.speedup for r in self.results]) + worst_accuracy_idx = np.argmax([r.rmse for r in self.results]) + + print(f"\nBEST SPEEDUP:") + best_result = self.results[best_speedup_idx] + print(f" Layer size: {best_result.layer_size}") + print(f" Batch config: {best_result.batch_size}x{best_result.sequence_length}") + print(f" Speedup: {best_result.speedup:.2f}x") + + print(f"\nWORST ACCURACY:") + worst_result = self.results[worst_accuracy_idx] + print(f" Layer size: {worst_result.layer_size}") + print(f" Batch config: {worst_result.batch_size}x{worst_result.sequence_length}") + print(f" RMSE: {worst_result.rmse:.6f}") + print(f" Cosine similarity: {worst_result.cosine_similarity:.6f}") + + def export_results_csv(self, filename: str = "awq_benchmark_results.csv"): + """Export results to CSV file""" + if not self.results: + print("No results to export") + return + + import csv + + with open(filename, 'w', newline='') as csvfile: + fieldnames = [ + 'layer_size', 'batch_size', 'sequence_length', + 'fp16_time_ms', 'awq_time_ms', 'speedup', + 'max_error', 'mean_error', 'rmse', 'cosine_similarity', + 'fp16_memory_mb', 'awq_memory_mb', 'memory_saving_percent' + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + for result in self.results: + writer.writerow({ + 'layer_size': f"{result.layer_size[0]}x{result.layer_size[1]}", + 'batch_size': result.batch_size, + 'sequence_length': result.sequence_length, + 'fp16_time_ms': result.fp16_time, + 'awq_time_ms': result.awq_time, + 'speedup': result.speedup, + 'max_error': result.max_error, + 'mean_error': result.mean_error, + 'rmse': result.rmse, + 'cosine_similarity': result.cosine_similarity, + 'fp16_memory_mb': result.fp16_memory, + 'awq_memory_mb': result.awq_memory, + 'memory_saving_percent': result.memory_saving + }) + + print(f"Results exported to {filename}") + + +def quick_demo(): + """Quick demonstration of AWQ vs FP16 comparison""" + print("Quick AWQ vs FP16 Demo") + print("=" * 50) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Create a simple test case + in_features, out_features = 768, 768 + batch_size, seq_len = 16, 128 + + print(f"Testing {in_features}→{out_features} layer with batch={batch_size}, seq_len={seq_len}") + + # Create FP16 linear layer + fp16_layer = nn.Linear(in_features, out_features, bias=True) + fp16_layer = fp16_layer.to(device).half() + + # Create test input + input_data = torch.randn(batch_size, seq_len, in_features, dtype=torch.float16, device=device) + + # Create benchmark instance + benchmark = AWQBenchmark(device=str(device)) + + # Get FP16 timing + print("Measuring FP16 performance...") + fp16_time = benchmark.measure_inference_time(fp16_layer, input_data, is_triton=False) + + # Create quantized version + print("Quantizing layer with AWQ...") + awq_layer = benchmark.quantize_linear_layer(fp16_layer) + + # Get AWQ timing with proper Triton warmup + print("Measuring AWQ performance (with Triton warmup)...") + awq_time = benchmark.measure_inference_time(awq_layer, input_data, is_triton=True) + + # Get outputs for accuracy measurement + with torch.no_grad(): + fp16_output = fp16_layer(input_data) + awq_output = awq_layer(input_data) + + # Calculate metrics + fp16_memory = benchmark.measure_model_memory(fp16_layer) + awq_memory = benchmark.measure_model_memory(awq_layer) + memory_saving = (fp16_memory - awq_memory) / fp16_memory * 100 + speedup = fp16_time / awq_time if awq_time > 0 else 0 + + # Accuracy metrics + accuracy_metrics = benchmark.compute_accuracy_metrics(fp16_output, awq_output) + + # Print results + print(f"\nResults:") + print(f" Speed:") + print(f" FP16: {fp16_time:.3f}ms") + print(f" AWQ: {awq_time:.3f}ms") + print(f" Speedup: {speedup:.2f}x") + print(f" Memory:") + print(f" FP16: {fp16_memory:.2f}MB") + print(f" AWQ: {awq_memory:.2f}MB") + print(f" Saving: {memory_saving:.1f}%") + print(f" Accuracy:") + print(f" RMSE: {accuracy_metrics['rmse']:.6f}") + print(f" Cosine Similarity: {accuracy_metrics['cosine_similarity']:.6f}") + print(f" Max Error: {accuracy_metrics['max_error']:.6f}") + + +def main(): + """Main function to run the comprehensive benchmark""" + + # Check if CUDA is available + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Running benchmark on: {device}") + + if device == "cpu": + print("Warning: Running on CPU. Triton kernels require CUDA for optimal performance.") + + # Ask user which test to run + print("\nChoose test type:") + print("1. Quick demo (single layer test)") + print("2. Comprehensive benchmark (multiple configurations)") + + try: + choice = input("Enter choice (1 or 2, default=1): ").strip() + if choice == "2": + # Create benchmark instance + benchmark = AWQBenchmark(device=device) + + # Run comprehensive benchmark + benchmark.run_comprehensive_benchmark() + + # Analyze results + benchmark.analyze_results() + + # Export results + benchmark.export_results_csv() + + print("\nBenchmark completed!") + else: + # Run quick demo + quick_demo() + + except KeyboardInterrupt: + print("\nBenchmark interrupted by user") + except Exception as e: + print(f"Error running benchmark: {e}") + # Fallback to quick demo + print("Running quick demo instead...") + quick_demo() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/quant/test_GPTQLinear.py b/tests/quant/test_GPTQLinear.py new file mode 100644 index 0000000..397e719 --- /dev/null +++ b/tests/quant/test_GPTQLinear.py @@ -0,0 +1,263 @@ +import torch +import triton +import triton.language as tl +import torch.nn as nn +from typing import Optional + + +@triton.jit +def _int4_linear_kernel( + input_ptr, # [M, K] + qweight_ptr, # [N, K//2] + scales_ptr, # [N, K//groupsize] + zeros_ptr, # [N, K//groupsize] + output_ptr, # [M, N] + bias_ptr, # [N] or dummy + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + groupsize: tl.constexpr, + stride_im, stride_ik, + stride_wn, stride_wk, + stride_sn, stride_sg, + stride_zn, stride_zg, + stride_om, stride_on, + HAS_BIAS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + mask_m = offs_m < M + mask_n = offs_n < N + mask_k = offs_k < K + + # Input block: [BLOCK_SIZE_M, BLOCK_SIZE_K] + input_ptrs = input_ptr + offs_m[:, None] * stride_im + offs_k[None, :] * stride_ik + input_block = tl.load(input_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + + # ---- Load and unpack packed int4 weights ---- + packed_k = offs_k // 2 # [BLOCK_SIZE_K] → K//2 indices + is_high = offs_k % 2 # [BLOCK_SIZE_K] + + weight_ptrs = qweight_ptr + offs_n[:, None] * stride_wn + packed_k[None, :] * stride_wk + packed_vals = tl.load(weight_ptrs, mask=mask_n[:, None] & (packed_k[None, :] < (K // 2)), other=0) + + low = packed_vals & 0xF + high = (packed_vals >> 4) & 0xF + unpacked = tl.where(is_high[None, :] == 1, high, low).to(tl.float32) # [N, K] + + # ---- Dequantization ---- + group_id = (offs_k // groupsize)[None, :] # [1, K] + scale_ptrs = scales_ptr + offs_n[:, None] * stride_sn + group_id * stride_sg + zero_ptrs = zeros_ptr + offs_n[:, None] * stride_zn + group_id * stride_zg + + scale_vals = tl.load(scale_ptrs, mask=mask_n[:, None], other=1.0) + zero_vals = tl.load(zero_ptrs, mask=mask_n[:, None], other=0.0) + + dequant = (unpacked - zero_vals) * scale_vals # [BLOCK_SIZE_N, BLOCK_SIZE_K] + + # ---- GEMM ---- + acc = tl.dot(input_block, tl.trans(dequant)) # [BLOCK_SIZE_M, BLOCK_SIZE_N] + + # ---- Add bias if present ---- + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_n + bias_vals = tl.load(bias_ptrs, mask=mask_n, other=0.0) + acc += bias_vals[None, :] + + # ---- Write output ---- + output_ptrs = output_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + tl.store(output_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :]) + + + +class Int4Linear(nn.Module): + """ + A linear layer that uses int4 quantized weights with Triton kernel. + """ + + def __init__(self, qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, + bias: Optional[torch.Tensor] = None, groupsize: int = 128): + super().__init__() + + # Validate inputs + assert qweight.dtype == torch.uint8, "qweight must be uint8" + assert scales.dtype in [torch.float16, torch.float32], "scales must be float16 or float32" + assert zeros.dtype in [torch.float16, torch.float32], "zeros must be float16 or float32" + + self.out_features, packed_in_features = qweight.shape + self.in_features = packed_in_features * 2 # Each uint8 contains 2 int4 values + self.groupsize = groupsize + + # Register quantized parameters + self.register_buffer('qweight', qweight) + self.register_buffer('scales', scales.to(torch.float32)) # Always use fp32 for scales + self.register_buffer('zeros', zeros.to(torch.float32)) # Always use fp32 for zeros + + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using int4 Triton kernel. + + Args: + x: Input tensor of shape [..., in_features] + + Returns: + Output tensor of shape [..., out_features] + """ + # Reshape input to 2D + input_shape = x.shape + x_2d = x.view(-1, self.in_features) + M, K = x_2d.shape + N = self.out_features + + # Allocate output + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + + # Calculate grid dimensions + BLOCK_SIZE_M = min(64, triton.next_power_of_2(M)) + BLOCK_SIZE_N = min(64, triton.next_power_of_2(N)) + BLOCK_SIZE_K = min(64, triton.next_power_of_2(K)) + + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE_M']), + triton.cdiv(N, meta['BLOCK_SIZE_N']) + ) + + # Launch kernel + _int4_linear_kernel[grid]( + x_2d, self.qweight, self.scales, self.zeros, output, + self.bias if self.bias is not None else x_2d, # Dummy pointer if no bias + M, N, K, self.groupsize, + x_2d.stride(0), x_2d.stride(1), + self.qweight.stride(0), self.qweight.stride(1), + self.scales.stride(0), self.scales.stride(1), + self.zeros.stride(0), self.zeros.stride(1), + output.stride(0), output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + HAS_BIAS=self.bias is not None, + num_warps=4, + num_stages=2, + ) + + # Reshape output back to original shape + return output.view(*input_shape[:-1], self.out_features) + + +def pack_int4_weights(qweight: torch.Tensor) -> torch.Tensor: + """ + Pack int4 weights from [N, K] uint8 to [N, K//2] uint8. + Each output uint8 contains two int4 values. + """ + N, K = qweight.shape + assert K % 2 == 0, "K must be even for packing" + + # Pack two int4 values into one uint8 + packed = torch.zeros((N, K // 2), dtype=torch.uint8, device=qweight.device) + packed = (qweight[:, 0::2] & 0xF) | ((qweight[:, 1::2] & 0xF) << 4) + + return packed + + +def create_int4_linear_from_quantized(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, bias: Optional[torch.Tensor] = None, + groupsize: int = 128) -> Int4Linear: + """ + Create an Int4Linear layer from quantized parameters. + + Args: + qweight: Quantized weights [out_features, in_features] as uint8 + scales: Dequantization scales [out_features, num_groups] + zeros: Dequantization zeros [out_features, num_groups] + bias: Optional bias term [out_features] + groupsize: Group size for quantization + + Returns: + Int4Linear layer ready for inference + """ + # Pack int4 weights if needed + if qweight.shape[1] % 2 == 0: + # Assume weights are not packed yet + packed_qweight = pack_int4_weights(qweight) + else: + # Assume weights are already packed + packed_qweight = qweight + + return Int4Linear(packed_qweight, scales, zeros, bias, groupsize) + + +# Example usage and testing +if __name__ == "__main__": + def test_int4_linear(): + # Test parameters + batch_size = 32 + in_features = 512 + out_features = 256 + groupsize = 128 + + # Create random quantized weights + qweight = torch.randint(0, 16, (out_features, in_features), dtype=torch.uint8, device='cuda') + scales = torch.randn(out_features, in_features // groupsize, dtype=torch.float16, device='cuda') + zeros = torch.randint(0, 16, (out_features, in_features // groupsize), dtype=torch.float16, device='cuda') + bias = torch.randn(out_features, dtype=torch.float16, device='cuda') + + # Create Int4Linear layer + int4_layer = create_int4_linear_from_quantized(qweight, scales, zeros, bias, groupsize) + + # Test forward pass + x = torch.randn(batch_size, in_features, dtype=torch.float16, device='cuda') + + # Warm up + for _ in range(10): + _ = int4_layer(x) + + torch.cuda.synchronize() + + # Benchmark + import time + start_time = time.time() + for _ in range(100): + output = int4_layer(x) + torch.cuda.synchronize() + end_time = time.time() + + print(f"Int4Linear forward time: {(end_time - start_time) * 10:.2f} ms per call") + print(f"Output shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + + # Compare with standard linear (using dequantized weights) + from lite_llama.quantization.gptq import GPTQ + gptq = GPTQ(wbits=4, groupsize=groupsize) + dequant_weight = gptq.dequantize(qweight, zeros, scales) + + std_layer = nn.Linear(in_features, out_features, bias=True, device='cuda', dtype=torch.float16) + std_layer.weight.data = dequant_weight.T.to(torch.float16) + std_layer.bias.data = bias + + start_time = time.time() + for _ in range(100): + std_output = std_layer(x) + torch.cuda.synchronize() + end_time = time.time() + + print(f"Standard Linear forward time: {(end_time - start_time) * 10:.2f} ms per call") + + # Check numerical accuracy (should be close due to quantization) + diff = torch.abs(output - std_output).max() + print(f"Max difference between Int4Linear and Standard Linear: {diff:.6f}") + + + test_int4_linear() \ No newline at end of file diff --git a/tests/quant/test_SQLinear.py b/tests/quant/test_SQLinear.py new file mode 100644 index 0000000..d2e6531 --- /dev/null +++ b/tests/quant/test_SQLinear.py @@ -0,0 +1,488 @@ +import torch +import torch.nn as nn + +import numpy as np +from typing import Dict, List, Tuple +import time, os, sys, gc +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) +from lite_llama.kernels.sq_linear import SmoothQuantLinear + + +# Import the SmoothQuant implementation +# from smoothquant_int4 import SmoothQuantLinear + +def get_memory_usage(): + """Get current GPU memory usage in MB""" + if torch.cuda.is_available(): + return torch.cuda.memory_allocated() / 1024 / 1024 + return 0 + + +def clear_memory(): + """Clear GPU memory cache""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + +def format_table(headers: List[str], rows: List[List[str]], title: str = "") -> str: + """Simple table formatter without external dependencies""" + if not rows: + return "" + + # Calculate column widths + widths = [len(header) for header in headers] + for row in rows: + for i, cell in enumerate(row): + if i < len(widths): + widths[i] = max(widths[i], len(str(cell))) + + # Create format string + fmt = " | ".join(f"{{:<{w}}}" for w in widths) + separator = "-+-".join("-" * w for w in widths) + + # Build table + result = [] + if title: + total_width = sum(widths) + 3 * (len(widths) - 1) + result.append(f"\n{title}") + result.append("=" * max(len(title), total_width)) + + result.append(fmt.format(*headers)) + result.append(separator) + + for row in rows: + result.append(fmt.format(*[str(cell) for cell in row])) + + return "\n".join(result) + + +class PerformanceComparison: + """Class to compare SmoothQuant INT4 with nn.Linear""" + + def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): + self.device = device + self.results = [] + + def create_layers(self, in_features: int, out_features: int, group_size: int = 128): + """Create both SmoothQuant and nn.Linear layers with same weights""" + # Create standard linear layer + linear_layer = nn.Linear(in_features, out_features, bias=True, dtype=torch.float16) + linear_layer = linear_layer.to(self.device) + + # Create SmoothQuant layer + sq_layer = SmoothQuantLinear(in_features, out_features, bias=True, group_size=group_size) + sq_layer = sq_layer.to(self.device) + + # Copy weights and bias from linear to SmoothQuant + with torch.no_grad(): + weight = linear_layer.weight.data.clone() + bias = linear_layer.bias.data.clone() if linear_layer.bias is not None else None + + # Generate some sample activations to compute scales for SmoothQuant + sample_input = torch.randn(32, 128, in_features, dtype=torch.float16, device=self.device) + act_scales = sample_input.abs().amax(dim=(0, 1)) + + # Quantize the weights + sq_layer.quantize_weight(weight, act_scales) + if bias is not None: + sq_layer.bias.copy_(bias) + + return linear_layer, sq_layer + + def measure_memory(self, layer, input_tensor, layer_name: str): + """Measure memory usage of a layer""" + clear_memory() + + # Baseline memory + baseline_memory = get_memory_usage() + + # Model memory (parameters) - handle both regular and buffer parameters + model_memory = 0 + for param in layer.parameters(): + model_memory += param.numel() * param.element_size() + + # Also count registered buffers (important for SmoothQuant) + for buffer in layer.buffers(): + model_memory += buffer.numel() * buffer.element_size() + + model_memory = model_memory / 1024 / 1024 # Convert to MB + + # Ensure we have a minimum memory value to avoid division by zero + model_memory = max(model_memory, 0.001) # At least 1KB + + # Forward pass memory + with torch.no_grad(): + _ = layer(input_tensor) + peak_memory = get_memory_usage() + + activation_memory = peak_memory - baseline_memory - model_memory + + return { + 'layer_name': layer_name, + 'model_memory_mb': model_memory, + 'activation_memory_mb': max(0, activation_memory), + 'total_memory_mb': peak_memory - baseline_memory + } + + def measure_speed(self, layer, input_tensor, num_warmup: int = 10, num_runs: int = 100): + """Measure inference speed of a layer""" + # Warmup + with torch.no_grad(): + for _ in range(num_warmup): + _ = layer(input_tensor) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # Timing + times = [] + with torch.no_grad(): + for _ in range(num_runs): + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = time.time() + + _ = layer(input_tensor) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + end = time.time() + + times.append((end - start) * 1000) # Convert to ms + + return { + 'mean_time_ms': np.mean(times), + 'std_time_ms': np.std(times), + 'min_time_ms': np.min(times), + 'max_time_ms': np.max(times) + } + + def measure_accuracy(self, linear_layer, sq_layer, input_tensor): + """Compare numerical accuracy between layers""" + with torch.no_grad(): + # Get outputs + linear_output = linear_layer(input_tensor) + sq_output = sq_layer(input_tensor) + + # Calculate differences + abs_diff = (linear_output - sq_output).abs() + rel_diff = abs_diff / (linear_output.abs() + 1e-8) + + return { + 'mean_abs_error': abs_diff.mean().item(), + 'max_abs_error': abs_diff.max().item(), + 'mean_rel_error': rel_diff.mean().item(), + 'max_rel_error': rel_diff.max().item(), + 'mse': ((linear_output - sq_output) ** 2).mean().item(), + 'cosine_similarity': torch.nn.functional.cosine_similarity( + linear_output.flatten(), sq_output.flatten(), dim=0 + ).item() + } + + def run_comparison(self, test_configs: List[Dict]): + """Run comparison across multiple configurations""" + print(f"Running comparison on {self.device}") + print("=" * 80) + + for config in test_configs: + print(f"\nTesting: {config['name']}") + print("-" * 40) + + # Create input + input_tensor = torch.randn( + config['batch_size'], + config['seq_len'], + config['in_features'], + dtype=torch.float16, + device=self.device + ) + + # Create layers + linear_layer, sq_layer = self.create_layers( + config['in_features'], + config['out_features'], + config.get('group_size', 128) + ) + + # Measure memory + linear_memory = self.measure_memory(linear_layer, input_tensor, 'nn.Linear') + sq_memory = self.measure_memory(sq_layer, input_tensor, 'SmoothQuant') + + # Measure speed + linear_speed = self.measure_speed(linear_layer, input_tensor) + sq_speed = self.measure_speed(sq_layer, input_tensor) + + # Measure accuracy + accuracy = self.measure_accuracy(linear_layer, sq_layer, input_tensor) + + # Calculate throughput (tokens/sec) + total_tokens = config['batch_size'] * config['seq_len'] + linear_throughput = total_tokens / (linear_speed['mean_time_ms'] / 1000) + sq_throughput = total_tokens / (sq_speed['mean_time_ms'] / 1000) + + # Store results + result = { + 'config': config, + 'linear_memory': linear_memory, + 'sq_memory': sq_memory, + 'linear_speed': linear_speed, + 'sq_speed': sq_speed, + 'accuracy': accuracy, + 'linear_throughput': linear_throughput, + 'sq_throughput': sq_throughput, + 'speedup': linear_speed['mean_time_ms'] / sq_speed['mean_time_ms'], + 'memory_reduction': linear_memory['model_memory_mb'] / sq_memory['model_memory_mb'] + } + + self.results.append(result) + + # Print summary for this config + self.print_config_summary(result) + + def print_config_summary(self, result): + """Print summary for a single configuration""" + config = result['config'] + + print(f"Input shape: [{config['batch_size']}, {config['seq_len']}, {config['in_features']}]") + print(f"Output features: {config['out_features']}") + + # Speed comparison + print(f"\n🕐 Speed Comparison:") + print( + f" nn.Linear: {result['linear_speed']['mean_time_ms']:.3f} ± {result['linear_speed']['std_time_ms']:.3f} ms") + print(f" SmoothQuant: {result['sq_speed']['mean_time_ms']:.3f} ± {result['sq_speed']['std_time_ms']:.3f} ms") + print(f" Speedup: {result['speedup']:.2f}x") + + # Throughput + print(f"\n🚀 Throughput:") + print(f" nn.Linear: {result['linear_throughput']:.0f} tokens/sec") + print(f" SmoothQuant: {result['sq_throughput']:.0f} tokens/sec") + + # Memory comparison + print(f"\n💾 Memory Usage (Model Parameters):") + print(f" nn.Linear: {result['linear_memory']['model_memory_mb']:.2f} MB") + print(f" SmoothQuant: {result['sq_memory']['model_memory_mb']:.2f} MB") + print(f" Reduction: {result['memory_reduction']:.2f}x") + + # Accuracy + print(f"\n📊 Accuracy:") + print(f" Mean Abs Error: {result['accuracy']['mean_abs_error']:.6f}") + print(f" Max Abs Error: {result['accuracy']['max_abs_error']:.6f}") + print(f" Mean Rel Error: {result['accuracy']['mean_rel_error']:.4f}") + print(f" Cosine Similarity: {result['accuracy']['cosine_similarity']:.6f}") + + def generate_report(self): + """Generate comprehensive report with tables""" + if not self.results: + print("No results to report!") + return + + print("\n" + "=" * 80) + print("COMPREHENSIVE COMPARISON REPORT") + print("=" * 80) + + # Summary table + self.print_summary_table() + + # Additional detailed analysis + self.print_detailed_analysis() + + def print_summary_table(self): + """Print summary table of all results""" + headers = [ + "Config", "Batch×Seq", "Features", + "Linear (ms)", "SmoothQ (ms)", "Speedup", + "Linear (MB)", "SmoothQ (MB)", "Mem Reduction", + "Mean Abs Err", "Cos Sim" + ] + + rows = [] + for result in self.results: + config = result['config'] + rows.append([ + config['name'], + f"{config['batch_size']}×{config['seq_len']}", + f"{config['in_features']}→{config['out_features']}", + f"{result['linear_speed']['mean_time_ms']:.2f}", + f"{result['sq_speed']['mean_time_ms']:.2f}", + f"{result['speedup']:.2f}x", + f"{result['linear_memory']['model_memory_mb']:.1f}", + f"{result['sq_memory']['model_memory_mb']:.1f}", + f"{result['memory_reduction']:.2f}x", + f"{result['accuracy']['mean_abs_error']:.4f}", + f"{result['accuracy']['cosine_similarity']:.4f}" + ]) + + print(format_table(headers, rows, "📋 Summary Table")) + + # Overall statistics + speedups = [r['speedup'] for r in self.results] + memory_reductions = [r['memory_reduction'] for r in self.results] + accuracies = [r['accuracy']['cosine_similarity'] for r in self.results] + + print(f"\n📈 Overall Statistics:") + print(f" Average Speedup: {np.mean(speedups):.2f}x (±{np.std(speedups):.2f})") + print(f" Average Memory Reduction: {np.mean(memory_reductions):.2f}x (±{np.std(memory_reductions):.2f})") + print(f" Average Cosine Similarity: {np.mean(accuracies):.4f} (±{np.std(accuracies):.4f})") + + def print_detailed_analysis(self): + """Print detailed analysis without plots""" + if not self.results: + return + + print("\n" + "=" * 60) + print("DETAILED ANALYSIS") + print("=" * 60) + + # Performance analysis + print("\n🚀 Performance Analysis:") + print("-" * 30) + + best_speedup = max(self.results, key=lambda x: x['speedup']) + worst_speedup = min(self.results, key=lambda x: x['speedup']) + + print(f"Best speedup: {best_speedup['speedup']:.2f}x ({best_speedup['config']['name']})") + print(f"Worst speedup: {worst_speedup['speedup']:.2f}x ({worst_speedup['config']['name']})") + + # Memory analysis + print("\n💾 Memory Analysis:") + print("-" * 30) + + best_memory = max(self.results, key=lambda x: x['memory_reduction']) + worst_memory = min(self.results, key=lambda x: x['memory_reduction']) + + print(f"Best memory reduction: {best_memory['memory_reduction']:.2f}x ({best_memory['config']['name']})") + print(f"Worst memory reduction: {worst_memory['memory_reduction']:.2f}x ({worst_memory['config']['name']})") + + # Accuracy analysis + print("\n📊 Accuracy Analysis:") + print("-" * 30) + + best_accuracy = max(self.results, key=lambda x: x['accuracy']['cosine_similarity']) + worst_accuracy = min(self.results, key=lambda x: x['accuracy']['cosine_similarity']) + + print( + f"Best accuracy: {best_accuracy['accuracy']['cosine_similarity']:.6f} ({best_accuracy['config']['name']})") + print( + f"Worst accuracy: {worst_accuracy['accuracy']['cosine_similarity']:.6f} ({worst_accuracy['config']['name']})") + + # Efficiency analysis + print("\n⚡ Efficiency Analysis:") + print("-" * 30) + + for result in self.results: + config = result['config'] + efficiency_score = result['speedup'] * result['memory_reduction'] * result['accuracy']['cosine_similarity'] + print(f"{config['name']:12}: Efficiency Score = {efficiency_score:.3f}") + + # Scaling analysis + print("\n📈 Scaling Analysis:") + print("-" * 30) + + # Sort by problem size (total parameters) + sorted_results = sorted(self.results, key=lambda x: x['config']['in_features'] * x['config']['out_features']) + + for result in sorted_results: + config = result['config'] + total_params = config['in_features'] * config['out_features'] + print( + f"{config['name']:12}: {total_params:>10,} params, {result['speedup']:.2f}x speedup, {result['memory_reduction']:.2f}x memory") + + # Recommendations + print("\n💡 Recommendations:") + print("-" * 30) + + avg_speedup = np.mean([r['speedup'] for r in self.results]) + avg_memory = np.mean([r['memory_reduction'] for r in self.results]) + avg_accuracy = np.mean([r['accuracy']['cosine_similarity'] for r in self.results]) + + if avg_speedup > 1.5: + print("✅ SmoothQuant shows significant speed improvements") + else: + print("⚠️ SmoothQuant speed improvements are marginal") + + if avg_memory > 3.0: + print("✅ SmoothQuant provides excellent memory savings") + else: + print("⚠️ SmoothQuant memory savings are lower than expected") + + if avg_accuracy > 0.99: + print("✅ SmoothQuant maintains high numerical accuracy") + elif avg_accuracy > 0.95: + print("⚠️ SmoothQuant shows moderate accuracy degradation") + else: + print("❌ SmoothQuant shows significant accuracy degradation") + + +def run_comprehensive_test(): + """Run comprehensive comparison test""" + + # Test configurations + test_configs = [ + { + 'name': 'Small', + 'batch_size': 1, + 'seq_len': 128, + 'in_features': 512, + 'out_features': 512, + 'group_size': 128 + }, + { + 'name': 'Medium', + 'batch_size': 4, + 'seq_len': 256, + 'in_features': 1024, + 'out_features': 1024, + 'group_size': 128 + }, + { + 'name': 'Large', + 'batch_size': 8, + 'seq_len': 512, + 'in_features': 2048, + 'out_features': 2048, + 'group_size': 128 + }, + { + 'name': 'Very Large', + 'batch_size': 16, + 'seq_len': 1024, + 'in_features': 4096, + 'out_features': 4096, + 'group_size': 128 + }, + { + 'name': 'LLaMA-like', + 'batch_size': 1, + 'seq_len': 2048, + 'in_features': 4096, + 'out_features': 11008, # Typical MLP dimension + 'group_size': 128 + } + ] + + # Run comparison + comparison = PerformanceComparison() + comparison.run_comparison(test_configs) + comparison.generate_report() + + return comparison + + +if __name__ == "__main__": + print("SmoothQuant INT4 vs nn.Linear Comprehensive Comparison") + print("=" * 60) + + # Check if CUDA is available + if torch.cuda.is_available(): + print(f"Using GPU: {torch.cuda.get_device_name()}") + print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") + else: + print("Using CPU (GPU not available)") + + print() + + # Run the comprehensive test + comparison = run_comprehensive_test() + + print(f"\n🎉 Comparison complete!") \ No newline at end of file diff --git a/tests/quant/test_gptq.py b/tests/quant/test_gptq.py new file mode 100644 index 0000000..f99867f --- /dev/null +++ b/tests/quant/test_gptq.py @@ -0,0 +1,232 @@ +import torch +import numpy as np +import pytest +from typing import Dict, Tuple, Optional, Any +import time, os, sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +from gptq import GPTQ + +# Assuming the GPTQ class and helper functions are imported +# from your_gptq_module import GPTQ, pack_weight, unpack_weight, GPTQConfig + +def pack_weight(weight): + """Pack two 4-bit values into one uint8 value""" + rows, cols = weight.shape + original_cols = cols + if cols % 2 != 0: + weight = torch.nn.functional.pad(weight, (0, 1), value=0) + cols += 1 + packed = (weight[:, 0::2] & 0xF) | ((weight[:, 1::2] & 0xF) << 4) + return packed.contiguous(), original_cols + + +def unpack_weight(packed_weight, original_cols): + """Unpack uint8 values back to two 4-bit values""" + rows, packed_cols = packed_weight.shape + unpacked = torch.zeros((rows, packed_cols * 2), dtype=torch.uint8, device=packed_weight.device) + unpacked[:, 0::2] = packed_weight & 0xF + unpacked[:, 1::2] = (packed_weight >> 4) & 0xF + return unpacked[:, :original_cols].contiguous() + + +# Mock GPTQConfig for testing +class GPTQConfig: + def __init__(self, w_bit=4, group_size=128, device="cuda"): + self.w_bit = w_bit + self.group_size = group_size + self.device = device + + +def test_pack_unpack_weights(): + """Test that pack/unpack operations are lossless""" + print("Testing pack/unpack operations...") + + # Test with even columns + weight_even = torch.randint(0, 16, (4, 8), dtype=torch.uint8) + packed, original_cols = pack_weight(weight_even) + unpacked = unpack_weight(packed, original_cols) + + assert torch.equal(weight_even, unpacked), "Pack/unpack failed for even columns" + assert packed.shape[1] == weight_even.shape[1] // 2, "Packed size incorrect" + + # Test with odd columns + weight_odd = torch.randint(0, 16, (4, 7), dtype=torch.uint8) + packed, original_cols = pack_weight(weight_odd) + unpacked = unpack_weight(packed, original_cols) + + assert torch.equal(weight_odd, unpacked), "Pack/unpack failed for odd columns" + assert original_cols == 7, "Original columns not preserved" + + print("✓ Pack/unpack tests passed") + + +def test_quantize_dequantize_cycle(): + """Test the complete quantize -> dequantize cycle""" + print("Testing quantize -> dequantize cycle...") + + device = "cuda" if torch.cuda.is_available() else "cpu" + config = GPTQConfig(w_bit=4, group_size=32, device=device) + + gptq = GPTQ(config) + + # Test cases with different weight characteristics + test_cases = [ + { + "name": "Normal weights", + "weight": torch.randn(128, 256, device=device, dtype=torch.float32) * 0.1 + }, + { + "name": "Small weights", + "weight": torch.randn(64, 128, device=device, dtype=torch.float32) * 0.001 + }, + { + "name": "Large weights", + "weight": torch.randn(32, 64, device=device, dtype=torch.float32) * 10.0 + }, + { + "name": "Mixed scale weights", + "weight": torch.cat([ + torch.randn(16, 32, device=device) * 0.001, + torch.randn(16, 32, device=device) * 1.0 + ], dim=0) + }, + { + "name": "Odd columns", + "weight": torch.randn(32, 63, device=device, dtype=torch.float32) * 0.1 + } + ] + + results = [] + + for test_case in test_cases: + print(f"\n Testing: {test_case['name']}") + weight = test_case["weight"] + + # Quantize + start_time = time.time() + packed_qweight, qzeros, scales, original_cols = gptq.quantize(weight) + quantize_time = time.time() - start_time + + # Dequantize + start_time = time.time() + reconstructed = gptq.dequantize_packed(packed_qweight, qzeros, scales, original_cols) + dequantize_time = time.time() - start_time + + # Check shapes + assert reconstructed.shape == weight.shape, f"Shape mismatch: {reconstructed.shape} vs {weight.shape}" + + # Calculate errors + mse = torch.mean((weight - reconstructed) ** 2).item() + relative_error = gptq.relative_error_loss(weight, reconstructed).item() + max_abs_error = torch.max(torch.abs(weight - reconstructed)).item() + + # Memory efficiency check + original_memory = weight.numel() * 4 # float32 + quantized_memory = (packed_qweight.numel() + qzeros.numel() * 2 + scales.numel() * 2) # uint8 + float16 + compression_ratio = original_memory / quantized_memory + + result = { + "name": test_case["name"], + "shape": weight.shape, + "mse": mse, + "relative_error": relative_error, + "max_abs_error": max_abs_error, + "compression_ratio": compression_ratio, + "quantize_time": quantize_time, + "dequantize_time": dequantize_time, + "original_cols": original_cols + } + results.append(result) + + print(f" Shape: {weight.shape}") + print(f" MSE: {mse:.6f}") + print(f" Relative Error: {relative_error:.6f}") + print(f" Max Abs Error: {max_abs_error:.6f}") + print(f" Compression Ratio: {compression_ratio:.2f}x") + print(f" Original Cols: {original_cols}") + + # Assertions for quality + assert relative_error < 1.0, f"Relative error too high: {relative_error}" + assert compression_ratio > 1.5, f"Compression ratio too low: {compression_ratio}" + + # Test that packing actually reduced size + original_qweight_size = weight.shape[0] * weight.shape[1] # unpacked size + packed_qweight_size = packed_qweight.numel() + expected_packed_size = (weight.shape[1] + 1) // 2 * weight.shape[0] # ceil(cols/2) * rows + + assert packed_qweight_size <= expected_packed_size, "Packing didn't reduce size as expected" + + print("\n✓ All quantize -> dequantize tests passed") + return results + + + +def test_consistency_across_devices(): + """Test that quantization is consistent across CPU and GPU""" + print("Testing device consistency...") + + if not torch.cuda.is_available(): + print(" CUDA not available, skipping device consistency test") + return + + weight_cpu = torch.randn(32, 64, dtype=torch.float16) * 0.1 + weight_gpu = weight_cpu.cuda() + + # Note: This would require the actual GPTQ implementation + # For now, just test that weights can be moved between devices + + assert weight_cpu.device.type == "cpu" + assert weight_gpu.device.type == "cuda" + assert torch.allclose(weight_cpu, weight_gpu.cpu(), atol=1e-6) + + print("✓ Device consistency tests passed") + + +def run_performance_benchmark(): + """Benchmark quantization performance""" + print("Running performance benchmark...") + + device = "cuda" if torch.cuda.is_available() else "cpu" + + sizes = [(512, 512), (1024, 1024), (2048, 2048), (4096, 4096)] + config = GPTQConfig(w_bit=4, group_size=32, device=device) + + gptq = GPTQ(config) + for rows, cols in sizes: + weight = torch.randn(rows, cols, device=device, dtype=torch.float16) * 0.1 + + # Time quantization (would need actual implementation) + start_time = time.time() + quantized = gptq.quantize(weight) + quantize_time = time.time() - start_time + + print(f" Size {rows}x{cols}: Quantization took {quantize_time:.3f}s") + + +def main(): + """Run all tests""" + print("=" * 60) + print("GPTQ Quantization Test Suite") + print("=" * 60) + + try: + # Basic functionality tests + test_pack_unpack_weights() + test_quantize_dequantize_cycle() + test_consistency_across_devices() + + # Performance benchmark + run_performance_benchmark() + + print("\n" + "=" * 60) + print("ALL TESTS PASSED!") + print("=" * 60) + + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file