Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infinicore_infer/models/jiuge.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ typedef struct
const void *const *attn_qkv;
// nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
const void *const *attn_qkv_b;
// nlayer * [dh]
const void *const *attn_q_norm;
// nlayer * [dh]
const void *const *attn_k_norm;
// nlayer * [ndev, d, nkvh / ndev * dh]
const void *const *attn_o;
// nlayer * [d]
Expand Down
65 changes: 57 additions & 8 deletions scripts/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
forward_batch,
)
from infer_task import InferTask, KVCache

from tokenizers import decoders as _dec
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
Expand Down Expand Up @@ -64,6 +64,12 @@ def attn_k_b(self, i):
def attn_v_b(self, i):
return f"model.layers.{i}.self_attn.v_proj.bias"

def attn_q_norm(self, i):
return f"model.layers.{i}.self_attn.q_norm.weight"

def attn_k_norm(self, i):
return f"model.layers.{i}.self_attn.k_norm.weight"

def ffn_norm(self, i):
return f"model.layers.{i}.post_attention_layernorm.weight"

Expand Down Expand Up @@ -123,7 +129,11 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config["hidden_size"] // config["num_attention_heads"],
dh=(
config["head_dim"]
if "head_dim" in config
else config["hidden_size"] // config["num_attention_heads"]
),
di=config["intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
Expand Down Expand Up @@ -281,6 +291,35 @@ def qkv_b_slices(_i):
else:
self.attn_qkv_b = None

if naming.attn_q_norm(0) in state_dict:
self.attn_q_norm_tensors = [
state_dict[naming.attn_q_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_q_norm_ptrs = [
self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs)
self.attn_k_norm_tensors = [
state_dict[naming.attn_k_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_k_norm_ptrs = [
self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs)
else:
self.attn_q_norm = None
self.attn_k_norm = None

self.attn_o_tensor = [
(
state_dict[naming.attn_o(i)]
Expand Down Expand Up @@ -484,7 +523,7 @@ def load_all_safetensors_from_dir(dir_path_: str):
)
else:
raise ValueError("Unsupported weight naming")
elif "qwen2" == config["model_type"]:
elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
Expand All @@ -501,6 +540,20 @@ def load_all_safetensors_from_dir(dir_path_: str):
else:
raise ValueError("Unsupported model architecture")

backend = getattr(self.tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
dec = getattr(target, "decoder", None)
sn = repr(norm)[:800] if norm is not None else ""
sd = repr(dec)[:800] if dec is not None else ""
has_prepend = "Prepend" in sn
has_strip = "Strip" in sd
if has_prepend and has_strip:
target.decoder = _dec.Sequence([
_dec.Replace("▁", " "),
_dec.ByteFallback(),
_dec.Fuse(),
])
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")

Expand Down Expand Up @@ -564,11 +617,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
Expand Down
12 changes: 2 additions & 10 deletions scripts/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,7 @@ async def chat_stream(id_, request_data, request: Request):
break

token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False)
yield f"data: {chunk}\n\n"

Expand All @@ -236,11 +232,7 @@ async def chat(id_, request_data, request: Request):
break

token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
output.append(content)

output_text = "".join(output).strip()
Expand Down
2 changes: 2 additions & 0 deletions scripts/libinfinicore_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class JiugeWeightsCStruct(ctypes.Structure):
("attn_norm", POINTER(c_void_p)),
("attn_qkv", POINTER(c_void_p)),
("attn_qkv_b", POINTER(c_void_p)),
("attn_q_norm", POINTER(c_void_p)),
("attn_k_norm", POINTER(c_void_p)),
("attn_o", POINTER(c_void_p)),
("ffn_norm", POINTER(c_void_p)),
("ffn_gate_up", POINTER(c_void_p)),
Expand Down
29 changes: 22 additions & 7 deletions src/models/jiuge/jiuge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
infinirtStream_t stream;
infinirtStreamCreate(&stream);

std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
for (size_t layer = 0; layer < meta->nlayer; layer++) {
w_attn_norm.push_back(
Expand All @@ -32,6 +32,12 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
b_attn_qkv.push_back(
getAttnQKVBias(meta, weights, layer, idev, ndev));
}
if (weights->attn_q_norm != nullptr) {
w_attn_q_norm.push_back(
getAttnQNorm(meta, weights, layer));
w_attn_k_norm.push_back(
getAttnKNorm(meta, weights, layer));
}
w_attn_out.push_back(
getAttnO(meta, weights, layer, idev, ndev));
w_ffn_norm.push_back(
Expand All @@ -56,6 +62,8 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
w_attn_norm,
w_attn_qkv,
b_attn_qkv,
w_attn_q_norm,
w_attn_k_norm,
w_attn_out,
w_ffn_norm,
w_ffn_gate_up,
Expand Down Expand Up @@ -130,6 +138,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0;

// Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
Expand All @@ -141,7 +150,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);

auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
auto qkv_buf_view = qkv_buf->view({ntok, nh + nkvh * 2, dh});
auto q_buf = qkv_buf_view->slice(1, 0, nh);
auto k_buf = qkv_buf_view->slice(1, nh, nkvh);

// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
Expand Down Expand Up @@ -198,19 +209,23 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
if (has_qk_norm) {
rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon);
rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon);
}
// rope
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);

size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
auto q = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});

// self attention
// concat
Expand Down
2 changes: 1 addition & 1 deletion src/models/jiuge/jiuge_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct DeviceResource {
// Weights
std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
cos_table;
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
// Streams
infinirtStream_t stream;
Expand Down
16 changes: 16 additions & 0 deletions src/models/jiuge/jiuge_weight.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
}

inline std::shared_ptr<Tensor> getAttnQNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape);
}

inline std::shared_ptr<Tensor> getAttnKNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape);
}

inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
JiugeWeights const *w, size_t layer,
size_t idev, size_t ndev) {
Expand Down