Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion include/infinicore_infer/models/jiuge.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct JiugeModel;
typedef struct
{
infiniDtype_t dt_logits;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size;
float epsilon, theta;
uint32_t end_token;
} JiugeMeta;
Expand Down Expand Up @@ -65,6 +65,10 @@ destroyJiugeModel(struct JiugeModel *);
__C __export struct KVCache *
createKVCache(const struct JiugeModel *);

/// @brief 创建 Paged KV Cache
__C __export struct KVCache *
createPagedKVCache(const struct JiugeModel *, uint32_t max_kvcache_tokens);

/// @brief 复制 KV Cache
__C __export struct KVCache *
duplicateKVCache(const struct JiugeModel *,
Expand All @@ -85,13 +89,18 @@ dropKVCache(const struct JiugeModel *,
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill
/// @param enable_paged_attn 是否启用 paged attention
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const int32_t *block_tables,
const int32_t *slot_mapping,
const float *temperature, const uint32_t *topk, const float *topp,
const uint32_t is_prefill, const bool enable_paged_attn,
uint32_t *output);

/// @brief 批次推理一轮,输出 output embedding 后的 logits
Expand All @@ -101,12 +110,19 @@ inferBatch(struct JiugeModel *,
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param block_tables 每个请求的 block 表
/// @param slot_mapping 每个请求的 slot 映射
/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill
/// @param enable_paged_attn 是否启用 paged attention
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const int32_t *block_tables,
const int32_t *slot_mapping,
const uint32_t is_prefill, const bool enable_paged_attn,
void *logits);

#endif
83 changes: 83 additions & 0 deletions python/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
import time
import sys
from random import randint, seed
# from nanovllm import LLM, SamplingParams
# from vllm import LLM, SamplingParams

from icinfer import LLM, SamplingParams
from icinfer.engine.libinfinicore_infer import DeviceType

import logging
logger = logging.getLogger(__name__)
import argparse

def parse_args():
parser = argparse.ArgumentParser()
# parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf")
# parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/")
parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/")
parser.add_argument("--device-type", type=str, default="nvidia")
parser.add_argument("--ndev", type=int, default=4)
parser.add_argument("--max-kvcache-tokens", type=int, default=131072)
args = parser.parse_args()
return args

def main():
args = parse_args()
model_path = args.model_path
max_kvcache_tokens = args.max_kvcache_tokens
device_type = DeviceType.DEVICE_TYPE_CPU
if args.device_type == "cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif args.device_type == "nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif args.device_type == "cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif args.device_type == "ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif args.device_type == "metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif args.device_type == "moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
elif args.device_type == "iluvatar":
device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else:
logger.info(
# "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
"Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)

seed(0)
# num_seqs = 128
num_seqs = 8
max_input_len = 1024
max_ouput_len = 1024

path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/")
llm = LLM(path, device=device_type, enforce_eager=True,
tensor_parallel_size=args.ndev, trust_remote_code=True,
attention_bias=True, enable_paged_attn=True, max_kvcache_tokens=max_kvcache_tokens)


prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]

sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
# sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
# uncomment the following line for vllm
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]

llm.generate(["Benchmark: "], SamplingParams())
t = time.time()
# llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
outputs = llm.generate(prompt_token_ids, sampling_params)
t = (time.time() - t)
total_tokens = sum(sp.max_tokens for sp in sampling_params)
throughput = total_tokens / t
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")


if __name__ == "__main__":
main()
157 changes: 157 additions & 0 deletions python/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

import sys
from transformers import AutoTokenizer
import argparse

from icinfer import LLM, SamplingParams
from icinfer.engine.libinfinicore_infer import DeviceType

import logging
logger = logging.getLogger(__name__)

def parse_args():
parser = argparse.ArgumentParser()
# parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/")
parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/")
parser.add_argument("--device-type", type=str, default="nvidia")
parser.add_argument("--ndev", type=int, default=1)
parser.add_argument("--max-kvcache-tokens", type=int, default=10240)
# parser.add_argument("--max-kvcache-tokens", type=int, default=65536)
parser.add_argument("--enable-paged-attn", action="store_true")
# parser.add_argument("--enable-paged-attn", type=bool, default=True)
args = parser.parse_args()
return args

