diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d19485e83e..bdc628e272 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1042,7 +1042,8 @@ def sleep(self, level: int = 1): self.cache_engine = None self.reset_graph_runner() device = 'cpu' if level == 1 else 'meta' - self.patched_model.get_model().to(device=device) + self.patched_model.get_model().to(device=device, non_blocking=True) + torch.cuda.synchronize() torch.cuda.empty_cache() @torch.inference_mode() diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index eed7abe630..afa1c6be61 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -332,6 +332,8 @@ def __init__(self, else: raise ValueError(f'unsupported backend {backend}') self.backend_config = self.engine.engine_config + self.is_sleeping = backend_config.empty_init + self.sleeping_tags: set[str] = set() if not backend_config.empty_init else {'weights', 'kv_cache'} logger.info(f'updated backend_config={self.backend_config}') # parameters for member functions @@ -477,6 +479,8 @@ def sleep(self, level: int = 1): discard both the model weights and the kv cache. """ self.engine.sleep(level) + self.sleeping_tags = {'weights', 'kv_cache'} + self.is_sleeping = True def wakeup(self, tags: Optional[List[str]] = None): """Wake up the model. @@ -488,11 +492,17 @@ def wakeup(self, tags: Optional[List[str]] = None): wake_up should be called with all tags (or None) before the engine is used again. """ + tags = tags or list(self.sleeping_tags) + if any(tag not in self.sleeping_tags for tag in tags): + logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}') + return self.engine.wakeup(tags) # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instance - if self.backend == 'turbomind' and (tags is None or 'kv_cache' in tags): + if self.backend == 'turbomind' and 'kv_cache' in tags: self.instances = [self.engine.create_instance() for _ in range(self.instance_num)] self.free_insts = None + self.sleeping_tags = self.sleeping_tags - set(tags) + self.is_sleeping = bool(self.sleeping_tags) def _get_limiter(self): if not self.limiter: diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 75472399fd..8ca922d362 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1105,6 +1105,12 @@ async def wakeup(raw_request: Request = None): return Response(status_code=200) +@router.get('/is_sleeping', dependencies=[Depends(check_api_key)]) +async def is_sleeping(raw_request: Request = None): + is_sleeping = VariableInterface.async_engine.is_sleeping + return JSONResponse(content={'is_sleeping': is_sleeping}) + + """ PD Disaggregation API Begin """ diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 173a95e391..052a900778 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -213,6 +213,7 @@ def _get_model_params(self): model_comm = self.model_comm tm_params = self._tm_model.tm_params + tm_params.clear() def _get_params(device_id, que): rank = self.node_id * self.gpu_count + device_id diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 8514c92e2c..f46124798a 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -116,6 +116,7 @@ void LlamaWeight::release() } decoder_layer_weights.clear(); + pinned_weights_.clear(); // Wait for deallocations core::Context::stream().Sync(); @@ -127,21 +128,22 @@ void LlamaWeight::release() void LlamaWeight::to_device(const core::Device& device) { - core::ContextGuard guard = context(); - - auto to_device = [&](Tensor& x) -> Tensor { - auto tmp = std::exchange(x, empty_like(x, device)); - Copy(tmp, x); - return tmp; - }; - - std::vector tmp_cpu_tensors; + TM_CHECK(device.type == kCPU || device.type == kDEVICE); + core::ContextGuard guard{stream_, alloca_, Allocator{kCPUpinned}}; auto tensor_ptr_map = get_parameters(); for (auto& [name, tensor_ptr] : tensor_ptr_map) { - auto tmp_tensor = to_device(*tensor_ptr); - if (tmp_tensor.device().type != kDEVICE) { - tmp_cpu_tensors.push_back(tmp_tensor); + if (device.type == kCPU) { + if (pinned_weights_.find(name) == pinned_weights_.end()) { + pinned_weights_[name] = empty_like(*tensor_ptr, kCPUpinned); + Copy(*tensor_ptr, pinned_weights_[name]); + } + *tensor_ptr = {}; + } + else { + TM_CHECK(pinned_weights_.find(name) != pinned_weights_.end()); + *tensor_ptr = empty_like(pinned_weights_[name], kDEVICE); + Copy(pinned_weights_[name], *tensor_ptr); } } core::Context::stream().Sync(); diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index 824bc7e10a..22d71be138 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -20,6 +20,8 @@ #pragma once +#include + #include "src/turbomind/core/context.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" @@ -75,6 +77,8 @@ struct LlamaWeight: core::Module { DataType data_type_; DataType weight_type_; + std::unordered_map pinned_weights_; + int tp_size_; // this will follow attn tp param int tp_rank_;