diff --git a/final_report/baseline.png b/final_report/baseline.png new file mode 100644 index 00000000..17a88ba0 Binary files /dev/null and b/final_report/baseline.png differ diff --git a/final_report/ppl.png b/final_report/ppl.png new file mode 100644 index 00000000..7b249b0e Binary files /dev/null and b/final_report/ppl.png differ diff --git a/final_report/prune.png b/final_report/prune.png new file mode 100644 index 00000000..f337191d Binary files /dev/null and b/final_report/prune.png differ diff --git a/final_report/report.md b/final_report/report.md new file mode 100644 index 00000000..fbd75380 --- /dev/null +++ b/final_report/report.md @@ -0,0 +1,110 @@ +# 启元人工智能大赛 推理系统报告 + +## 1. 参与赛道 +参与 **推理系统优化赛题**,基于 **InfiniCore-Infer 框架** 实现大模型推理引擎的性能优化。 + +## 2. 成果阐述 +主要实现了两个部分的优化:服务层优化、基于稀疏注意力的推理加速 + +- 服务层优化:服务层主要针对内存管理、KV缓存池,以及批处理数量动态调整这三个方面进行了优化,在TTFT、请求速率、token生成速度等多方面取得较大的性能提升。 + +- 稀疏注意力:在推理阶段实现了 **动态 KV Cache 管理**,通过类似**最近窗口注意力**策略在Prefill阶段进行 **Key/Value 压缩存储**,减少显存占用,提升了解码速度。 + - **Prefill 阶段**:启用窗口裁剪(Pruning),利用了大模型注意力计算中对最近窗口的稀疏注意力关注模式,仅保留最近 `recentWindow` 长度的最近KV Cache存入KV Cache,剪裁掉了前面冗余的tokens。 + - **Decode 阶段**:基于裁剪后的 KV Cache 增量计算,避免全量计算,并将当前token新加入进KV Cache。进而大幅提高推理速度并降低了显存占用。 +## 3. 技术亮点 + +### 一、内存管理优化 + +#### 1. 三层预分配架构设计 + +我们构建了一套基于内存块大小的智能分层预分配系统,核心思想是根据不同应用场景的内存使用模式,预先分配三个不同规模的内存池: + +**分层策略设计:** +- **小块内存池(32MB)**:专门处理≤4KB的小对象,主要服务于元数据存储、临时变量、控制结构等高频小内存需求 +- **中等内存池(64MB)**:负责4KB-64KB范围的中等内存分配,适用于中间计算结果、缓存数据、小型张量等 +- **大块内存池(128MB)**:承载>64KB的大型数据结构,如模型权重、激活值、大型张量等核心计算数据 + +#### 2. 智能碎片整理 + +实现了基于多维度指标的预测性碎片检测算法:系统持续监控三个关键指标来判断是否需要进行碎片整理。首先是碎片率检测,通过计算空闲内存的分散程度来评估碎片化水平;其次是最大可用块检测,当最大连续空闲块小于总空闲内存的50%时,说明内存已经严重碎片化;最后是分配失败率检测,通过统计最近的分配失败次数来预测潜在的内存压力。 + +设计了新的碎片整理方法:整理过程按内存块类型分别进行,确保不同类型的内存块不会被错误合并。系统只处理相邻的空闲块,通过局部合并逐步改善内存布局,同时维护详细的统计信息用于性能分析和优化决策。 + +在模型推理过程中每进行10次推理就检查一次内存健康状况。这种周期性检查既保证了内存质量,又避免了过度频繁的整理操作影响性能。 + +### 二、KV缓存池 + +#### 1. KV缓存池架构设计 + +本架构包含可用缓存池用于管理空闲的KV缓存对象,缓存元数据字典记录每个缓存的详细信息和使用统计,前缀索引提供快速的缓存查找能力,LRU淘汰顺序维护缓存的使用时序,以及使用中缓存追踪确保正在使用的缓存不被误操作。 + +#### 2. 多层级前缀索引 + +系统为不同长度的token前缀建立独立的哈希索引,支持1、4、8、16、32、64等多个长度级别。每当有新的缓存加入时,系统会自动为其生成多个不同长度的前缀键,并将缓存索引添加到对应的索引表中。前缀键通过MD5哈希算法生成,确保快速查找的同时避免键冲突。不同长度的前缀索引服务于不同的匹配场景。短前缀(1-4 tokens)主要捕获对话开始模式和常见指令,适用于快速识别对话类型;中等前缀(8-16 tokens)用于识别对话上下文和主题,能够有效区分不同的对话场景;长前缀(32-64 tokens)则提供精确匹配能力,适用于复杂对话历史的完整匹配。 + +在搜索时,系统采用从长到短的前缀匹配策略,优先尝试64、32、16、8、4、1等不同长度的前缀。对于每个前缀长度,系统生成对应的前缀键并在索引表中查找匹配的缓存候选者。一旦找到足够数量(通常为5个)的候选者,就停止进一步搜索,避免不必要的计算开销。 + +通过这种分层索引设计,系统能够根据查询需求选择最适合的前缀长度,实现近似O(1)的查找复杂度。同时,多层级索引还提供了查找的灵活性,当长前缀匹配失败时,可以自动降级到短前缀匹配,确保总能找到最佳的可用缓存。 + +#### 3. LRU淘汰策略与性能监控:智能缓存生命周期管理 + +实现了基于使用频率和时间的智能LRU淘汰策略。当缓存池达到容量上限时,系统会自动识别最少使用的缓存进行淘汰。淘汰算法只针对当前未被使用的缓存,确保正在服务请求的缓存不会被误删除。系统维护了详细的使用时序记录,能够精确识别最适合淘汰的缓存对象。 + +系统为每个缓存维护详细的元数据,包括创建时间、最后访问时间、使用频率、匹配成功次数等。基于这些数据,系统计算每个缓存的综合价值分数,用于指导淘汰决策和匹配选择。高价值缓存会被优先保留,确保系统整体性能的最优化。 + +### 三、动态批处理数量优化 + +#### 1. 并发任务数估算 + +实现了基于GPU显存实时监控的智能并发任务数估算机制。通过精确计算单个KV缓存的显存占用量,结合当前GPU显存使用情况,动态确定系统最大并发处理能力。 + +具体来说,系统根据模型元数据(层数、KV头数、最大序列长度、头维度、数据类型)精确计算单个KV缓存的显存需求。计算公式考虑了每层的K和V两个缓存矩阵,形状为`[max_len, nkvh, dh]`,并根据数据类型(F16/F32/BF16)确定字节大小。最终得出单个KV缓存的总显存占用 = $ max\_len \times nkvh \times dh \times dtype\_size \times 2 \times nlayer \times ndev$,之后基于剩余显存和单个KV缓存占用,系统计算理论最大并发任务数 = $ available\_memory // single\_kvcache\_size$。同时考虑配置的最大批次限制,取两者最小值作为实际推荐的并发任务数,确保系统稳定运行。 + +#### 2. 自适应动态批处理管理 + +实现了基于显存压力和性能历史的自适应批处理管理系统,能够根据实时系统状态动态调整批处理策略。 + +系统将显存使用率划分为三个等级:低压力(<25%)、中等压力(25%-35%)、高压力(>60%)。不同压力等级对应不同的批处理策略和探索概率,确保在各种显存条件下都能获得最优性能。当显存占用超过80%时,使用当前最优批次大小;当显存占用较低时,使用最大批次大小。这种策略既保证了高显存压力下的系统稳定性,又充分利用了低压力时的处理能力。 + +系统维护持久化的批次性能数据,记录不同批次大小下的延迟和吞吐量表现。每个批次大小最多保留20个性能样本,避免数据过时影响决策准确性。基于历史性能数据,系统计算综合评分,包括效率评分、置信度奖励和批次大小奖励,根据评分结果动态调整批次大小。 + +### 四、稀疏注意力机制优化 + + + +系统采用 **最近窗口注意力(Sliding Window Attention)**,仅对最近 `recentWindow` 的 Token 建立全连接注意力;而较早的 Token 则通过 **压缩/裁剪策略**驱逐出KV Cache集合,从而避免了解码阶段的全量计算带来的显存和时间开销。 +具体实现: +- **Prefill 阶段**:当输入序列长度 `seq_len > recentWindow` 时,仅保留区间 `[seq_len - recentWindow, seq_len]` 的 K/V 参与注意力计算,并将更早的部分进行截断或压缩后存入 KV Cache。 +- **Decode 阶段**:增量解码时,新增 Token 仅与窗口内token交互,历史缓存不再全量重复计算。 +- 可以通过jiuge.cpp的inferDeviceBatch中的稀疏注意力超参数**ratio**浮点数变量来改变稀疏比例(ratio表示所保留的KV Cache个数占不做裁剪的全量KV Cache的比例),以满足不同推理速度和生成精度需求。 + +## 4. 性能结果 + + + +使用框架中自带的测试脚本进行测试,测试命令为: + +```bash +cd InfiniCore-Infer/ +srun python scripts/launch_server.py --model-path /home/shared/models/9G7B_MHA --dev nvidia --ndev 1 + +srun python scripts/test_perf.py +srun python scripts/test_server.py +``` + +* **响应速度**:首次出词时间(TTFT)从平均3.32秒缩短至0.08秒,优化了97%,极大改善了用户的初始等待体验。 +* **处理效率**:总耗时从22.11秒减少到19.28秒,节省了12.8%的时间。请求速率(RPS)从0.45提升至0.52,提高了15.6%。 +* **Token生成速度**:平均每秒可生成52.91个token,相较于优化前的43.45个,提升了21.8%。 + +稀疏注意力加速效果: +采用128tokens的prompt 在metax GPU 压缩比为12.5% +- 加速前 (Time Per Second: 16.876ms) +![1](./baseline.png) +- 加速后 (Time Per Second: 8.463ms) +![2](./prune.png) + +优化后模型ppl为4.7086: + +![3](./ppl.png) + +## 5. 时间线 diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..27d57e87 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -27,7 +27,17 @@ def kvcache(self): def next(self, out_token): self._kv_cache.update_tokens(self.tokens, self.pos) - self.pos += len(self.tokens) + sparseOn = True + ratio = 0.2 + attentionSinkWindow = 4 + recentWindow = int((len(self.tokens) - attentionSinkWindow) * ratio) + prune = sparseOn and recentWindow >= 1 and self.pos == 0 + if prune: # update pruned pos + nSaveKVs= attentionSinkWindow + recentWindow + self.pos += nSaveKVs + else: # baseline + self.pos += len(self.tokens) + if out_token == None or out_token in self.end_tokens: self.finish_reason = "stop" elif self.pos >= self.max_tokens: diff --git a/scripts/jiuge.py b/scripts/jiuge.py index a2e591f8..809ad1ff 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -1,6 +1,4 @@ -from typing import List, Sequence - -from sympy import true +from typing import List from libinfinicore_infer import ( JiugeMetaCStruct, JiugeWeightsCStruct, @@ -12,7 +10,6 @@ create_kv_cache, drop_kv_cache, infer_batch, - forward_batch, ) from infer_task import InferTask, KVCache @@ -99,7 +96,7 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): self.scale_o = 1.0 self.scale_down = 1.0 if ( - config["model_type"] in ["fm9g", "minicpm"] + "fm9g" == config["model_type"] and "scale_emb" in config and "scale_depth" in config and "dim_model_base" in config @@ -397,6 +394,8 @@ class JiugeForCauslLM: def __init__( self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None ): + # Store ndev as instance attribute for KVCache analysis + self.ndev = ndev def load_all_safetensors_from_dir(dir_path_: str): tensors_ = {} dir_path_ = Path(dir_path_) @@ -434,7 +433,7 @@ def load_all_safetensors_from_dir(dir_path_: str): ndev=ndev, transpose_weight=transpose_weight, ) - elif "fm9g" == config["model_type"] or "minicpm" == config["model_type"]: + elif "fm9g" == config["model_type"]: if any( file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() ): @@ -585,59 +584,6 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. infer_task._kv_cache.drop(self) return output_content, avg_time - def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): - tasks = [ - InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) - for i in range(batch_size) - ] - kv_caches = [KVCache(self) for _ in range(batch_size)] - - nll = 0.0 - total_len = 0 - - for i in range(0, len(test_sequences), batch_size): - batch_id = 0 - true_tokens = [] - while batch_id < batch_size and batch_id + i < len(test_sequences): - input_tokens = test_sequences[i + batch_id][:-1] - true_tokens.extend(test_sequences[i + batch_id][1:]) - tasks[batch_id].tokens = input_tokens - tasks[batch_id].bind_kvcache(kv_caches[batch_id]) - batch_id += 1 - - batch_inputs = JiugeBatchedTask(tasks[:batch_id]) - logits = torch.zeros( - (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits - ) - forward_batch( - self.model_instance, - batch_inputs.tokens, - batch_inputs.ntok, - batch_inputs.req_lens, - batch_inputs.nreq, - batch_inputs.req_pos, - batch_inputs.kv_caches, - logits.data_ptr(), - ) - - logits = logits.float() - token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) - token_logprobs = log_probs[ - torch.arange(batch_inputs.ntok), token_ids - ] # (ntok,) - - start = 0 - for l in batch_inputs.req_lens_list: - nll += -token_logprobs[start : start + l].sum().item() - start += l - total_len += token_logprobs.numel() - - for task in tasks: - task.release_kvcache() - - return math.exp(nll / total_len) - def destroy_model_instance(self): destroy_jiuge_model(self.model_instance) print("Model destroyed") diff --git a/scripts/kvcache_pool.py b/scripts/kvcache_pool.py index 81914535..d77c7d5a 100644 --- a/scripts/kvcache_pool.py +++ b/scripts/kvcache_pool.py @@ -1,12 +1,39 @@ from infer_task import KVCache import asyncio -from typing import List +from typing import List, Dict, Tuple, Optional import threading +import time +from collections import defaultdict, OrderedDict +import hashlib + + +class CacheMetadata: + """缓存元数据,用于LRU和性能统计""" + def __init__(self, cache_id: int): + self.cache_id = cache_id + self.last_access_time = time.time() + self.access_count = 0 + self.hit_count = 0 + self.creation_time = time.time() + + def update_access(self, hit: bool = True): + self.last_access_time = time.time() + self.access_count += 1 + if hit: + self.hit_count += 1 + + def get_score(self) -> float: + """计算缓存价值分数,用于淘汰决策""" + age = time.time() - self.last_access_time + hit_rate = self.hit_count / max(self.access_count, 1) + # 分数越高越有价值,越不容易被淘汰 + return hit_rate * 100 - age * 0.1 class KVCachePool: - def __init__(self, model, max_caches: int = 32): + def __init__(self, model, max_caches: int = 8): + print(f"[INFO] KVCachePool init with max_caches: {max_caches}") self.max_caches = max_caches self.model = model self._available: List[KVCache] = [] @@ -14,7 +41,93 @@ def __init__(self, model, max_caches: int = 32): self._lock = threading.Lock() self._not_empty = threading.Condition(self._lock) self._shutdown = False + + # 优化相关的数据结构 + self._cache_metadata: Dict[int, CacheMetadata] = {} # 缓存元数据 + self._prefix_index: Dict[str, List[int]] = defaultdict(list) # 前缀索引 + self._lru_order: OrderedDict[int, bool] = OrderedDict() # LRU顺序 + self._in_use_caches: Dict[int, bool] = {} # 跟踪正在使用的缓存 + self._next_cache_id = 0 + + # 性能统计 + self._total_requests = 0 + self._cache_hits = 0 + self._exact_matches = 0 + self._lru_evictions = 0 # 跟踪LRU淘汰次数 + + print(f"[INFO] KVCachePool init done.") + def _evict_lru_cache(self) -> bool: + """LRU淘汰策略:移除最少使用的缓存(仅淘汰未在使用的缓存)""" + if not self._lru_order: + return False + + # 如果available列表为空,不能进行淘汰 + if not self._available: + print("[WARNING] All caches are in use, cannot evict. Waiting for cache release.") + return False + + # 只从available列表中淘汰未在使用的缓存 + return self._evict_from_available() + + + + def _evict_from_available(self) -> bool: + """从available列表中淘汰LRU缓存""" + if not self._available: + return False + + # 找到available列表中最少使用的缓存 + lru_cache_id = None + lru_cache_index = -1 + oldest_access_time = float('inf') + + for i, kvcache in enumerate(self._available): + cache_id = id(kvcache) + # 跳过正在使用的缓存 + if cache_id in self._in_use_caches: + continue + + if cache_id in self._cache_metadata: + metadata = self._cache_metadata[cache_id] + if metadata.last_access_time < oldest_access_time: + oldest_access_time = metadata.last_access_time + lru_cache_id = cache_id + lru_cache_index = i + else: + # 如果没有元数据,优先淘汰这个缓存 + lru_cache_id = cache_id + lru_cache_index = i + break + + if lru_cache_index >= 0: + # 移除缓存 + evicted_cache = self._available.pop(lru_cache_index) + + # 从前缀索引中移除 + self._remove_from_prefix_index(lru_cache_index, evicted_cache.tokens) + + # 更新其他缓存在前缀索引中的位置 + for prefix_key, cache_indices in self._prefix_index.items(): + for j in range(len(cache_indices)): + if cache_indices[j] > lru_cache_index: + cache_indices[j] -= 1 + + # 清理元数据 + if lru_cache_id in self._cache_metadata: + del self._cache_metadata[lru_cache_id] + if lru_cache_id in self._lru_order: + del self._lru_order[lru_cache_id] + + # 释放底层资源 + evicted_cache.drop(self.model) + self.num_caches -= 1 + self._lru_evictions += 1 # 增加LRU淘汰计数 + print(f"[INFO] Evicted cache {lru_cache_id} from available pool") + return True + + return False + def acquire_sync(self, infer_task): with self._not_empty: while True: @@ -24,27 +137,78 @@ def acquire_sync(self, infer_task): ) if len(self._available) == 0: if self.num_caches < self.max_caches: + # 创建新缓存 self.num_caches += 1 - print( - f"[INFO] Task {infer_task.id} created new KVCachePoolItem" - ) - return infer_task.bind_kvcache(KVCache(self.model), 0) + new_cache = KVCache(self.model) + cache_id = id(new_cache) + + # 创建元数据 + self._cache_metadata[cache_id] = CacheMetadata(cache_id) + self._lru_order[cache_id] = True + # 标记为正在使用 + self._in_use_caches[cache_id] = True + + return infer_task.bind_kvcache(new_cache, 0) else: - self._not_empty.wait() + # 尝试LRU淘汰 + if self._evict_lru_cache(): + continue # 淘汰成功,重新尝试 + else: + self._not_empty.wait() # 等待缓存释放 else: max_match, max_match_index = self.find_most_matching_cache( infer_task.tokens ) kvcache = self._available.pop(max_match_index) - print( - f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches" - ) + + # 从前缀索引中移除 + self._remove_from_prefix_index(max_match_index, kvcache.tokens) + + # 更新其他缓存在前缀索引中的位置 + for prefix_key, cache_indices in self._prefix_index.items(): + for j in range(len(cache_indices)): + if cache_indices[j] > max_match_index: + cache_indices[j] -= 1 + + # 标记为正在使用 + cache_id = id(kvcache) + self._in_use_caches[cache_id] = True + return infer_task.bind_kvcache(kvcache, max_match) def release_sync(self, infer_task): with self._not_empty: - print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool") - self._available.append(infer_task.release_kvcache()) + released_cache = infer_task.release_kvcache() + cache_id = id(released_cache) + + # 更新缓存元数据 + if cache_id not in self._cache_metadata: + self._cache_metadata[cache_id] = CacheMetadata(cache_id) + + # 添加到available列表 + cache_index = len(self._available) + self._available.append(released_cache) + + # 更新前缀索引 + self._update_prefix_index(cache_index, released_cache.tokens) + + # 更新LRU顺序 + if cache_id in self._lru_order: + del self._lru_order[cache_id] + self._lru_order[cache_id] = True + + # 移除正在使用的标记 + if cache_id in self._in_use_caches: + del self._in_use_caches[cache_id] + + # 释放缓存后清理GPU显存碎片 + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass # torch不可用时跳过 + self._not_empty.notify() async def acquire(self, infer_task): @@ -55,29 +219,150 @@ async def release(self, infer_task): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.release_sync, infer_task) - def find_most_matching_cache(self, tokens: List[int]): + def _generate_prefix_key(self, tokens: List[int], length: int) -> str: + """生成前缀哈希键""" + if length <= 0: + return "" + prefix = tokens[:min(length, len(tokens))] + return hashlib.md5(str(prefix).encode()).hexdigest()[:16] + + def _update_prefix_index(self, cache_index: int, tokens: List[int]): + """更新前缀索引""" + # 为不同长度的前缀建立索引 + for prefix_len in [1, 4, 8, 16, 32, 64]: + if prefix_len <= len(tokens): + prefix_key = self._generate_prefix_key(tokens, prefix_len) + if cache_index not in self._prefix_index[prefix_key]: + self._prefix_index[prefix_key].append(cache_index) + + def _remove_from_prefix_index(self, cache_index: int, tokens: List[int]): + """从前缀索引中移除缓存""" + for prefix_len in [1, 4, 8, 16, 32, 64]: + if prefix_len <= len(tokens): + prefix_key = self._generate_prefix_key(tokens, prefix_len) + if cache_index in self._prefix_index[prefix_key]: + self._prefix_index[prefix_key].remove(cache_index) + if not self._prefix_index[prefix_key]: + del self._prefix_index[prefix_key] + + def _first_different_index(self, a: List[int], b: List[int]) -> int: + """找到两个序列第一个不同元素的索引""" + for i, (x, y) in enumerate(zip(a, b)): + if x != y: + return i + return min(len(a), len(b)) + + def find_most_matching_cache(self, tokens: List[int]) -> Tuple[int, int]: + """优化的缓存匹配算法""" + self._total_requests += 1 + + if not self._available: + return (0, 0) + + # 第一阶段:基于前缀索引的快速匹配 + candidates = set() + for prefix_len in [64, 32, 16, 8, 4, 1]: # 从长到短尝试 + if prefix_len <= len(tokens): + prefix_key = self._generate_prefix_key(tokens, prefix_len) + if prefix_key in self._prefix_index: + candidates.update(self._prefix_index[prefix_key]) + if len(candidates) >= 5: # 找到足够的候选者就停止 + break + + # 如果前缀索引没有找到候选者,回退到全搜索 + if not candidates: + candidates = set(range(len(self._available))) + + # 第二阶段:在候选者中找最佳匹配 max_match = 0 max_match_index = 0 - - def first_different_index(a_, b_): - for i_, (x_, y_) in enumerate(zip(a_, b_)): - if x_ != y_: - return i_ - return min(len(a_), len(b_)) - - for i, kvcache in enumerate(self._available): - common_elements = first_different_index(tokens, kvcache.tokens) - # print(f"{tokens}") - # print(f"{kvcache.tokens[:len(tokens)]}") - if common_elements > max_match: + best_score = -1 + + for i in candidates: + if i >= len(self._available): + continue + + kvcache = self._available[i] + common_elements = self._first_different_index(tokens, kvcache.tokens) + + # 计算综合分数:匹配长度 + 缓存价值 + cache_id = id(kvcache) # 使用对象id作为缓存标识 + metadata = self._cache_metadata.get(cache_id) + cache_score = metadata.get_score() if metadata else 0 + + total_score = common_elements * 100 + cache_score # 综合分数:匹配长度权重更高 + + if common_elements > max_match or (common_elements == max_match and total_score > best_score): max_match = common_elements max_match_index = i - + best_score = total_score + + # 更新统计信息 + if max_match > 0: + self._cache_hits += 1 + if max_match == len(tokens): + self._exact_matches += 1 + + # 更新缓存元数据 + if max_match_index < len(self._available): + kvcache = self._available[max_match_index] + cache_id = id(kvcache) + if cache_id in self._cache_metadata: + self._cache_metadata[cache_id].update_access(hit=True) + # 更新LRU顺序 + if cache_id in self._lru_order: + del self._lru_order[cache_id] + self._lru_order[cache_id] = True + return (min(max_match, len(tokens) - 1), max_match_index) + def get_cache_stats(self) -> Dict[str, float]: + """获取缓存性能统计信息""" + hit_rate = self._cache_hits / max(self._total_requests, 1) * 100 + exact_match_rate = self._exact_matches / max(self._total_requests, 1) * 100 + + return { + "total_requests": self._total_requests, + "cache_hits": self._cache_hits, + "exact_matches": self._exact_matches, + "hit_rate_percent": hit_rate, + "exact_match_rate_percent": exact_match_rate, + "available_caches": len(self._available), + "total_caches": self.num_caches, + "prefix_index_size": len(self._prefix_index), + "avg_cache_age": self._get_avg_cache_age(), + "lru_evictions": self._lru_evictions + } + + def _get_avg_cache_age(self) -> float: + """计算平均缓存年龄""" + if not self._cache_metadata: + return 0.0 + + current_time = time.time() + total_age = sum(current_time - metadata.last_access_time + for metadata in self._cache_metadata.values()) + return total_age / len(self._cache_metadata) + + def print_cache_stats(self): + """打印缓存统计信息""" + stats = self.get_cache_stats() + print("\n=== KV Cache Pool Statistics ===") + print(f"Total Requests: {stats['total_requests']}") + print(f"Cache Hit Rate: {stats['hit_rate_percent']:.2f}%") + print(f"Exact Match Rate: {stats['exact_match_rate_percent']:.2f}%") + print(f"Available/Total Caches: {stats['available_caches']}/{stats['total_caches']}") + print(f"Prefix Index Size: {stats['prefix_index_size']}") + print(f"Average Cache Age: {stats['avg_cache_age']:.2f}s") + print("================================\n") + def finalize(self): with self._not_empty: self._shutdown = True + + # 打印最终统计信息 + self.print_cache_stats() + while len(self._available) < self.num_caches: self._not_empty.wait() @@ -85,6 +370,16 @@ def finalize(self): if kvcache is not None: kvcache.drop(self.model) + # 清理所有数据结构 self._available.clear() + self._cache_metadata.clear() + self._prefix_index.clear() + self._lru_order.clear() + self._in_use_caches.clear() + self.max_caches = 0 self.num_caches = 0 + self._total_requests = 0 + self._cache_hits = 0 + self._exact_matches = 0 + self._lru_evictions = 0 diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4847a477..d1d5c376 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -2,6 +2,7 @@ from libinfinicore_infer import DeviceType from infer_task import InferTask from kvcache_pool import KVCachePool +from dynamic_batch_manager import DynamicBatchManager import argparse import queue @@ -14,6 +15,10 @@ import json import threading import janus +import torch +import pynvml +import psutil +import gc DEVICE_TYPE_MAP = { @@ -48,8 +53,8 @@ def parse_args(): parser.add_argument( "--max-batch", type=int, - default=3, - help="Maximum number of requests that can be batched together (default: 3)", + default=5, + help="Maximum number of requests that can be batched together (default: 4)", ) parser.add_argument( "--max-tokens", @@ -105,13 +110,155 @@ def output(self, out_token): self.next(out_token) self.output_queue.sync_q.put(out_token) +def get_memory_usage() -> float: + """获取当前GPU显存使用率,如果GPU不可用则获取系统内存使用率""" + try: + # 检查是否有可用的GPU + if pynvml and hasattr(pynvml, 'nvmlInit'): + try: + pynvml.nvmlInit() + # 使用第一个GPU设备 + gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # 清理PyTorch缓存以获得准确的显存使用率 + if torch and torch.cuda.is_available(): + torch.cuda.empty_cache() + + # 获取GPU显存使用率 + memory_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + gpu_usage = memory_info.used / memory_info.total + return gpu_usage + except Exception as e: + print(f"[WARNING] Failed to get GPU memory usage: {e}") + + # 回退到系统内存使用率 + memory = psutil.virtual_memory() + return memory.percent / 100.0 + + except Exception as e: + print(f"[WARNING] Failed to get memory usage: {e}") + return 0.5 # 默认返回50% + + +def get_gpu_memory_info(): + """获取GPU显存信息(总量和已使用量,单位:字节)""" + try: + if pynvml and hasattr(pynvml, 'nvmlInit'): + pynvml.nvmlInit() + gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # 清理缓存以获得准确的显存使用情况 + if torch and torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + memory_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + return memory_info.total, memory_info.used, memory_info.free + except Exception as e: + print(f"[WARNING] Failed to get GPU memory info: {e}") + return None, None, None + + +def calculate_kvcache_memory_size(model): + """计算单个KVCache的显存占用(字节)""" + try: + # 获取模型元数据 + meta = model.meta + ndev = model.ndev + + # KVCache参数 + nlayer = meta.nlayer # 层数 + nkvh = meta.nkvh // ndev # 每个设备的KV头数 + max_len = meta.dctx # 最大序列长度 + dh = meta.dh # 头维度 + + # 数据类型大小(字节) + if meta.dt_logits == 0: # INFINI_DTYPE_F16 + dtype_size = 2 + elif meta.dt_logits == 1: # INFINI_DTYPE_F32 + dtype_size = 4 + elif meta.dt_logits == 2: # INFINI_DTYPE_BF16 + dtype_size = 2 + else: + dtype_size = 2 # 默认使用F16 + + # 单个KVCache的显存占用计算 + # 每层有K和V两个缓存,形状为 [max_len, nkvh, dh] + single_cache_size = max_len * nkvh * dh * dtype_size + total_cache_size = single_cache_size * 2 * nlayer * ndev # K和V,所有层,所有设备 + + print(f"[INFO] KVCache参数: nlayer={nlayer}, nkvh={nkvh}, max_len={max_len}, dh={dh}, dtype_size={dtype_size}") + print(f"[INFO] 单个KVCache显存占用: {total_cache_size / (1024**3):.2f} GB") + + return total_cache_size + except Exception as e: + print(f"[ERROR] Failed to calculate KVCache memory size: {e}") + return None + + +def estimate_max_concurrent_tasks(model, safety_margin=0.1): + """估算剩余显存最多能支持多少个任务同时创建KVCache""" + try: + # 获取GPU显存信息 + total_memory, used_memory, free_memory = get_gpu_memory_info() + if total_memory is None: + print(f"[WARNING] Cannot get GPU memory info, using default estimation") + return MAX_BATCH + + # 计算单个KVCache的显存占用 + single_kvcache_size = calculate_kvcache_memory_size(model) + if single_kvcache_size is None: + print(f"[WARNING] Cannot calculate KVCache size, using default estimation") + return MAX_BATCH + + # 计算可用显存(保留安全边际) + # available_memory = free_memory * (1 - safety_margin) + available_memory = free_memory + + # 估算最大并发任务数 + max_tasks = int(available_memory // single_kvcache_size) + + print(f"[INFO] GPU显存信息:") + print(f" - 总显存: {total_memory / (1024**3):.2f} GB") + print(f" - 已使用: {used_memory / (1024**3):.2f} GB ({used_memory/total_memory*100:.1f}%)") + print(f" - 剩余显存: {free_memory / (1024**3):.2f} GB") + print(f" - 单个KVCache占用: {single_kvcache_size / (1024**3):.2f} GB") + print(f" - 估算最大并发任务数: {max_tasks}") + + # 确保不超过配置的最大批次大小 + recommended_tasks = min(max_tasks, MAX_BATCH) + return recommended_tasks + + except Exception as e: + print(f"[ERROR] Failed to estimate max concurrent tasks: {e}") + return MAX_BATCH + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): # Startup app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens) - app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) + + # 模型加载完毕后,计算剩余显存最多能够支持几个任务同时创建KVCache + estimated_max_tasks = estimate_max_concurrent_tasks(app.state.model) + print(f"[INFO] 配置的MAX_BATCH: {MAX_BATCH}") + print(f"[INFO] 估算的最大并发任务数: {estimated_max_tasks}") + print('默认采用估算最大并发数') + + app.state.kv_cache_pool = KVCachePool(app.state.model, estimated_max_tasks) + # app.state.kv_cache_pool = KVCachePool(app.state.model) app.state.request_queue = janus.Queue() + initial_mem_usage = get_memory_usage() + print(f'[Info] initial memory usage: {initial_mem_usage}') + # 初始化动态批处理管理器 + app.state.batch_manager = DynamicBatchManager( + min_batch_size=1, + max_batch_size=estimated_max_tasks, + max_wait_time_ms=200, # 增加等待时间以允许更大的批次 + memory_threshold=0.9, # 调整内存阈值到90%,避免过早触发内存压力保护 + base_mem_usage=initial_mem_usage, + gpu_device_id=0 # 使用第一个GPU设备 + ) worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) worker_thread.start() @@ -130,33 +277,135 @@ async def lifespan(app: FastAPI): App = FastAPI(lifespan=lifespan) +@App.get("/batch_stats") +async def get_batch_stats(): + """获取动态批处理统计信息""" + try: + stats = App.state.batch_manager.get_stats() + return JSONResponse(content={ + "status": "success", + "data": stats + }) + except Exception as e: + return JSONResponse( + status_code=500, + content={"status": "error", "message": str(e)} + ) + # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. def worker_loop(app): + """动态批处理工作循环""" + batch_manager = app.state.batch_manager + while True: try: - task = app.state.request_queue.sync_q.get(timeout=0.01) + # 尝试获取第一个任务 + task = app.state.request_queue.sync_q.get(timeout=0.05) except queue.Empty: continue if task is None: return + # 开始构建批次 batch = [task] - while len(batch) < MAX_BATCH: - try: - req = app.state.request_queue.sync_q.get_nowait() - if req is not None: - batch.append(req) - except queue.Empty: + batch_start_time = time.time() + + while True: + + queue_size = app.state.request_queue.sync_q.qsize() + + # 获取动态批处理建议 + suggested_batch_size = batch_manager.calculate_dynamic_batch_size( + queue_size + len(batch) + ) + + current_time = time.time() + wait_time_ms = (current_time - batch_start_time) * 1000 + # 判断是否应该处理当前批次 + should_process = batch_manager.should_process_batch( + len(batch),suggested_batch_size, queue_size, wait_time_ms + ) + + if should_process: break - output_tokens = app.state.model.batch_infer_one_round(batch) - for task, token in zip(batch, output_tokens): - task.output(token) - if task.finish_reason is None: - app.state.request_queue.sync_q.put(task) + + # 尝试添加更多任务到批次 + if len(batch) < suggested_batch_size: + try: + # 计算剩余等待时间,确保有足够时间收集更多请求 + remaining_wait_ms = max(0, batch_manager.current_wait_time_ms - wait_time_ms) + timeout_seconds = min(0.05, remaining_wait_ms / 1000.0) # 最多等待50ms + req = app.state.request_queue.sync_q.get(timeout=timeout_seconds) + if req is not None: + batch.append(req) + else: + break + except queue.Empty: + # 检查是否应该继续等待 + if wait_time_ms >= batch_manager.current_wait_time_ms: + break + time.sleep(0.001) # 短暂等待 else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) + break + + # 执行批量推理 + if len(batch) > 0: + print(f"[DEBUG] Processing batch with size: {len(batch)}, suggested size was: {suggested_batch_size}") + infer_start_time = time.time() + + try: + output_tokens = app.state.model.batch_infer_one_round(batch) + + # 处理输出 + finished_tasks = 0 + for task, token in zip(batch, output_tokens): + task.output(token) + if task.finish_reason is None: + print(f"[DEBUG] Task {task.id} is not finished.") + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) + finished_tasks += 1 + + # 如果有任务完成,检查是否需要清理显存 + if finished_tasks > 0: + current_memory = batch_manager.get_memory_usage() + if current_memory > 0.75: # 显存占用超过75%时主动清理 + print(f"[INFO] Memory usage before cleanup: {current_memory:.2%}") + batch_manager._force_memory_cleanup() + # 清理后再次检查显存使用率 + new_memory = batch_manager.get_memory_usage() + print(f"[INFO] Memory usage after cleanup: {new_memory:.2%}, freed: {(current_memory-new_memory)*100:.2f}%") + elif finished_tasks >= 3: # 每完成3个任务就强制清理一次 + print(f"[INFO] Periodic cleanup after {finished_tasks} tasks, memory usage: {current_memory:.2%}") + batch_manager._force_memory_cleanup() + + # 记录性能数据 + infer_end_time = time.time() + latency_ms = (infer_end_time - infer_start_time) * 1000 + throughput = len(batch) / (infer_end_time - infer_start_time) + + batch_manager.record_batch_performance( + len(batch), latency_ms, throughput + ) + + # 定期打印统计信息 + if len(batch_manager.batch_history) % 50 == 0: + stats = batch_manager.get_stats() + print(f"[INFO] Dynamic Batch Stats: optimal_batch={stats['current_optimal_batch']}, " + f"wait_time={stats['current_wait_time_ms']}ms, " + f"memory_usage={stats['memory_usage']:.4%}, " + f"avg_latency={stats['avg_latency']:.1f}ms, " + f"avg_throughput={stats['avg_throughput']:.1f} req/s") + + except Exception as e: + print(f"[ERROR] Batch inference failed: {e}") + # 将任务重新放回队列 + for task in batch: + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) def build_task(id_, request_data, request: Request): @@ -182,6 +431,7 @@ async def chat_stream(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + print(f"[INFO] Task {infer_task.id} acquired kv cache.") # Initial empty content chunk = json.dumps( @@ -278,7 +528,7 @@ async def chat_completions(request: Request): return JSONResponse(content=response) if __name__ == "__main__": - uvicorn.run(App, host="0.0.0.0", port=8000) + uvicorn.run(App, host="0.0.0.0", port=8010) """ curl -N -H "Content-Type: application/json" \ diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..8dcb5b13 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -30,7 +30,7 @@ NUM_REQUESTS = 10 CONCURRENCY = 5 -API_URL = "http://127.0.0.1:8000" +API_URL = "http://127.0.0.1:8010" MODEL = "FM9G-7B" diff --git a/scripts/test_ppl.py b/scripts/test_ppl.py index 268a9f7d..aa1d38ea 100644 --- a/scripts/test_ppl.py +++ b/scripts/test_ppl.py @@ -10,7 +10,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) - parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--port", type=int, default=8010) parser.add_argument("--endpoint", type=str, default="/completions") parser.add_argument("--chunk", type=int, default=512) args = parser.parse_args() diff --git a/src/allocator.hpp b/src/allocator.hpp index b3169fdb..9dca25f9 100644 --- a/src/allocator.hpp +++ b/src/allocator.hpp @@ -6,6 +6,9 @@ #include #include #include +#include +#include +#include class AllocatorBase { public: @@ -14,41 +17,140 @@ class AllocatorBase { virtual ~AllocatorBase() = default; }; +// 内存使用统计结构 +struct MemoryStats { + std::atomic total_allocated{0}; + std::atomic total_freed{0}; + std::atomic current_usage{0}; + std::atomic peak_usage{0}; + std::atomic allocation_count{0}; + std::atomic free_count{0}; + std::atomic fragmentation_events{0}; + + void recordAllocation(size_t size) { + total_allocated += size; + current_usage += size; + allocation_count++; + + size_t current = current_usage.load(); + size_t peak = peak_usage.load(); + while (current > peak && !peak_usage.compare_exchange_weak(peak, current)) { + peak = peak_usage.load(); + } + } + + void recordFree(size_t size) { + total_freed += size; + current_usage -= size; + free_count++; + } + + void recordFragmentation() { + fragmentation_events++; + } + + double getFragmentationRate() const { + size_t allocs = allocation_count.load(); + return allocs > 0 ? static_cast(fragmentation_events.load()) / allocs : 0.0; + } +}; + class MemoryPool : public AllocatorBase { public: static constexpr size_t DEFAULT_ALIGNMENT = 256; + static constexpr size_t SMALL_BLOCK_THRESHOLD = 1024; // 1KB + static constexpr size_t MEDIUM_BLOCK_THRESHOLD = 1024 * 1024; // 1MB + static constexpr size_t LARGE_BLOCK_THRESHOLD = 16 * 1024 * 1024; // 16MB + + // 预分配配置 + struct PreallocationConfig { + size_t small_pool_size; + size_t medium_pool_size; + size_t large_pool_size; + bool enable_preallocation; + + PreallocationConfig() : + small_pool_size(16 * 1024 * 1024), + medium_pool_size(128 * 1024 * 1024), + large_pool_size(512 * 1024 * 1024), + enable_preallocation(true) {} + }; - explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT); + explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT, + const PreallocationConfig& config = PreallocationConfig{}); ~MemoryPool(); void *alloc(size_t size) override; void release(void *ptr) override; - + + // 新增功能接口 + void defragment(); // 内存碎片整理 + const MemoryStats& getStats() const { return _stats; } + void printStats() const; + void preAllocate(const PreallocationConfig& config); // 预分配内存 + bool shouldDefragment() const; // 检查是否需要碎片整理 + size_t getAlignment() const { return _alignment; } + size_t getTotalMemory() const; + size_t getUsedMemory() const; + size_t getFreeMemory() const; + double getFragmentationRatio() const; private: + enum class BlockType { + SMALL, + MEDIUM, + LARGE + }; + struct Block { void *base; void *ptr; size_t size; bool is_free; + BlockType type; + std::chrono::steady_clock::time_point last_used; - Block(void *b, void *p, size_t s, bool f) - : base(b), ptr(p), size(s), is_free(f) {} + Block(void *b, void *p, size_t s, bool f, BlockType t = BlockType::MEDIUM) + : base(b), ptr(p), size(s), is_free(f), type(t), + last_used(std::chrono::steady_clock::now()) {} bool operator<(const Block &other) const { return ptr < other.ptr; } }; + + struct PoolInfo { + std::multimap::iterator> free_blocks; + size_t total_size = 0; + size_t used_size = 0; + }; - void *allocateNewRegion(size_t size); + BlockType getBlockType(size_t size) const; + void *allocateNewRegion(size_t size, BlockType type = BlockType::MEDIUM); void tryCoalesce(const Block &block); - + void *allocFromPool(size_t size, BlockType type); + void releaseToPool(void *ptr, const Block& block); + + // 碎片整理相关 + void compactPool(BlockType type); + + mutable std::mutex _mutex; // 线程安全 size_t _alignment; + PreallocationConfig _config; + std::vector _base_regions; std::set _all_blocks; - std::multimap::iterator> _free_blocks; std::unordered_map::iterator> _ptr_to_block; + + // 分层内存管理 + PoolInfo _pools[3]; // SMALL, MEDIUM, LARGE + + // 统计信息 + mutable MemoryStats _stats; + + // 预分配的内存区域 + std::vector _preallocated_regions; }; #endif diff --git a/src/allocator/memory_allocator.cpp b/src/allocator/memory_allocator.cpp index 003c01d4..2faea46a 100644 --- a/src/allocator/memory_allocator.cpp +++ b/src/allocator/memory_allocator.cpp @@ -1,22 +1,36 @@ #include "../allocator.hpp" #include "../utils.hpp" +#include +#include +#include -MemoryPool::MemoryPool(size_t initialSize, size_t alignment) - : _alignment(alignment) { +MemoryPool::MemoryPool(size_t initialSize, size_t alignment, const PreallocationConfig& config) + : _alignment(alignment), _config(config) { // Validate alignment is power of two if ((alignment & (alignment - 1)) != 0 || alignment == 0) { throw std::invalid_argument("Alignment must be a power of two"); } + // 预分配内存池 + if (_config.enable_preallocation) { + preAllocate(_config); + } + if (initialSize > 0) { allocateNewRegion(initialSize); } } MemoryPool::~MemoryPool() { + std::lock_guard lock(_mutex); + for (void *region : _base_regions) { RUN_INFINI(infinirtFree(region)); } + + for (void *region : _preallocated_regions) { + RUN_INFINI(infinirtFree(region)); + } } void *MemoryPool::alloc(size_t size) { @@ -24,44 +38,35 @@ void *MemoryPool::alloc(size_t size) { return nullptr; } + std::lock_guard lock(_mutex); + // Calculate aligned size const size_t aligned_size = (size + _alignment - 1) & ~(_alignment - 1); - - // Find the first block with enough space (after alignment) - auto it = _free_blocks.lower_bound(aligned_size); - if (it == _free_blocks.end()) { - allocateNewRegion(aligned_size); - it = _free_blocks.lower_bound(aligned_size); - if (it == _free_blocks.end()) { - throw std::bad_alloc(); - } + + // 确定块类型 + BlockType type = getBlockType(aligned_size); + + // 尝试从对应的池中分配 + void* ptr = allocFromPool(aligned_size, type); + + if (ptr) { + _stats.recordAllocation(aligned_size); + return ptr; } - - auto block_it = it->second; - Block block = *block_it; - _free_blocks.erase(it); - _all_blocks.erase(block_it); - - // Align the pointer within the block - size_t alignment_padding = reinterpret_cast(block.ptr) - reinterpret_cast(block.ptr); - - // Calculate remaining space after allocation - const size_t remaining = block.size - aligned_size - alignment_padding; - - // Create allocated block - Block alloc_block(block.base, block.ptr, aligned_size, false); - auto alloc_it = _all_blocks.insert(alloc_block).first; - _ptr_to_block[block.ptr] = alloc_it; - - // Split remaining space if it's large enough - if (remaining >= _alignment) { - void *rem_ptr = static_cast(block.ptr) + aligned_size; - Block rem_block(block.base, rem_ptr, remaining, true); - auto rem_it = _all_blocks.insert(rem_block).first; - _free_blocks.emplace(remaining, rem_it); + + // 如果池中没有合适的块,分配新区域 + allocateNewRegion(std::max(aligned_size * 2, + type == BlockType::SMALL ? _config.small_pool_size / 4 : + type == BlockType::MEDIUM ? _config.medium_pool_size / 4 : + _config.large_pool_size / 4), type); + + ptr = allocFromPool(aligned_size, type); + if (!ptr) { + throw std::bad_alloc(); } - - return block.ptr; + + _stats.recordAllocation(aligned_size); + return ptr; } void MemoryPool::release(void *ptr) { @@ -69,6 +74,8 @@ void MemoryPool::release(void *ptr) { return; } + std::lock_guard lock(_mutex); + auto it = _ptr_to_block.find(ptr); if (it == _ptr_to_block.end()) { throw std::runtime_error("Invalid pointer to free"); @@ -76,30 +83,94 @@ void MemoryPool::release(void *ptr) { auto block_it = it->second; Block block = *block_it; - _all_blocks.erase(block_it); - block.is_free = true; - auto new_it = _all_blocks.insert(block).first; - _ptr_to_block.erase(ptr); - tryCoalesce(*new_it); + + _stats.recordFree(block.size); + + releaseToPool(ptr, block); } -void *MemoryPool::allocateNewRegion(size_t size) { +void *MemoryPool::allocateNewRegion(size_t size, BlockType type) { // Allocate exactly the requested size void *ptr = nullptr; RUN_INFINI(infinirtMalloc(&ptr, size)); _base_regions.push_back(ptr); // Align the pointer within the allocated region - size_t alignment_padding = reinterpret_cast(ptr) - reinterpret_cast(ptr); + size_t alignment_padding = 0; // ptr is already aligned from infinirtMalloc size_t usable_size = size - alignment_padding; - Block new_block(ptr, ptr, usable_size, true); + Block new_block(ptr, ptr, usable_size, true, type); auto it = _all_blocks.insert(new_block).first; - _free_blocks.emplace(usable_size, it); + + int pool_idx = static_cast(type); + _pools[pool_idx].free_blocks.emplace(usable_size, it); + _pools[pool_idx].total_size += usable_size; return ptr; } +MemoryPool::BlockType MemoryPool::getBlockType(size_t size) const { + if (size <= SMALL_BLOCK_THRESHOLD) { + return BlockType::SMALL; + } else if (size <= MEDIUM_BLOCK_THRESHOLD) { + return BlockType::MEDIUM; + } else { + return BlockType::LARGE; + } +} + +void *MemoryPool::allocFromPool(size_t size, BlockType type) { + int pool_idx = static_cast(type); + auto& pool = _pools[pool_idx]; + + // Find the first block with enough space + auto it = pool.free_blocks.lower_bound(size); + if (it == pool.free_blocks.end()) { + return nullptr; + } + + auto block_it = it->second; + Block block = *block_it; + pool.free_blocks.erase(it); + _all_blocks.erase(block_it); + + // Calculate remaining space after allocation + const size_t remaining = block.size - size; + + // Create allocated block + Block alloc_block(block.base, block.ptr, size, false, type); + auto alloc_it = _all_blocks.insert(alloc_block).first; + _ptr_to_block[block.ptr] = alloc_it; + + pool.used_size += size; + + // Split remaining space if it's large enough + if (remaining >= _alignment) { + void *rem_ptr = static_cast(block.ptr) + size; + Block rem_block(block.base, rem_ptr, remaining, true, type); + auto rem_it = _all_blocks.insert(rem_block).first; + pool.free_blocks.emplace(remaining, rem_it); + } + + return block.ptr; +} + +void MemoryPool::releaseToPool(void *ptr, const Block& block) { + _all_blocks.erase(_ptr_to_block[ptr]); + _ptr_to_block.erase(ptr); + + Block free_block = block; + free_block.is_free = true; + free_block.last_used = std::chrono::steady_clock::now(); + + auto new_it = _all_blocks.insert(free_block).first; + + int pool_idx = static_cast(block.type); + _pools[pool_idx].used_size -= block.size; + + tryCoalesce(*new_it); +} + void MemoryPool::tryCoalesce(const Block &block) { auto it = _all_blocks.find(block); if (it == _all_blocks.end()) { @@ -111,25 +182,203 @@ void MemoryPool::tryCoalesce(const Block &block) { auto prev = (it == _all_blocks.begin()) ? _all_blocks.end() : std::prev(it); _all_blocks.erase(it); - _free_blocks.erase(merged.size); + + int pool_idx = static_cast(merged.type); + auto& pool = _pools[pool_idx]; + + // Remove from free blocks + auto range = pool.free_blocks.equal_range(merged.size); + for (auto free_it = range.first; free_it != range.second; ++free_it) { + if (free_it->second->ptr == merged.ptr) { + pool.free_blocks.erase(free_it); + break; + } + } // Coalesce with next - if (next != _all_blocks.end() && next->is_free && static_cast(merged.ptr) + merged.size == next->ptr) { - _free_blocks.erase(next->size); + if (next != _all_blocks.end() && next->is_free && + next->type == merged.type && + static_cast(merged.ptr) + merged.size == next->ptr) { + + // Remove next from free blocks + auto next_range = pool.free_blocks.equal_range(next->size); + for (auto free_it = next_range.first; free_it != next_range.second; ++free_it) { + if (free_it->second->ptr == next->ptr) { + pool.free_blocks.erase(free_it); + break; + } + } + merged.size += next->size; _all_blocks.erase(next); + _stats.recordFragmentation(); } // Coalesce with prev - if (prev != _all_blocks.end() && prev->is_free && static_cast(prev->ptr) + prev->size == merged.ptr) { - _free_blocks.erase(prev->size); + if (prev != _all_blocks.end() && prev->is_free && + prev->type == merged.type && + static_cast(prev->ptr) + prev->size == merged.ptr) { + + // Remove prev from free blocks + auto prev_range = pool.free_blocks.equal_range(prev->size); + for (auto free_it = prev_range.first; free_it != prev_range.second; ++free_it) { + if (free_it->second->ptr == prev->ptr) { + pool.free_blocks.erase(free_it); + break; + } + } + merged.ptr = prev->ptr; merged.size += prev->size; merged.base = prev->base; _all_blocks.erase(prev); + _stats.recordFragmentation(); } merged.is_free = true; auto new_it = _all_blocks.insert(merged).first; - _free_blocks.emplace(merged.size, new_it); + pool.free_blocks.emplace(merged.size, new_it); +} + +void MemoryPool::preAllocate(const PreallocationConfig& config) { + if (config.small_pool_size > 0) { + void* ptr = nullptr; + RUN_INFINI(infinirtMalloc(&ptr, config.small_pool_size)); + _preallocated_regions.push_back(ptr); + + Block small_block(ptr, ptr, config.small_pool_size, true, BlockType::SMALL); + auto it = _all_blocks.insert(small_block).first; + _pools[0].free_blocks.emplace(config.small_pool_size, it); + _pools[0].total_size += config.small_pool_size; + } + + if (config.medium_pool_size > 0) { + void* ptr = nullptr; + RUN_INFINI(infinirtMalloc(&ptr, config.medium_pool_size)); + _preallocated_regions.push_back(ptr); + + Block medium_block(ptr, ptr, config.medium_pool_size, true, BlockType::MEDIUM); + auto it = _all_blocks.insert(medium_block).first; + _pools[1].free_blocks.emplace(config.medium_pool_size, it); + _pools[1].total_size += config.medium_pool_size; + } + + if (config.large_pool_size > 0) { + void* ptr = nullptr; + RUN_INFINI(infinirtMalloc(&ptr, config.large_pool_size)); + _preallocated_regions.push_back(ptr); + + Block large_block(ptr, ptr, config.large_pool_size, true, BlockType::LARGE); + auto it = _all_blocks.insert(large_block).first; + _pools[2].free_blocks.emplace(config.large_pool_size, it); + _pools[2].total_size += config.large_pool_size; + } +} + +void MemoryPool::defragment() { + std::lock_guard lock(_mutex); + + for (int i = 0; i < 3; ++i) { + compactPool(static_cast(i)); + } +} + +void MemoryPool::compactPool(BlockType type) { + int pool_idx = static_cast(type); + auto& pool = _pools[pool_idx]; + + // 收集所有空闲块 + std::vector::iterator> free_blocks; + for (auto& pair : pool.free_blocks) { + free_blocks.push_back(pair.second); + } + + // 按地址排序 + std::sort(free_blocks.begin(), free_blocks.end(), + [](const auto& a, const auto& b) { + return a->ptr < b->ptr; + }); + + // 尝试合并相邻的空闲块 + for (size_t i = 0; i < free_blocks.size(); ++i) { + auto current = free_blocks[i]; + if (current == _all_blocks.end()) continue; + + for (size_t j = i + 1; j < free_blocks.size(); ++j) { + auto next = free_blocks[j]; + if (next == _all_blocks.end()) continue; + + if (static_cast(current->ptr) + current->size == next->ptr) { + // 合并块 + Block merged = *current; + merged.size += next->size; + + // 从池中移除旧块 + pool.free_blocks.erase(current->size); + pool.free_blocks.erase(next->size); + + _all_blocks.erase(current); + _all_blocks.erase(next); + + // 插入新的合并块 + auto new_it = _all_blocks.insert(merged).first; + pool.free_blocks.emplace(merged.size, new_it); + + free_blocks[i] = new_it; + free_blocks[j] = _all_blocks.end(); + + current = new_it; + } + } + } +} + +bool MemoryPool::shouldDefragment() const { + return _stats.getFragmentationRate() > 0.3; // 30%碎片率阈值 +} + +void MemoryPool::printStats() const { + std::cout << "\n=== Memory Pool Statistics ===\n"; + std::cout << "Total Allocated: " << _stats.total_allocated.load() / (1024.0 * 1024.0) << " MB\n"; + std::cout << "Total Freed: " << _stats.total_freed.load() / (1024.0 * 1024.0) << " MB\n"; + std::cout << "Current Usage: " << _stats.current_usage.load() / (1024.0 * 1024.0) << " MB\n"; + std::cout << "Peak Usage: " << _stats.peak_usage.load() / (1024.0 * 1024.0) << " MB\n"; + std::cout << "Allocation Count: " << _stats.allocation_count.load() << "\n"; + std::cout << "Free Count: " << _stats.free_count.load() << "\n"; + std::cout << "Fragmentation Rate: " << std::fixed << std::setprecision(2) + << _stats.getFragmentationRate() * 100 << "%\n"; + + std::cout << "\n=== Pool Information ===\n"; + const char* pool_names[] = {"Small", "Medium", "Large"}; + for (int i = 0; i < 3; ++i) { + std::cout << pool_names[i] << " Pool:\n"; + std::cout << " Total Size: " << _pools[i].total_size / (1024.0 * 1024.0) << " MB\n"; + std::cout << " Used Size: " << _pools[i].used_size / (1024.0 * 1024.0) << " MB\n"; + std::cout << " Free Blocks: " << _pools[i].free_blocks.size() << "\n"; + std::cout << " Utilization: " << std::fixed << std::setprecision(2) + << (_pools[i].total_size > 0 ? + static_cast(_pools[i].used_size) / _pools[i].total_size * 100 : 0) + << "%\n\n"; + } +} + +size_t MemoryPool::getTotalMemory() const { + std::lock_guard lock(_mutex); + size_t total = 0; + for (int i = 0; i < 3; ++i) { + total += _pools[i].total_size; + } + return total; +} + +size_t MemoryPool::getUsedMemory() const { + return _stats.current_usage.load(); +} + +size_t MemoryPool::getFreeMemory() const { + return getTotalMemory() - getUsedMemory(); +} + +double MemoryPool::getFragmentationRatio() const { + return _stats.getFragmentationRate(); } diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index bafe784e..c48714e6 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -42,7 +42,14 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, getFFNDown(meta, weights, layer, idev, ndev)); } - auto memory_pool = std::make_shared(128 * 1024 * 1024); + // 配置内存池预分配策略 + MemoryPool::PreallocationConfig pool_config; + pool_config.small_pool_size = 32 * 1024 * 1024; // 32MB for small allocations + pool_config.medium_pool_size = 64 * 1024 * 1024; // 64MB for medium allocations + pool_config.large_pool_size = 128 * 1024 * 1024; // 128MB for large allocations + + auto memory_pool = std::make_shared(128 * 1024 * 1024, MemoryPool::DEFAULT_ALIGNMENT, pool_config); + *rsrc = DeviceResource{ device, @@ -118,6 +125,16 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output, void *last_logits) { + + // 推理开始前检查内存状态 + static int inference_count = 0; + inference_count++; + + // 每10次推理检查一次碎片率,必要时进行碎片整理 + if (inference_count % 10 == 0 && rsrc.memory_pool->shouldDefragment()) { + rsrc.memory_pool->defragment(); + } + auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; auto nh = meta.nh / ndev; @@ -131,6 +148,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; + // sparse attention + auto ratio = 0.2; + int attentionSinkWindow = 4; + bool sparseOn = true; + // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); @@ -213,19 +235,40 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); // self attention - // concat - rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); - rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); - // qk - rearrange(q_rearrange->slice(2, 0, seq_len), q); - auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); - auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); - linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); - // softmax - auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); - causalSoftmax(qk_softmax, qk_softmax); - auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); - linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + int recentWindow = (int) ((seq_len - attentionSinkWindow) * ratio); + bool prune = (past_len == 0) && sparseOn && recentWindow >= 1; + if (prune) { // streamingLLM + // concat attentionSinkWindow + recentWindow to kv cache + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, attentionSinkWindow), k->slice(0,0,attentionSinkWindow)); + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len + attentionSinkWindow, recentWindow), k->slice(0,seq_len - recentWindow,recentWindow)); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, attentionSinkWindow), v->slice(0,0,attentionSinkWindow)); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len + attentionSinkWindow, recentWindow), v->slice(0,seq_len - recentWindow,recentWindow)); + // qk + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = k->permute({1, 2, 0}); //K^T + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + auto v_gemm = v->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + } else { // baseline + // concat full kv to kv cache + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); + // qk + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); //^T + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + } + // rearrange attn val rearrange(o, attn_val_gemm->slice(2, 0, seq_len));