diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index e89e171a..a6b1971a 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -35,6 +35,10 @@ typedef struct const void *const *attn_qkv; // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh] const void *const *attn_qkv_b; + // nlayer * [dh] + const void *const *attn_q_norm; + // nlayer * [dh] + const void *const *attn_k_norm; // nlayer * [ndev, d, nkvh / ndev * dh] const void *const *attn_o; // nlayer * [d] diff --git a/scripts/jiuge.py b/scripts/jiuge.py index a2e591f8..0474031b 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -15,7 +15,7 @@ forward_batch, ) from infer_task import InferTask, KVCache - +from tokenizers import decoders as _dec from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref import os from pathlib import Path @@ -64,6 +64,12 @@ def attn_k_b(self, i): def attn_v_b(self, i): return f"model.layers.{i}.self_attn.v_proj.bias" + def attn_q_norm(self, i): + return f"model.layers.{i}.self_attn.q_norm.weight" + + def attn_k_norm(self, i): + return f"model.layers.{i}.self_attn.k_norm.weight" + def ffn_norm(self, i): return f"model.layers.{i}.post_attention_layernorm.weight" @@ -123,7 +129,11 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): if "num_key_value_heads" in config else config["num_attention_heads"] ), - dh=config["hidden_size"] // config["num_attention_heads"], + dh=( + config["head_dim"] + if "head_dim" in config + else config["hidden_size"] // config["num_attention_heads"] + ), di=config["intermediate_size"], dctx=( config["max_position_embeddings"] if max_tokens is None else max_tokens @@ -281,6 +291,35 @@ def qkv_b_slices(_i): else: self.attn_qkv_b = None + if naming.attn_q_norm(0) in state_dict: + self.attn_q_norm_tensors = [ + state_dict[naming.attn_q_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_q_norm_ptrs = [ + self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs) + self.attn_k_norm_tensors = [ + state_dict[naming.attn_k_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_k_norm_ptrs = [ + self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs) + else: + self.attn_q_norm = None + self.attn_k_norm = None + self.attn_o_tensor = [ ( state_dict[naming.attn_o(i)] @@ -484,7 +523,7 @@ def load_all_safetensors_from_dir(dir_path_: str): ) else: raise ValueError("Unsupported weight naming") - elif "qwen2" == config["model_type"]: + elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]: state_dict = load_all_safetensors_from_dir(model_dir_path) if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) @@ -501,6 +540,20 @@ def load_all_safetensors_from_dir(dir_path_: str): else: raise ValueError("Unsupported model architecture") + backend = getattr(self.tokenizer, "backend_tokenizer", None) + target = getattr(backend, "_tokenizer", backend) + norm = getattr(target, "normalizer", None) + dec = getattr(target, "decoder", None) + sn = repr(norm)[:800] if norm is not None else "" + sd = repr(dec)[:800] if dec is not None else "" + has_prepend = "Prepend" in sn + has_strip = "Strip" in sd + if has_prepend and has_strip: + target.decoder = _dec.Sequence([ + _dec.Replace("▁", " "), + _dec.ByteFallback(), + _dec.Fuse(), + ]) load_end_time = time.time() print(f"Time used: {load_end_time - load_start_time:.3f}s") @@ -564,11 +617,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 - output_str = ( - self.tokenizer._tokenizer.id_to_token(output_tokens[0]) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4847a477..63e0e82b 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -207,11 +207,7 @@ async def chat_stream(id_, request_data, request: Request): break token = await infer_task.output_queue.async_q.get() - content = ( - request.app.state.model.tokenizer._tokenizer.id_to_token(token) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + content = request.app.state.model.tokenizer.decode(token) chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) yield f"data: {chunk}\n\n" @@ -236,11 +232,7 @@ async def chat(id_, request_data, request: Request): break token = await infer_task.output_queue.async_q.get() - content = ( - request.app.state.model.tokenizer._tokenizer.id_to_token(token) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + content = request.app.state.model.tokenizer.decode(token) output.append(content) output_text = "".join(output).strip() diff --git a/scripts/libinfinicore_infer.py b/scripts/libinfinicore_infer.py index a92382cd..af5d9624 100644 --- a/scripts/libinfinicore_infer.py +++ b/scripts/libinfinicore_infer.py @@ -66,6 +66,8 @@ class JiugeWeightsCStruct(ctypes.Structure): ("attn_norm", POINTER(c_void_p)), ("attn_qkv", POINTER(c_void_p)), ("attn_qkv_b", POINTER(c_void_p)), + ("attn_q_norm", POINTER(c_void_p)), + ("attn_k_norm", POINTER(c_void_p)), ("attn_o", POINTER(c_void_p)), ("ffn_norm", POINTER(c_void_p)), ("ffn_gate_up", POINTER(c_void_p)), diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index bafe784e..bb19ec83 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -21,7 +21,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, infinirtStream_t stream; infinirtStreamCreate(&stream); - std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; for (size_t layer = 0; layer < meta->nlayer; layer++) { w_attn_norm.push_back( @@ -32,6 +32,12 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, b_attn_qkv.push_back( getAttnQKVBias(meta, weights, layer, idev, ndev)); } + if (weights->attn_q_norm != nullptr) { + w_attn_q_norm.push_back( + getAttnQNorm(meta, weights, layer)); + w_attn_k_norm.push_back( + getAttnKNorm(meta, weights, layer)); + } w_attn_out.push_back( getAttnO(meta, weights, layer, idev, ndev)); w_ffn_norm.push_back( @@ -56,6 +62,8 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, w_attn_norm, w_attn_qkv, b_attn_qkv, + w_attn_q_norm, + w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, @@ -130,6 +138,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto dvoc = meta.dvoc; auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; + bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0; // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); @@ -141,7 +150,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_cpu = std::vector(nreq); - auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto qkv_buf_view = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_buf_view->slice(1, 0, nh); + auto k_buf = qkv_buf_view->slice(1, nh, nkvh); // Prepare inputs auto batch_pos_ids = std::vector(ntok); @@ -198,9 +209,13 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); // qkv_proj linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr); + if (has_qk_norm) { + rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon); + rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon); + } // rope - rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); - rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); size_t token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { @@ -208,9 +223,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto seq_len = req_lens[req]; auto total_len = past_len + seq_len; auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); - auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + auto q = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); // self attention // concat diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index be05b0e8..ad2a2fd0 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -20,7 +20,7 @@ struct DeviceResource { // Weights std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, cos_table; - std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; // Streams infinirtStream_t stream; diff --git a/src/models/jiuge/jiuge_weight.hpp b/src/models/jiuge/jiuge_weight.hpp index 6e8bc33e..7ee10155 100644 --- a/src/models/jiuge/jiuge_weight.hpp +++ b/src/models/jiuge/jiuge_weight.hpp @@ -70,6 +70,22 @@ inline std::shared_ptr getAttnQKVBias( return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape); } +inline std::shared_ptr getAttnQNorm( + JiugeMeta const *meta, + JiugeWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape); +} + +inline std::shared_ptr getAttnKNorm( + JiugeMeta const *meta, + JiugeWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape); +} + inline std::shared_ptr getAttnO(JiugeMeta const *meta, JiugeWeights const *w, size_t layer, size_t idev, size_t ndev) {