Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bd40477
fix import error
TATAXIMU May 17, 2025
0f84e95
add sentencepiece into requirement
TATAXIMU May 17, 2025
5008cd3
add logs
TATAXIMU May 17, 2025
50de54d
merge llava into generate
TATAXIMU May 17, 2025
6dfd879
Merge remote-tracking branch 'origin/main' into dev
TATAXIMU May 17, 2025
522171a
fix readme and add interface for weight converter
TATAXIMU May 17, 2025
f42bf9f
fix latency calculation
TATAXIMU May 18, 2025
6801f53
refactor: folder and import bugs
May 19, 2025
5580283
fix: code structure
May 19, 2025
a5665fd
chore: update logs directory
May 19, 2025
5997bc9
fix: remove repeat file
May 20, 2025
4b83355
fix: logger directory can't create
May 20, 2025
652a2ca
Merge remote-tracking branch 'origin/main' into dev
TATAXIMU May 21, 2025
15acd74
fix conflict
TATAXIMU May 21, 2025
228d1e3
fix import
TATAXIMU May 23, 2025
2ef6619
gptq for llama
TATAXIMU May 23, 2025
af1f429
gptq for llava
TATAXIMU May 24, 2025
f4c1010
fix print problem
TATAXIMU May 24, 2025
57fba86
fix naming issue
TATAXIMU May 24, 2025
cabeb3e
fix gptq compress
TATAXIMU May 26, 2025
629bd98
1
TATAXIMU May 28, 2025
8becbae
fix missing weight keys
TATAXIMU May 31, 2025
5dbe60a
fix requirement.txt
TATAXIMU May 31, 2025
5daf07d
test for 2.2.0
TATAXIMU May 31, 2025
11ec094
fix inference with missing keys
TATAXIMU May 31, 2025
6cf54e6
fixed gptq missing int4 inference
TATAXIMU Jun 1, 2025
b5cd0d5
update gptq int4 kernel
TATAXIMU Jul 2, 2025
ebaeb0c
split int4 kernel and gptq
TATAXIMU Jul 3, 2025
f2cb589
add awq
TATAXIMU Jul 10, 2025
6343abe
Merge to main
TATAXIMU Jul 10, 2025
2dfb249
update sqlinear and sq quant, TODO: awqlinear
TATAXIMU Jul 13, 2025
f30154f
update sqlinear and awqlinear
TATAXIMU Jul 14, 2025
d96b8a0
remove unnecessary method
TATAXIMU Jul 18, 2025
ec28d65
add quant into generate stream
TATAXIMU Jul 22, 2025
de0b18e
refactor quant test
TATAXIMU Jul 24, 2025
67905ca
update awq test
TATAXIMU Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ conda activate lite_llama
git clone https://github.com/harleyszhang/lite_llama.git
cd lite_llama/
pip install -r requirement.txt
python test_weight_convert.py # model weight transformation
python generate.py --prompt "What is large language model" --checkpoint_path /path/to/model/Llama-3.2-1B-Instruct/ # Run on the basis that the model has been downloaded and placed in the specified directory
python apply_weight_convert.py --checkpoints_dir /path/to/model/Llama-3.2-1B-Instruct/ --model_type llama # model weight transformation
python generate.py -p "What is large language model" -m /path/to/model/Llama-3.2-1B-Instruct/ -f /path/to/figure# Run on the basis that the model has been downloaded and placed in the specified directory
```

ROCm version 5.7 and above is recommended.
Expand All @@ -95,21 +95,20 @@ conda activate lite_llama
git clone https://github.com/harleyszhang/lite_llama.git
cd lite_llama/
pip install -r requirement.txt
python test_weight_convert.py # model weight transformation
python generate.py --prompt "What is large language model" --checkpoint_path /path/to/model/Llama-3.2-1B-Instruct/ # Run on the basis that the model has been downloaded and placed in the specified directory
python apply_weight_convert.py --checkpoints_dir /path/to/model/Llama-3.2-1B-Instruct/ --model_type llama # model weight transformation
python generate.py -p "What is large language model" -m /path/to/model/Llama-3.2-1B-Instruct/ -f /path/to/figure# Run on the basis that the model has been downloaded and placed in the specified directory
```


## Evaluation

After `cli.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal.

![cli](./images/cli_stream.png)

After `generate.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal.

![generate](./images/generate_stream.png)

After `cli.py` runs successfully, the terminal displays the interface as shown below, and you can enter your question in the terminal.

![cli](./images/cli_stream.png)

After `cli_llava.py` runs successfully, the terminal displays the interface as shown below, enter your picture and prompt word in the terminal, and then enter.

![llava model streaming output](./images/llava_output2.gif)
Expand Down
14 changes: 7 additions & 7 deletions apply_weight_convert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def convert(checkpoints_dir: Path,
new_sd: Dict[str, torch.Tensor] = {}

# ---------- 1. 重映射 ----------
for k, v in tqdm(hf_state.items(), desc=f"[{model_type}] 权重重映射"):
for k, v in tqdm(hf_state.items(), desc=f"[{model_type}] Weight mapping"):
if (ck := mapping.get(k)) is not None:
new_sd[ck] = v
else:
logger.debug("忽略未映射参数 %s", k)
logger.debug("Ignore unmapped parameters %s", k)

# ---------- 2. 仅对 *Qwen* 系列执行 KV 合并 ----------
if model_type.startswith("qwen") or model_type.startswith("llama"): # 只处理 Qwen-2 / Qwen-3 等
Expand All @@ -259,7 +259,7 @@ def convert(checkpoints_dir: Path,
save_state_dict(out_dir, checkpoints_dir.name, new_sd)
copy_metadata(checkpoints_dir, out_dir)

logger.info("🎉 转换完成,共 %d 个参数", len(new_sd))
logger.info("🎉 Convert Complete,There are %d parameters in total", len(new_sd))
return new_sd


Expand Down Expand Up @@ -313,8 +313,8 @@ def get_num_layers(checkpoints_dir: Path, model_type: str) -> int:
def main() -> None:
parser = argparse.ArgumentParser(
description="Convert HF / bin checkpoints into Lite-LLaMA format.")
parser.add_argument("checkpoints_dir", type=Path, help="模型权重目录")
parser.add_argument("--model-type",
parser.add_argument("--checkpoints_dir", type=Path, help="模型权重目录")
parser.add_argument("--model_type",
choices=_SPEC.keys(),
help="显式指定模型类型;默认根据目录名猜测")
parser.add_argument("--device", default="cuda",
Expand All @@ -325,11 +325,11 @@ def main() -> None:

# 1️⃣ **直接从 config.json 读取 model_type** ↓
model_type = detect_model_type(ckpt_dir)
logger.info("检测到 model_type = %s", model_type)
logger.info("Model Type is: %s", model_type)

# 2️⃣ 获取层数
num_layers = get_num_layers(ckpt_dir, model_type)
logger.info("Transformer 层数 %d", num_layers)
logger.info("Transformer Number of layers %d", num_layers)

# 3️⃣ 加载权重并执行转换
hf_sd = load_hf_state(ckpt_dir, model_type, device=args.device)
Expand Down
Loading