def main():
args = parse_args()
model_path = args.model_path
max_kvcache_tokens = args.max_kvcache_tokens
device_type = DeviceType.DEVICE_TYPE_CPU
if args.device_type == "cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif args.device_type == "nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif args.device_type == "cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif args.device_type == "ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif args.device_type == "metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif args.device_type == "moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
elif args.device_type == "iluvatar":
device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else:
logger.info(
# "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
"Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)

# path = os.path.expanduser("~/vllm/huggingface/Qwen3-0.6B/")
# tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
# llm = LLM(path, enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True)
# path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/")
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
llm = LLM(path, device=device_type, enforce_eager=True,
tensor_parallel_size=args.ndev, trust_remote_code=True,
attention_bias=True, enable_paged_attn=args.enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens)

sampling_params = SamplingParams(temperature=0.6, max_tokens=128)
# prompts = [
# "introduce yourself",
# # "list all prime numbers within 100",
# "山东最高的山是?",
# "如果猫能写诗,它们会写些什么?",
# "描述一个没有重力的世界。",
# "如果地球停止自转,会发生什么?",
# "假设你是一只会飞的鲸鱼,描述你的日常生活。",
# "如果人类可以与植物沟通,世界会变成什么样?",
# "描述一个由糖果构成的城市。",
# "如果时间旅行成为可能,你最想去哪个时代?",
# "想象一下,如果地球上只有蓝色,其他颜色都消失了。",
# "如果动物能上网,它们会浏览什么网站?",
# "描述一个没有声音的世界。",
# "如果人类可以在水下呼吸,城市会如何变化?",
# "想象一下,如果天空是绿色的,云是紫色的。",
# "如果你能与任何历史人物共进晚餐,你会选择谁?",
# "描述一个没有夜晚的星球。",
# "如果地球上只有一种语言,世界会如何运作?",
# "想象一下,如果所有的书都变成了音乐。",
# "如果你可以变成任何一种动物,你会选择什么?",
# "描述一个由机器人统治的未来世界。",
# "如果你能与任何虚构角色成为朋友,你会选择谁?",
# "想象一下,如果每个人都能读懂他人的思想。"
# ] * 2
prompts = [
# "描述一个由糖果构成的城市。",
# "如果时间旅行成为可能,你最想去哪个时代?",
# "如果时间旅行成为可能,你最想去哪个时代?",
# "想象一下,如果地球上只有蓝色,其他颜色都消失了。",
# "如果动物能上网,它们会浏览什么网站?",
# "描述一个由糖果构成的城市。",
# "如果时间旅行成为可能,你最想去哪个时代?",
# "想象一下,如果地球上只有蓝色,其他颜色都消失了。",
# "如果动物能上网,它们会浏览什么网站?",

"如果人类可以与植物沟通,世界会变成什么样?",
"描述一个由糖果构成的城市。",
"如果时间旅行成为可能,你最想去哪个时代?",
"想象一下,如果地球上只有蓝色,其他颜色都消失了。",
"如果动物能上网,它们会浏览什么网站?",
"描述一个没有声音的世界。",
"如果人类可以在水下呼吸,城市会如何变化?",
"想象一下,如果天空是绿色的,云是紫色的。",
# "如果你能与任何历史人物共进晚餐,你会选择谁?",
# "描述一个没有夜晚的星球。",
# "如果地球上只有一种语言,世界会如何运作?",
# "想象一下,如果所有的书都变成了音乐。",
# "如果你可以变成任何一种动物,你会选择什么?",
# "描述一个由机器人统治的未来世界。",
# "如果你能与任何虚构角色成为朋友,你会选择谁?",
# "想象一下,如果每个人都能读懂他人的思想。"

# "如果人类可以与植物沟通,世界会变成什么样?",
# "描述一个由糖果构成的城市。",
# "如果人类可以与植物沟通,世界会变成什么样?",

]
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
for prompt in prompts
]
outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency = llm.generate(prompts, sampling_params)

for prompt, output in zip(prompts, outputs):
print("\n")
print(f"Prompt: {prompt!r}")
print(f"Completion: {output['text']!r}")
# print("\n")
# print(f"Prompt: {prompts[0]!r}")
# print(f"Completion: {outputs[0]['text']!r}")
print(f"batch_size: {len(prompts)}, n_dev: {args.ndev}, is_paged_attn: {args.enable_paged_attn}")
print(f"Avg Prefill Throughput: {avg_prefill_throughput:.2f} tok/s")
print(f"Avg Decode Throughput: {avg_decode_throughput:.2f} tok/s")
print(f"Avg TTFT: {avg_ttft*1000:.2f} ms")
print(f"Avg TBT: {avg_tbt*1000:.2f} ms")
print(f"Cache Efficiency: {cache_efficiency*100:.2f}%")

if __name__ == "__main__":
main()


"""
CLI:
python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4 --max-kvcache-tokens 10240 --enable-paged-attn
python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4

"""
13 changes: 13 additions & 0 deletions python/icinfer.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Metadata-Version: 2.4
Name: icinfer
Version: 0.1.0
Summary: a lightweight, hardware-agnostic, unified inference engine implementation built from scratch, based on InfiniCore
Author:
License-Expression: MIT
Project-URL: Homepage, https://github.com/InfiniTensor/InfiniLM
Requires-Python: <3.13,>=3.10
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.4.0
Requires-Dist: triton>=3.0.0
Requires-Dist: transformers>=4.51.0
Requires-Dist: xxhash
2 changes: 2 additions & 0 deletions python/icinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from icinfer.llm import LLM
from icinfer.sampling_params import SamplingParams
Empty file.
Loading