diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index 58e8c70054..2527507d33 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -13,14 +13,15 @@ def get_launching_server_cmd(model_path, backend, server_config): elif backend == 'sglang': cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path] elif backend == 'vllm': - cmd = ['vllm', 'serve', '--model', model_path] + cmd = ['vllm', 'serve', model_path] else: raise ValueError(f'unknown backend: {backend}') for key, value in server_config.items(): # Convert snake_case to kebab-case for command line args key = key.replace('_', '-') cmd.append(f'--{key}') - cmd.append(str(value)) + if str(value): + cmd.append(str(value)) # Special handling for proxy server case if server_config.get('proxy_url') and server_config.get('dp'): cmd.append('--allow-terminate-by-client') @@ -66,9 +67,9 @@ def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]: server_ip = server_config.get('server_ip', '0.0.0.0') server_port = server_config.get('server_port', 23333) elif backend == 'sglang': - return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 30000)) + return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000)) elif backend == 'vllm': - return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 8000)) + return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000)) else: raise ValueError(f'unknown backend: {backend}') return server_ip, server_port @@ -131,7 +132,7 @@ def benchmark(model_path: str, backend: str, server_config: Dict, data_config: D try: - print(f"Starting api_server: {' '.join(server_cmd)}") + print(f"Starting api_server: {' '.join(server_cmd)}", flush=True) proc = subprocess.Popen(server_cmd) # Wait for the server to be ready wait_server_ready(server_ip, server_port) diff --git a/docs/en/advance/spec_decoding.md b/docs/en/advance/spec_decoding.md new file mode 100644 index 0000000000..49d813d775 --- /dev/null +++ b/docs/en/advance/spec_decoding.md @@ -0,0 +1,104 @@ +# Speculative Decoding + +Speculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once. + +> \[!NOTE\] +> This is an experimental feature in lmdeploy. + +## Examples + +Here are some examples. + +### Eagle 3 + +#### Prepare + +Install [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) + +```shell +git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git +cd flash-attention/hopper +python setup.py install +``` + +#### pipeline + +```python +from lmdeploy import pipeline, PytorchEngineConfig +from lmdeploy.messages import SpeculativeConfig + + +if __name__ == '__main__': + + model_path = 'meta-llama/Llama-3.1-8B-Instruct' + spec_cfg = SpeculativeConfig(method='eagle3', + num_speculative_tokens=3, + model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', + ) + pipe = pipeline(model_path, + backend_config=PytorchEngineConfig(max_batch_size=128), + speculative_config=spec_cfg) + response = pipe(['Hi, pls intro yourself', 'Shanghai is']) + print(response) +``` + +### serving + +```shell +lmdeploy serve api_server \ +meta-llama/Llama-3.1-8B-Instruct \ +--backend pytorch \ +--server-port 24545 \ +--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \ +--speculative-algorithm eagle3 \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 \ +--enable-metrics +``` + +### Deepseek MTP + +#### Prepare + +Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation) + +```shell +git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla +cd flash-mla +git submodule update --init --recursive +pip install -v . +``` + +#### pipeline + +```python +from lmdeploy import pipeline, PytorchEngineConfig +from lmdeploy.messages import SpeculativeConfig + + +if __name__ == '__main__': + + model_path = 'deepseek-ai/DeepSeek-V3' + spec_cfg = SpeculativeConfig(method='deepseek_mtp', + num_speculative_tokens=3, + ) + pipe = pipeline(model_path, + backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), + speculative_config=spec_cfg) + response = pipe(['Hi, pls intro yourself', 'Shanghai is']) + print(response) +``` + +### serving + +```shell +lmdeploy serve api_server \ +deepseek-ai/DeepSeek-V3 \ +--backend pytorch \ +--server-port 24545 \ +--tp 16 \ +--speculative-algorithm deepseek_mtp \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 \ +--enable-metrics +``` diff --git a/docs/en/index.rst b/docs/en/index.rst index b64c230cb8..b346be6342 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -103,6 +103,7 @@ Documentation advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/spec_decoding.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/advance/spec_decoding.md b/docs/zh_cn/advance/spec_decoding.md new file mode 100644 index 0000000000..56907f0e5b --- /dev/null +++ b/docs/zh_cn/advance/spec_decoding.md @@ -0,0 +1,104 @@ +# Speculative Decoding + +推测解码是一种优化技术,它通过引入轻量级草稿模型来预测多个后续token,再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比,这种方法可使系统一次性生成多个token。 + +> \[!NOTE\] +> 请注意,这是lmdeploy中的实验性功能。 + +## 示例 + +请参考如下使用示例。 + +### Eagle 3 + +#### 安装依赖 + +安装 [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) + +```shell +git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git +cd flash-attention/hopper +python setup.py install +``` + +#### pipeline + +```python +from lmdeploy import pipeline, PytorchEngineConfig +from lmdeploy.messages import SpeculativeConfig + + +if __name__ == '__main__': + + model_path = 'meta-llama/Llama-3.1-8B-Instruct' + spec_cfg = SpeculativeConfig(method='eagle3', + num_speculative_tokens=3, + model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', + ) + pipe = pipeline(model_path, + backend_config=PytorchEngineConfig(max_batch_size=128), + speculative_config=spec_cfg) + response = pipe(['Hi, pls intro yourself', 'Shanghai is']) + print(response) +``` + +### serving + +```shell +lmdeploy serve api_server \ +meta-llama/Llama-3.1-8B-Instruct \ +--backend pytorch \ +--server-port 24545 \ +--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \ +--speculative-algorithm eagle3 \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 \ +--enable-metrics +``` + +### Deepseek MTP + +#### 安装依赖 + +Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation) + +```shell +git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla +cd flash-mla +git submodule update --init --recursive +pip install -v . +``` + +#### pipeline + +```python +from lmdeploy import pipeline, PytorchEngineConfig +from lmdeploy.messages import SpeculativeConfig + + +if __name__ == '__main__': + + model_path = 'deepseek-ai/DeepSeek-V3' + spec_cfg = SpeculativeConfig(method='deepseek_mtp', + num_speculative_tokens=3, + ) + pipe = pipeline(model_path, + backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), + speculative_config=spec_cfg) + response = pipe(['Hi, pls intro yourself', 'Shanghai is']) + print(response) +``` + +### serving + +```shell +lmdeploy serve api_server \ +deepseek-ai/DeepSeek-V3 \ +--backend pytorch \ +--server-port 24545 \ +--tp 16 \ +--speculative-algorithm deepseek_mtp \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 \ +--enable-metrics +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index bd946ba96e..be0d3cc788 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能: advance/pytorch_multinodes.md advance/pytorch_profiling.md advance/metrics.md + advance/spec_decoding.md .. toctree:: :maxdepth: 1 diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 4085c3ba03..dc818e4801 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -3,7 +3,7 @@ from typing import List, Literal, Optional, Union from .archs import autoget_backend_config, get_task -from .messages import PytorchEngineConfig, TurbomindEngineConfig +from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig from .model import ChatTemplateConfig @@ -12,6 +12,7 @@ def pipeline(model_path: str, chat_template_config: Optional[ChatTemplateConfig] = None, log_level: str = 'WARNING', max_log_len: int = None, + speculative_config: SpeculativeConfig = None, **kwargs): """ Args: @@ -68,6 +69,12 @@ def pipeline(model_path: str, if backend_config is not None else None model_path = get_model(model_path, download_dir, revision) + # spec model + if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model): + download_dir = backend_config.download_dir \ + if backend_config is not None else None + speculative_config.model = get_model(speculative_config.model, download_dir) + _, pipeline_class = get_task(model_path) if not isinstance(backend_config, PytorchEngineConfig): # set auto backend mode @@ -80,6 +87,7 @@ def pipeline(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + speculative_config=speculative_config, **kwargs) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index d71198791f..1e635a8df8 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -3,7 +3,8 @@ import os from ..version import __version__ -from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args +from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args, + get_speculative_config) class CLI(object): @@ -44,12 +45,12 @@ def add_parser_chat(): ', "baichuan-inc/baichuan2-7b-chat" and so on') # common args ArgumentHelper.backend(parser) - # # chat template args + # chat template args ArgumentHelper.chat_template(parser) # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) - # + # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.adapters(pt_group) @@ -77,6 +78,9 @@ def add_parser_chat(): ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.communicator(tb_group) + # speculative decoding + ArgumentHelper.add_spec_group(parser) + @staticmethod def add_parser_checkenv(): """Add parser for check_env command.""" @@ -168,7 +172,13 @@ def get_gpu_topo(): @staticmethod def chat(args): from .chat import main + kwargs = convert_args(args) + speculative_config = get_speculative_config(args) + to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens'] + for key in to_remove: + kwargs.pop(key) + kwargs['speculative_config'] = speculative_config main(**kwargs) @staticmethod diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 6a9e9f2b13..9cf36608b3 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -3,7 +3,8 @@ from lmdeploy.utils import get_max_batch_size from .cli import CLI -from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters +from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters, + get_speculative_config) class SubCliServe: @@ -144,6 +145,9 @@ def add_parser_api_server(): vision_group = parser.add_argument_group('Vision model arguments') ArgumentHelper.vision_max_batch_size(vision_group) + # spec decode + ArgumentHelper.add_spec_group(parser) + @staticmethod def add_parser_proxy(): """Add parser for proxy server command.""" @@ -247,61 +251,68 @@ def api_server(args): enable_metrics=args.enable_metrics, hf_overrides=args.hf_overrides) chat_template_config = get_chat_template(args.chat_template, args.model_path) + speculative_config = get_speculative_config(args) from lmdeploy.messages import VisionConfig vision_config = VisionConfig(args.vision_max_batch_size) if args.dp == 1: from lmdeploy.serve.openai.api_server import serve as run_api_server - run_api_server(args.model_path, - model_name=args.model_name, - backend=backend, - backend_config=backend_config, - chat_template_config=chat_template_config, - vision_config=vision_config, - server_name=args.server_name, - server_port=args.server_port, - allow_origins=args.allow_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allow_methods, - allow_headers=args.allow_headers, - allow_terminate_by_client=args.allow_terminate_by_client, - log_level=args.log_level.upper(), - api_keys=args.api_keys, - ssl=args.ssl, - proxy_url=args.proxy_url, - max_log_len=args.max_log_len, - disable_fastapi_docs=args.disable_fastapi_docs, - max_concurrent_requests=args.max_concurrent_requests, - reasoning_parser=args.reasoning_parser, - tool_call_parser=args.tool_call_parser) + run_api_server( + args.model_path, + model_name=args.model_name, + backend=backend, + backend_config=backend_config, + chat_template_config=chat_template_config, + vision_config=vision_config, + server_name=args.server_name, + server_port=args.server_port, + allow_origins=args.allow_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allow_methods, + allow_headers=args.allow_headers, + allow_terminate_by_client=args.allow_terminate_by_client, + log_level=args.log_level.upper(), + api_keys=args.api_keys, + ssl=args.ssl, + proxy_url=args.proxy_url, + max_log_len=args.max_log_len, + disable_fastapi_docs=args.disable_fastapi_docs, + max_concurrent_requests=args.max_concurrent_requests, + reasoning_parser=args.reasoning_parser, + tool_call_parser=args.tool_call_parser, + speculative_config=speculative_config, + ) else: from lmdeploy.serve.openai.launch_server import launch_server - launch_server(args.nnodes, - args.node_rank, - args.model_path, - model_name=args.model_name, - backend=backend, - backend_config=backend_config, - chat_template_config=chat_template_config, - vision_config=vision_config, - server_name=args.server_name, - server_port=args.server_port, - allow_origins=args.allow_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allow_methods, - allow_headers=args.allow_headers, - allow_terminate_by_client=args.allow_terminate_by_client, - log_level=args.log_level.upper(), - api_keys=args.api_keys, - ssl=args.ssl, - proxy_url=args.proxy_url, - max_log_len=args.max_log_len, - disable_fastapi_docs=args.disable_fastapi_docs, - max_concurrent_requests=args.max_concurrent_requests, - reasoning_parser=args.reasoning_parser, - tool_call_parser=args.tool_call_parser) + launch_server( + args.nnodes, + args.node_rank, + args.model_path, + model_name=args.model_name, + backend=backend, + backend_config=backend_config, + chat_template_config=chat_template_config, + vision_config=vision_config, + server_name=args.server_name, + server_port=args.server_port, + allow_origins=args.allow_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allow_methods, + allow_headers=args.allow_headers, + allow_terminate_by_client=args.allow_terminate_by_client, + log_level=args.log_level.upper(), + api_keys=args.api_keys, + ssl=args.ssl, + proxy_url=args.proxy_url, + max_log_len=args.max_log_len, + disable_fastapi_docs=args.disable_fastapi_docs, + max_concurrent_requests=args.max_concurrent_requests, + reasoning_parser=args.reasoning_parser, + tool_call_parser=args.tool_call_parser, + speculative_config=speculative_config, + ) @staticmethod def proxy(args): diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index bfd94182d0..0e2d88b88a 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -99,6 +99,19 @@ def get_chat_template(chat_template: str, model_path: str = None): return None +def get_speculative_config(args): + """Get speculative config from args.""" + from lmdeploy.messages import SpeculativeConfig + speculative_config = None + if args.speculative_algorithm is not None: + speculative_config = SpeculativeConfig( + method=args.speculative_algorithm, + model=args.speculative_draft_model, + num_speculative_tokens=args.speculative_num_draft_tokens, + ) + return speculative_config + + class ArgumentHelper: """Helper class to add unified argument.""" @@ -654,6 +667,26 @@ def dllm_confidence_threshold(parser): default=0.85, help='The confidence threshold for dllm.') + def add_spec_group(parser): + spec_group = parser.add_argument_group('Speculative decoding arguments') + spec_group.add_argument('--speculative-algorithm', + type=str, + default=None, + choices=['eagle', 'eagle3', 'deepseek_mtp'], + help='The speculative algorithm to use. `None` means speculative decoding is disabled') + + spec_group.add_argument('--speculative-draft-model', + type=str, + default=None, + help='The path to speculative draft model') + + spec_group.add_argument('--speculative-num-draft-tokens', + type=int, + default=1, + help='The number of speculative tokens to generate per step') + + return spec_group + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90a38a7f6a..35487ee908 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -522,6 +522,7 @@ class RequestMetrics: """ token_timestamp: float = 0.0 engine_events: List[EngineEvent] = field(default_factory=list) + spec_info: Optional[Dict[str, Any]] = None @dataclass @@ -560,3 +561,17 @@ class VisionConfig: """ max_batch_size: int = 1 thread_safe: bool = False + + +@dataclass +class SpeculativeConfig: + """Speculative decoding config. + + Args: + method (str): the speculative decoding method. + model (str): the path of speculative model. + num_speculative_tokens (int): number of generated token of draft model per step + """ + method: str + model: str = '' + num_speculative_tokens: int = 1 diff --git a/lmdeploy/metrics/loggers.py b/lmdeploy/metrics/loggers.py index 730f1d6d00..137f4d6459 100644 --- a/lmdeploy/metrics/loggers.py +++ b/lmdeploy/metrics/loggers.py @@ -6,7 +6,9 @@ from datetime import datetime from typing import List -from lmdeploy.metrics.stats import FinishedRequestStats, IterationStats, SchedulerStats +import numpy as np + +from lmdeploy.metrics.stats import FinishedRequestStats, IterationStats, SchedulerStats, SpeculativeDecodingStats from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -22,6 +24,10 @@ def record_schedule(self, stats: SchedulerStats) -> None: def record_iteration(self, stats: IterationStats) -> None: ... + @abstractmethod + def record_specdecode(self, stats: SpeculativeDecodingStats) -> None: + ... + def log(self): # noqa pass @@ -37,6 +43,11 @@ def _reset(self, now): self.last_log_time = now self.total_prompt_tokens = 0 self.total_generation_tokens = 0 + # spec decode + self.num_drafts: int = 0 + self.num_draft_tokens: int = 0 + self.num_accepted_tokens: int = 0 + self.num_accepted_tokens_per_pos: np.ndarray = None def record_schedule(self, stats: SchedulerStats): self.last_scheduler_stats = stats @@ -48,9 +59,42 @@ def record_iteration(self, stats: IterationStats): self.total_prompt_tokens += stats.prompt_tokens self.total_generation_tokens += stats.new_generation_tokens + def record_specdecode(self, stats: SpeculativeDecodingStats): + """Record spec decoding stats.""" + if stats.num_drafts <= 0: + return + if self.num_accepted_tokens_per_pos is None: + self.num_accepted_tokens_per_pos = np.zeros(stats.num_spec_tokens) + self.num_drafts += stats.num_drafts + self.num_draft_tokens += stats.num_draft_tokens + self.num_accepted_tokens += stats.num_accepted_tokens + self.num_accepted_tokens_per_pos += stats.num_accepted_tokens_per_pos + def record_finish(self, stats: FinishedRequestStats): pass + def _get_log_spec_msg(self): + """Get spec decoding logging msg.""" + if self.num_drafts == 0: + return '' + + draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens * + 100 if self.num_draft_tokens > 0 else float('nan')) + + # Conventionally, mean acceptance length includes the bonus token + mean_acceptance_length = 1 + (self.num_accepted_tokens / self.num_drafts) + + acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts + rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates) + + log_msg = ('SpecDecoding metrics: ' + f'Draft acceptance rate: {draft_acceptance_rate:.2f}%, ' + f'Mean acceptance length: {mean_acceptance_length:.2f}, ' + f'Accepted: {self.num_accepted_tokens} tokens, ' + f'Drafted: {self.num_draft_tokens} tokens, ' + f'Per-position acceptance rate: {rates_str}') + return log_msg + def log(self): now = time.perf_counter() if self.total_prompt_tokens == 0 and self.total_generation_tokens == 0: @@ -61,6 +105,7 @@ def log(self): prompt_throughput = self.total_prompt_tokens / (now - self.last_log_time) generation_throughput = self.total_generation_tokens / (now - self.last_log_time) + spec_log_msg = self._get_log_spec_msg() self._reset(now) scheduler_stats = self.last_scheduler_stats @@ -74,7 +119,9 @@ def log(self): f'Running: {scheduler_stats.num_running_reqs} reqs, ' f'Waiting: {scheduler_stats.num_waiting_reqs} reqs, ' f'GPU KV cache usage: {scheduler_stats.gpu_cache_usage * 100 :.1f}%') - print(log_msg) + print(log_msg, flush=True) + if spec_log_msg: + print(spec_log_msg, flush=True) class PrometheusStatLogger(StatLoggerBase): @@ -282,6 +329,9 @@ def record_finish(self, stats: FinishedRequestStats) -> None: self.histogram_num_prompt_tokens_request.observe(stats.prompt_tokens) self.histogram_num_generation_tokens_request.observe(stats.generation_tokens) + def record_specdecode(self, stats: SpeculativeDecodingStats) -> None: + pass + def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: """Builds a list of buckets with increasing powers of 10 multiplied by diff --git a/lmdeploy/metrics/metrics_processor.py b/lmdeploy/metrics/metrics_processor.py index 2b87c0214b..ae18e0983b 100644 --- a/lmdeploy/metrics/metrics_processor.py +++ b/lmdeploy/metrics/metrics_processor.py @@ -119,7 +119,7 @@ async def _run_metrics_handler(self): try: # fetch update_data = await self.metrics_queue.get() - outputs, req_state, iteration_stats = update_data + outputs, req_state, iteration_stats, specdecode_stats = update_data # update request state according the engine events req_state.update_from_events(outputs.req_metrics.engine_events) @@ -128,9 +128,15 @@ async def _run_metrics_handler(self): # some attributes of req_state will also be updated, e.g., lastest_token_time iteration_stats.update_from_output(outputs, req_state) + # spec decode + if specdecode_stats is not None: + specdecode_stats.update_from_output(outputs) + # record iteration stats for stat_logger in self.stat_loggers: stat_logger.record_iteration(iteration_stats) + if specdecode_stats is not None: + stat_logger.record_specdecode(specdecode_stats) if outputs.status == ResponseType.FINISH: # record finished request stats @@ -153,7 +159,6 @@ async def udpate_schedule_stats(self, schedule_metrics: ScheduleMetrics): def queue_update(self, update_data: tuple): if not is_metrics_enabled() or self.metrics_queue is None: return - self.metrics_queue.put_nowait(update_data) def increment_total_requests(self): diff --git a/lmdeploy/metrics/stats.py b/lmdeploy/metrics/stats.py index cb4be2201c..0d6962a824 100644 --- a/lmdeploy/metrics/stats.py +++ b/lmdeploy/metrics/stats.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import List, Optional +import numpy as np + from lmdeploy.messages import EngineEvent, EngineOutput, ResponseType, ScheduleMetrics @@ -218,3 +220,59 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState): if outputs.status != ResponseType.SUCCESS: req_state.finish_reason = outputs.status req_state.finish_time = self.iteration_timestamp + req_state.generation_tokens = outputs.num_token + + +# modify from vllm +@dataclass +class SpeculativeDecodingStats: + """Speculative decoding stats.""" + + num_spec_tokens: int + num_drafts: int = 0 + num_draft_tokens: int = 0 + num_accepted_tokens: int = 0 + num_accepted_tokens_per_pos: np.ndarray = None + + def __post_init__(self): + assert self.num_spec_tokens > 0 + self.num_accepted_tokens_per_pos = np.zeros(self.num_spec_tokens) + + def update_from_output(self, outputs: EngineOutput): + """Update from engine output.""" + if spec_info := getattr(outputs.req_metrics, 'spec_info', None): + self.num_drafts += 1 + self.num_draft_tokens += spec_info['num_draft_tokens'] + self.num_accepted_tokens += spec_info['num_accepted_tokens'] + self.num_accepted_tokens_per_pos[:spec_info['num_accepted_tokens']] += 1 + + def update_per_draft(self, num_draft_tokens: int, num_accepted_tokens: int): + """Update with per draft stats.""" + if num_draft_tokens > 0: + self.num_drafts += 1 + self.num_draft_tokens += num_draft_tokens + self.num_accepted_tokens += num_accepted_tokens + self.num_accepted_tokens_per_pos[:num_accepted_tokens] += 1 + + def __repr__(self): + """Return a human-readable string representation.""" + draft_acceptance_rate = (self.num_accepted_tokens / self.num_draft_tokens * + 100 if self.num_draft_tokens > 0 else float('nan')) + + # Conventionally, mean acceptance length includes the bonus token + mean_acceptance_length = 1 + (self.num_accepted_tokens / + self.num_drafts) if self.num_drafts > 0 else float('nan') + + acceptance_rates = self.num_accepted_tokens_per_pos / self.num_drafts if self.num_drafts > 0 else [ + float('nan') + ] * self.num_accepted_tokens + rates_str = ', '.join(f'{p:.3f}' for p in acceptance_rates) + + return ('SpeculativeDecodingStats(' + f'num_spec_tokens={self.num_spec_tokens}, ' + f'num_drafts={self.num_drafts}, ' + f'num_draft_tokens={self.num_draft_tokens}, ' + f'num_accepted_tokens={self.num_accepted_tokens}, ' + f'draft_acceptance_rate={draft_acceptance_rate:.2f}%, ' + f'mean_acceptance_length={mean_acceptance_length:.2f}, ' + f'per_position_acceptance_rate={rates_str})') diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index b241c384b2..9752165def 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -42,6 +42,9 @@ class TritonAttentionMetadata(AttentionMetadata): num_splits: torch.Tensor = None cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None + # flash attn + scheduler_metadata: torch.Tensor = None + max_kv_seqlen: int = None def _cdiv(a, b): @@ -291,10 +294,11 @@ def forward( kv_seqlens = attn_metadata.kv_seqlens kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + batch_size = q_seqlens.size(0) if attn_metadata.is_decoding: - max_q_seqlen = 1 - else: - max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + max_q_seqlen = max_q_seqlen // batch_size + fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens @@ -323,7 +327,7 @@ def forward( is_decoding = attn_metadata.is_decoding if is_decoding: - query = query.unsqueeze(1) + query = query.unflatten(0, (batch_size, max_q_seqlen)) if kv_seqlens.dtype == torch.int64: kv_seqlens = kv_seqlens.to(torch.int32) attn_output = self.flash_mla_fwd(query, @@ -421,8 +425,9 @@ def __init__( causal=causal, **kwargs, ) - from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func + from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache self.flash_attn_varlen_func_v3 = flash_attn_varlen_func + self.flash_attn_with_kvcache_v3 = flash_attn_with_kvcache def forward( self, @@ -447,10 +452,12 @@ def forward( kv_seqlens = attn_metadata.kv_seqlens kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy + batch_size = q_seqlens.size(0) + + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) if attn_metadata.is_decoding: - max_q_seqlen = 1 - else: - max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + max_q_seqlen = max_q_seqlen // batch_size + fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens @@ -473,26 +480,44 @@ def forward( v_scales_zeros=v_scales_zeros, quant_policy=quant_policy, ) - - q_shape = query.shape - o_shape = q_shape[:-1] + (self.v_head_size, ) - attn_output = query.new_empty(o_shape) - if is_decoding: - self.paged_attention_fwd( - query, - k_cache, - v_cache, - attn_output, - block_offsets, - kv_seqlens=kv_seqlens, - k_scales_zeros=k_scales_zeros, - v_scales_zeros=v_scales_zeros, - quant_policy=quant_policy, - window_size=self.sliding_window, - sm_scale=self.scale, - logit_softcapping=self.logit_softcapping, - ) + # spec decoding + if max_q_seqlen > 1: + sliding_window = (-1, -1) if self.sliding_window is None else self.sliding_window + if isinstance(sliding_window, int): + sliding_window = (sliding_window, sliding_window) + query = query.unflatten(0, (-1, max_q_seqlen)) + attn_output = self.flash_attn_with_kvcache_v3( + query, + k_cache, + v_cache, + cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32), + max_seqlen_q=max_q_seqlen, + scheduler_metadata=attn_metadata.scheduler_metadata, + page_table=block_offsets, + softmax_scale=self.scale, + causal=self.causal, + window_size=sliding_window, + softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping, + ) + else: + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_size, ) + attn_output = query.new_empty(o_shape) + self.paged_attention_fwd( + query, + k_cache, + v_cache, + attn_output, + block_offsets, + kv_seqlens=kv_seqlens, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_policy=quant_policy, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logit_softcapping, + ) else: flatten_k, flatten_v = self.flatten_kv_cache( k_cache, diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..2cb0d2d887 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -68,21 +68,21 @@ def __init__( pool: Tuple[int, int], model_config: ModelConfig, device: torch.device, + decode_query_len: int = 1, ): self.model = model self.ctx_mgr = model.ctx_mgr self.model_config = model_config - self.meta = CudaGraphMeta( - max_batchs=max_batches, - max_tokens=max_tokens, - num_blocks=num_blocks, - is_decoding=is_decoding, - device=device, - input_buffers=dict(), - output_buffers=dict(), - vocab_size=self.model_config.vocab_size, - ) + self.meta = CudaGraphMeta(max_batchs=max_batches, + max_tokens=max_tokens, + num_blocks=num_blocks, + is_decoding=is_decoding, + device=device, + input_buffers=dict(), + output_buffers=dict(), + vocab_size=self.model_config.vocab_size, + decode_query_len=decode_query_len) self.device = device self.max_batches = max_batches self.max_tokens = max_tokens @@ -109,22 +109,22 @@ def capture(self, **kwargs): # so we set thread_safe capture mode here. with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'): output = self.model(**padded_kwargs) - - output_buffers = dict(logits=output) + output_buffers = output + if isinstance(output, torch.Tensor): + output_buffers = dict(hidden_states=output) self.meta.output_buffers = output_buffers return output @record_function('forward_cudagraph') def forward(self, **kwargs): """forward.""" - num_tokens = kwargs['input_ids'].size(-1) assert self._graph is not None self.model.fill_buffers_cudagraph(self.meta, **kwargs) context = self.ctx_mgr.current_context() self.model.update_context_cudagraph(self.meta, context) self._graph.replay() - output = self.meta.output_buffers['logits'][:, :num_tokens] + output = self.model.get_outputs_cudagraph(self.meta, **kwargs) return output def __del__(self): @@ -187,17 +187,19 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas batch_size = attn_metadata.q_seqlens.size(0) meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch + # for draft model to distinguish inputs from target model and itself + query_len = input_ids.size(1) // batch_size if meta.padding_batch_size is None: batch_size = self._get_capture_tokens(batch_size) else: batch_size = self._get_capture_tokens(meta.padding_batch_size) - return (batch_size, is_decoding, enable_microbatch) + return (batch_size, is_decoding, enable_microbatch, query_len) - def _get_max_tokens(self, graph_key: tuple): + def _get_max_tokens(self, graph_key: tuple, input_ids: torch.Tensor, q_seqlens: torch.Tensor): max_batches = graph_key[0] is_decoding = graph_key[1] assert is_decoding - return self.cudagraph_strategy.get_max_tokens(max_batches) + return self.cudagraph_strategy.get_max_tokens(max_batches, input_ids, q_seqlens) def __call__(self, **kwargs): """call.""" @@ -213,16 +215,20 @@ def __call__(self, **kwargs): graph_key = self.get_graph_key(**kwargs) max_batches = graph_key[0] is_decoding = graph_key[1] + decode_query_len = graph_key[3] if graph_key not in self._runner_map: - max_tokens = self._get_max_tokens(graph_key) - runner = CUDASingleGraphRunner(self.model, - max_batches=max_batches, - max_tokens=max_tokens, - num_blocks=self.num_blocks, - is_decoding=is_decoding, - pool=self.graph_pool_handle, - model_config=self.model_config, - device=self.device) + max_tokens = self._get_max_tokens(graph_key, kwargs['input_ids'], kwargs['attn_metadata'].q_seqlens) + runner = CUDASingleGraphRunner( + self.model, + max_batches=max_batches, + max_tokens=max_tokens, + num_blocks=self.num_blocks, + is_decoding=is_decoding, + pool=self.graph_pool_handle, + model_config=self.model_config, + device=self.device, + decode_query_len=decode_query_len, + ) runner.capture(**kwargs) self._runner_map[graph_key] = runner else: diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index d6b77de59e..b155ce4124 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +from typing import Optional, Tuple import torch @@ -19,6 +19,46 @@ def _get_meta_flashmla(kv_seqlens, num_attention_heads): return tile_scheduler_metadata, num_splits +def _get_meta_flashattn( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + causal=True, + window_size=(-1, -1), # -1 means infinite context window + num_splits=0, +): + """Get scheduler metadata for flash attn.""" + from flash_attn_interface import get_scheduler_metadata + + metadata = get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens, + qkv_dtype=qkv_dtype, + headdim_v=headdim_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + page_size=page_size, + causal=causal, + window_size=window_size, + num_splits=num_splits, + ) + return metadata + + class CudaOpsBackend(DefaultOpsBackend): """Cuda layer backend.""" @@ -121,6 +161,28 @@ def update_meta_flashmla(cls, attn_metadata, num_attention_heads): if attn_metadata.block_offsets.dtype != torch.int32: attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32) + @classmethod + def update_meta_flashattn(cls, attn_metadata, step_context): + batch_size = attn_metadata.q_seqlens.size(0) + max_seqlen_q = step_context.input_ids.size(1) // batch_size + block_size = step_context.kv_caches[0][0].size(1) + window_size = (step_context.model_config.sliding_window, ) * 2 + scheduler_metadata = _get_meta_flashattn( + batch_size=batch_size, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=step_context.max_kv_seqlen, + num_heads_q=step_context.model_config.num_attention_heads, + num_heads_kv=step_context.model_config.num_key_value_heads, + headdim=step_context.model_config.head_dim, + cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32), + qkv_dtype=step_context.model_config.dtype, + page_size=block_size, + window_size=window_size, + ) + attn_metadata.scheduler_metadata = scheduler_metadata + attn_metadata.max_kv_seqlen = step_context.max_kv_seqlen + return attn_metadata + @classmethod def update_step_context(cls, step_context): """Update step context.""" @@ -130,11 +192,20 @@ def update_step_context(cls, step_context): kv_seqlens = step_context.kv_seqlens kv_start_loc = None kv_flatten_size = None - cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0)) + use_flash_mla = step_context.model_config.use_flash_mla + use_flash_attn3_decoding = step_context.model_config.model_paradigm == 'ar_spec' + + if use_flash_mla or use_flash_attn3_decoding: + step_context.block_offsets = step_context.block_offsets.to(torch.int32) + + cu_seqlens_q = None + cu_seqlens_k = None if not step_context.is_decoding: + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0)) kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens kv_flatten_size = step_context.sum_kv_seqlen + attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, @@ -147,9 +218,15 @@ def update_step_context(cls, step_context): cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, ) - if getattr(step_context.model_config, 'use_flash_mla', False) is True: + if use_flash_mla: + if step_context.is_decoding is True: + decode_query_len = step_context.input_ids.size(1) // q_seqlens.size(0) + cls.update_meta_flashmla(attn_metadata, + step_context.model_config.num_attention_heads * decode_query_len) + + if use_flash_attn3_decoding: if step_context.is_decoding is True: - cls.update_meta_flashmla(attn_metadata, step_context.model_config.num_attention_heads) + attn_metadata = cls.update_meta_flashattn(attn_metadata, step_context) cross_seqlens = step_context.cross_seqlens cross_kv_seqlens = step_context.cross_kv_seqlens diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..66ea232712 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -7,6 +7,7 @@ from lmdeploy.messages import PytorchEngineConfig from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend +from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value def _update_torch_dtype(config: 'ModelConfig', dtype: str): @@ -213,12 +214,16 @@ def get_head_size(self): return self.head_dim @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - dist_config: DistConfig = None, - hf_overrides: Dict[str, Any] = None): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: bool = True, + dtype: str = 'auto', + dist_config: DistConfig = None, + hf_overrides: Dict[str, Any] = None, + is_draft_model: bool = False, + spec_method: str = None, + ): """Instantiate one of the configuration classes of the library from a pretrained model configuration. @@ -239,24 +244,35 @@ def from_pretrained(cls, # phi3 + trust_remote_code leads to error when tp. hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - model_config = cls.from_hf_config(hf_config, - pretrained_model_name_or_path, - dtype=dtype, - dist_config=dist_config) + model_config = cls.from_hf_config( + hf_config, + pretrained_model_name_or_path, + dtype=dtype, + dist_config=dist_config, + is_draft_model=is_draft_model, + spec_method=spec_method, + ) if hf_overrides is not None: logger = get_logger('lmdeploy') logger.warning(f'Overriding HF config with {hf_overrides}') override_hf_config(model_config.hf_config, hf_overrides) + # for serialization of transformers modules + maybe_register_config_serialize_by_value(trust_remote_code) + return model_config @classmethod - def from_hf_config(cls, - hf_config: Any, - model_path: str = None, - dtype: str = 'auto', - dist_config: DistConfig = None): + def from_hf_config( + cls, + hf_config: Any, + model_path: str = None, + dtype: str = 'auto', + dist_config: DistConfig = None, + is_draft_model: bool = False, + spec_method: str = None, + ): """From huggingface config.""" from lmdeploy.pytorch.configurations import AutoModelConfigBuilder if dist_config is None: @@ -266,7 +282,11 @@ def from_hf_config(cls, else: tp = 1 - model_config = AutoModelConfigBuilder.build(hf_config, model_path, tp=tp) + model_config = AutoModelConfigBuilder.build(hf_config, + model_path, + tp=tp, + is_draft_model=is_draft_model, + spec_method=spec_method) if model_config.k_head_dim is None: assert model_config.head_dim is not None @@ -352,3 +372,49 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): logprobs_mode=engine_config.logprobs_mode, dllm_config=dllm_config) return misc_config + + +@dataclass +class SpecDecodeConfig: + model: str + method: str + cache_config: CacheConfig = None + num_speculative_tokens: int = 1 + model_config: ModelConfig = None + + @classmethod + def from_config( + cls, + method: str, + num_speculative_tokens: int, + model: str, + target_cache_cfg: CacheConfig, + target_model: str = None, + dtype: str = 'auto', + ): + model = model or target_model + model_config = ModelConfig.from_pretrained(model, + trust_remote_code=True, + dtype=dtype, + is_draft_model=True, + spec_method=method) + cache_config = None + # include medusa + no_caches = ['medusa'] + if method not in no_caches: + cache_config = CacheConfig(max_batches=target_cache_cfg.max_batches, + block_size=target_cache_cfg.block_size, + num_cpu_blocks=target_cache_cfg.num_cpu_blocks, + num_gpu_blocks=target_cache_cfg.num_gpu_blocks, + cache_max_entry_count=target_cache_cfg.cache_max_entry_count, + max_prefill_token_num=target_cache_cfg.max_prefill_token_num, + device_type=target_cache_cfg.device_type, + migration_backend=target_cache_cfg.migration_backend) + obj = cls( + model=model, + method=method, + cache_config=cache_config, + model_config=model_config, + num_speculative_tokens=num_speculative_tokens, + ) + return obj diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index f83abe38f2..924a68e026 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -13,7 +13,7 @@ def condition(cls, hf_config): return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'kimi_k2'] @classmethod - def build(cls, hf_config, model_path: str = None, **kwargs): + def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs): """build.""" head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim) k_head_dim = head_dim @@ -25,15 +25,33 @@ def build(cls, hf_config, model_path: str = None, **kwargs): # update num_kv_heads for tp mode num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads) hf_config.use_flash_mla = flash_mla_available() + num_layers = hf_config.num_hidden_layers + model_paradigm = 'ar' - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=head_dim, - k_head_dim=k_head_dim, - v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - use_flash_mla=hf_config.use_flash_mla) + if spec_method is not None: + assert spec_method == 'deepseek_mtp' + + # draft model cfg + if is_draft_model: + num_layers = hf_config.num_nextn_predict_layers + hf_config.architectures[0] = 'DeepseekMTPModel' + # remove for correct mapping when building the patched model + del hf_config.auto_map + + if is_draft_model or spec_method is not None: + model_paradigm = 'ar_spec' + + return ModelConfig( + hidden_size=hf_config.hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + head_dim=head_dim, + k_head_dim=k_head_dim, + v_head_dim=v_head_dim, + vocab_size=hf_config.vocab_size, + use_flash_mla=hf_config.use_flash_mla, + model_paradigm=model_paradigm, + ) diff --git a/lmdeploy/pytorch/configurations/llama.py b/lmdeploy/pytorch/configurations/llama.py new file mode 100644 index 0000000000..885535c588 --- /dev/null +++ b/lmdeploy/pytorch/configurations/llama.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +class LlamaModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.architectures[0] in ['LlamaForCausalLM'] + + @classmethod + def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False, spec_method: str = None, **kwargs): + """Build llama.""" + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + + if is_draft_model: + # update draft model arch + assert spec_method is not None + hf_config.architectures[0] = spec_method.capitalize() + hf_config.architectures[0] + cfg.vocab_size = getattr(hf_config, 'draft_vocab_size', hf_config.vocab_size) + cfg.model_paradigm = 'ar_spec' + elif spec_method is not None: + # add aux_hidden_state_layers for eagle3 + if spec_method == 'eagle3': + num_layers = cfg.num_layers + hf_config.aux_hidden_state_layers = (2, num_layers // 2, num_layers - 3) + cfg.model_paradigm = 'ar_spec' + cfg.hf_config = hf_config + return cfg diff --git a/lmdeploy/pytorch/configurations/utils.py b/lmdeploy/pytorch/configurations/utils.py index 430d4bd724..dfdd50512e 100644 --- a/lmdeploy/pytorch/configurations/utils.py +++ b/lmdeploy/pytorch/configurations/utils.py @@ -19,3 +19,18 @@ def flash_mla_available(): except ImportError: logger.warning('For higher performance, please install flash_mla https://github.com/deepseek-ai/FlashMLA') return use_flash_mla + + +def flash_attn_v3_available(): + """Check if flash attn v3 is available.""" + use_fa3 = False + try: + # Now flash-attention only support FA3 for sm90a && cuda >= 12.3 + if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'): + import flash_attn_interface # noqa: F401 + assert torch.ops.flash_attn_3 is not None + use_fa3 = True + except Exception: + logger.warning('For higher performance, please install FlashAttention-3 ' + 'https://github.com/Dao-AILab/flash-attention') + return use_fa3 diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index fa6da852cd..4c0b4f8c93 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -10,8 +10,9 @@ import numpy as np import torch +from torch.profiler import record_function -from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType +from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, @@ -20,7 +21,7 @@ from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer from ..adapter.adapter import AdapterManager -from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig +from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler @@ -149,6 +150,26 @@ def _build_misc_config(engine_config: PytorchEngineConfig): return misc_config +def _build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig, + cache_config: CacheConfig): + """Build spec decode config.""" + specdecode_config = None + if speculative_config is not None: + draft_model = speculative_config.model + if draft_model and not os.path.exists(speculative_config.model): + draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision) + + specdecode_config = SpecDecodeConfig.from_config( + method=speculative_config.method, + num_speculative_tokens=speculative_config.num_speculative_tokens, + model=draft_model, + target_model=target_model, + target_cache_cfg=cache_config, + dtype=engine_config.dtype, + ) + return specdecode_config + + def _build_seq_meta(cache_config: CacheConfig, strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta @@ -242,6 +263,7 @@ def __init__(self, engine: 'Engine'): super().__init__(engine) self.scheduler = self.engine.scheduler self.forward_inputs = None + self.spec_decoding = engine.specdecode_config is not None self.dp = self.engine.dist_config.dp self.role = self.engine.cache_config.role @@ -310,7 +332,7 @@ async def prefetch_next_inputs(self): else: num_running = scheduler.num_running() is_decoding = self.forward_inputs['inputs'].is_decoding - running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding else 0 + running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0 if num_running > running_threshold: enable = True @@ -337,10 +359,13 @@ class Engine(EngineBase): trust_remote_code (bool): Trust remote code. """ - def __init__(self, - model_path: str, - engine_config: PytorchEngineConfig = None, - trust_remote_code: bool = True) -> None: + def __init__( + self, + model_path: str, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True, + speculative_config: SpeculativeConfig = None, + ) -> None: # make sure engine config exist engine_config = _update_engine_config(engine_config) @@ -376,20 +401,27 @@ def __init__(self, dist_config = _build_dist_config(engine_config) misc_config = _build_misc_config(engine_config) + # spec decode + self.specdecode_config = _build_specdecode_config(model_path, speculative_config, engine_config, cache_config) + # build model agent - self.executor = build_executor(model_path, - cache_config=cache_config, - backend_config=backend_config, - dist_config=dist_config, - misc_config=misc_config, - adapters=adapters, - device_type=engine_config.device_type, - distributed_executor_backend=engine_config.distributed_executor_backend, - dtype=engine_config.dtype) + self.executor = build_executor( + model_path, + cache_config=cache_config, + backend_config=backend_config, + dist_config=dist_config, + misc_config=misc_config, + adapters=adapters, + device_type=engine_config.device_type, + distributed_executor_backend=engine_config.distributed_executor_backend, + dtype=engine_config.dtype, + specdecode_config=self.specdecode_config, + ) self.executor.init() # strategies - self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config) + self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config, + self.specdecode_config) self.sampling_strategy = self.strategy_factory.build_sampling_strategy() self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy() self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config, @@ -433,6 +465,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True, + speculative_config: SpeculativeConfig = None, **kwargs): """Lmdeploy python inference engine. @@ -453,15 +486,21 @@ def from_pretrained(cls, if engine_config is not None and engine_config.enable_mp_engine: from .mp_engine import build_mp_engine backend = engine_config.mp_engine_backend - return build_mp_engine(backend=backend, - model_path=pretrained_model_name_or_path, - engine_config=engine_config, - trust_remote_code=trust_remote_code) + return build_mp_engine( + backend=backend, + model_path=pretrained_model_name_or_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code, + speculative_config=speculative_config, + ) if len(kwargs) > 0: logger.debug(f'Get unexpected kwargs: {kwargs}') - return cls(model_path=pretrained_model_name_or_path, - engine_config=engine_config, - trust_remote_code=trust_remote_code) + return cls( + model_path=pretrained_model_name_or_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code, + speculative_config=speculative_config, + ) def _download_adapters(self, adapters: Dict[str, str], engine_config: PytorchEngineConfig): """Download adapters.""" @@ -737,6 +776,7 @@ def __has_values(input_multimodals): @torch.inference_mode() @logging_timer('CreateModelInputs', logger) + @record_function('CreateModelInputs') def create_model_inputs(self, messages: SeqList, is_prefill: bool): """Create model inputs from messages. @@ -749,6 +789,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): # input ids token_ids = [msg.token_ids for msg in messages] + input_ids = torch.as_tensor(np.concatenate(token_ids))[None] # seqlens @@ -806,7 +847,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): # vision inputs vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) model_inputs.vision_inputs = vision_model_inputs - return model_inputs def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, @@ -826,6 +866,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) msg.status = MessageStatus.STOPPED + @record_function('make_infer_outputs') def _make_infer_outputs( self, batched_outputs: BatchedOutputs, @@ -868,9 +909,14 @@ def _make_infer_outputs( num_logprobs = msg.sampling_param.num_logprobs cur_logprobs = None if num_logprobs >= 0: - cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1]) - - req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) + cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1]) + # get spec stats info + spec_info = None + if self.specdecode_config is not None and is_decoding and self.engine_config.enable_metrics: + num_draft_tokens = self.specdecode_config.num_speculative_tokens + num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1 + spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens) + req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info) out = InferOutput(session_id=session_id, resp=msg.resp, finish=finish, @@ -889,6 +935,8 @@ def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): def __need_logits(seqs: SeqList): """Need logits.""" + if self.specdecode_config is not None: + return True return any(seq.return_logits for seq in seqs) scheduler = self.scheduler @@ -1127,7 +1175,6 @@ async def _async_loop_main( if idx == num_loops - 1: scheduler.collect_migration_done() forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() - # send output out = await self.executor.get_output_async() if out is not None: diff --git a/lmdeploy/pytorch/engine/executor/__init__.py b/lmdeploy/pytorch/engine/executor/__init__.py index 517b0d8f5f..ec7b736015 100644 --- a/lmdeploy/pytorch/engine/executor/__init__.py +++ b/lmdeploy/pytorch/engine/executor/__init__.py @@ -3,7 +3,7 @@ from typing import Dict from lmdeploy.pytorch import envs -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.utils import get_logger from .base import ExecutorBase @@ -53,25 +53,32 @@ def _log_and_set_backend(message: str, executor_backend: str): # return _log_and_set_backend(f'local device_count({device_count})>=world_size({world_size}),', 'mp') -def build_executor(model_path: str, - cache_config: CacheConfig, - backend_config: BackendConfig, - dist_config: DistConfig, - misc_config: MiscConfig, - adapters: Dict[str, str] = None, - device_type: str = 'cuda', - distributed_executor_backend: str = None, - dtype: str = 'auto') -> ExecutorBase: +def build_executor( + model_path: str, + cache_config: CacheConfig, + backend_config: BackendConfig, + dist_config: DistConfig, + misc_config: MiscConfig, + adapters: Dict[str, str] = None, + device_type: str = 'cuda', + distributed_executor_backend: str = None, + dtype: str = 'auto', + specdecode_config: SpecDecodeConfig = None, +) -> ExecutorBase: """Build model agent executor.""" logger = get_logger('lmdeploy') dp = dist_config.dp world_size = dist_config.world_size - model_config = ModelConfig.from_pretrained(model_path, - trust_remote_code=True, - dtype=dtype, - hf_overrides=misc_config.hf_overrides, - dist_config=dist_config) + model_config = ModelConfig.from_pretrained( + model_path, + trust_remote_code=True, + dtype=dtype, + hf_overrides=misc_config.hf_overrides, + dist_config=dist_config, + is_draft_model=False, + spec_method=None if specdecode_config is None else specdecode_config.method, + ) if distributed_executor_backend is None: distributed_executor_backend = get_distributed_executor_backend(world_size, dp, device_type, logger) @@ -99,6 +106,7 @@ def build_executor(model_path: str, misc_config=misc_config, adapters=adapters, device_type=device_type, + specdecode_config=specdecode_config, ) elif distributed_executor_backend == 'mp': from .mp_executor import MPExecutor @@ -111,6 +119,7 @@ def build_executor(model_path: str, misc_config=misc_config, adapters=adapters, device_type=device_type, + specdecode_config=specdecode_config, ) elif distributed_executor_backend == 'ray': from .ray_executor import RayExecutor @@ -124,6 +133,7 @@ def build_executor(model_path: str, adapters=adapters, device_type=device_type, dtype=dtype, + specdecode_config=specdecode_config, ) else: raise RuntimeError(f'Unsupported distributed_executor_backend: {distributed_executor_backend}.') diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e50843a80..fa1621521b 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -4,7 +4,7 @@ import contextlib from typing import Any, Dict, List, Optional -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.pytorch.engine.cache_engine import CacheEngine @@ -24,6 +24,7 @@ def __init__(self, dist_config: DistConfig, misc_config: MiscConfig, adapters: Dict[str, str] = None, + specdecode_config: SpecDecodeConfig = None, device_type: str = 'cuda'): """Initialize Executor.""" cache_config.window_size = model_config.sliding_window @@ -40,6 +41,7 @@ def __init__(self, self.tp = dist_config.tp self.world_size = dist_config.world_size self.device_type = device_type + self.specdecode_config = specdecode_config def download_models(self): """Download model.""" @@ -53,11 +55,11 @@ def gather_free_mem(self): """Gather available memory.""" raise NotImplementedError('Not Implemented.') - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" raise NotImplementedError('Not Implemented.') - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): """Set all model config.""" raise NotImplementedError('Not Implemented.') @@ -156,6 +158,9 @@ def _adjust_block_size(self): def update_configs(self): """Update cache config.""" self._adjust_block_size() + # spec + if self.specdecode_config and self.specdecode_config.cache_config: + self.specdecode_config.cache_config.block_size = self.cache_config.block_size cache_config = self.cache_config model_config = self.model_config free_mems = self.gather_free_mem() @@ -166,12 +171,26 @@ def update_configs(self): tp = self.dist_config.attn_config.tp cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp, cache_config.quant_policy) - runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size, vocal_size) + spec_cache_config = None + spec_model_config = None + spec_cache_block_size = 0 + if self.specdecode_config: + spec_model_config = self.specdecode_config.model_config + if spec_cache_config := self.specdecode_config.cache_config: + spec_cache_block_size = CacheEngine.get_cache_block_size(spec_cache_config.block_size, + spec_model_config) + + runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size + spec_cache_block_size, + vocal_size) if cache_config.max_prefill_token_num != max_prefill_token_num: if max_prefill_token_num <= 0: raise RuntimeError('No enough gpu memory for runtime.') cache_config.max_prefill_token_num = max_prefill_token_num logger.warning(f'No enough memory. Update max_prefill_token_num={max_prefill_token_num}') + + if spec_cache_config is not None: + spec_cache_config.max_prefill_token_num = max_prefill_token_num + free_mem -= runtime_mem logger.debug(f'estimated max runtime memory: {runtime_mem >> 20} mb') available_mem = free_mem * cache_config.cache_max_entry_count @@ -180,8 +199,11 @@ def update_configs(self): cache_config.num_gpu_blocks = int(available_mem / cache_block_size) if cache_config.num_gpu_blocks <= 0: raise RuntimeError('No enough gpu memory for kv cache.') - self.set_cache_config(cache_config) - self.set_model_config(model_config) + if spec_cache_config is not None: + spec_cache_config.num_gpu_blocks = cache_config.num_gpu_blocks + + self.set_cache_config(cache_config, spec_cache_config) + self.set_model_config(model_config, spec_model_config) def init(self): """init.""" @@ -192,6 +214,9 @@ def init(self): logger.info('Building GraphRunner and warmup ops, please waiting.') self.build_graph_runner() logger.info(f'Building CacheEngine with config: \n{self.cache_config}.') + if self.specdecode_config: + if spec_cache_config := self.specdecode_config.cache_config: + logger.info(f'Building Spec CacheEngine with config: \n{spec_cache_config}.') self.build_cache_engine() logger.info('Warming up model.') self.warmup() diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 56d8d5f58b..77b97b0ac4 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from lmdeploy.pytorch.backends.selector import get_backend -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.pytorch.devices import DeviceContext from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch @@ -31,6 +31,7 @@ def __init__( adapters: Dict[str, str] = None, device_type: str = 'cuda', log_level: int = 30, + specdecode_config: SpecDecodeConfig = None, ): self.model_path = model_path self.model_config = model_config @@ -45,7 +46,7 @@ def __init__( self.tp = dist_config.tp self.world_size = dist_config.world_size self.device_type = device_type - + self.specdecode_config = specdecode_config logger.setLevel(log_level) self.out_que: asyncio.Queue = None self._output_loop: asyncio.Task = None @@ -94,27 +95,30 @@ def build_model(self): """Build model.""" self.device_ctx = DeviceContext(device_type=self.device_type) - self.model_agent = build_model_agent(model_path=self.model_path, - model_config=self.model_config, - cache_config=self.cache_config, - backend_config=self.backend_config, - misc_config=self.misc_config, - device_ctx=self.device_ctx, - dist_ctx=self.dist_ctx, - adapters=self.adapters) + self.model_agent = build_model_agent( + model_path=self.model_path, + model_config=self.model_config, + cache_config=self.cache_config, + backend_config=self.backend_config, + misc_config=self.misc_config, + device_ctx=self.device_ctx, + dist_ctx=self.dist_ctx, + adapters=self.adapters, + specdecode_config=self.specdecode_config, + ) self.model_agent.build_model() def get_free_mem(self): """Gather free mem.""" return self.model_agent.get_free_mem() - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" - self.model_agent.set_cache_config(cache_config) + self.model_agent.set_cache_config(cache_config, spec_cache_config) - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): """Set all model config.""" - self.model_agent.set_model_config(model_config) + self.model_agent.set_model_config(model_config, spec_model_config) def build_graph_runner(self): """Build graph runner.""" diff --git a/lmdeploy/pytorch/engine/executor/mp_executor.py b/lmdeploy/pytorch/engine/executor/mp_executor.py index 18aa65028d..f12e0cdcc9 100644 --- a/lmdeploy/pytorch/engine/executor/mp_executor.py +++ b/lmdeploy/pytorch/engine/executor/mp_executor.py @@ -15,7 +15,7 @@ import torch.multiprocessing as mp from lmdeploy.pytorch.backends.selector import init_backend -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.utils import get_logger, try_import_deeplink from .base import ExecutorBase @@ -225,6 +225,7 @@ def __init__(self, dist_config: DistConfig, misc_config: MiscConfig, adapters: Dict[str, str] = None, + specdecode_config: SpecDecodeConfig = None, device_type: str = 'cuda'): """Initialize Executor.""" super().__init__(model_path=model_path, @@ -233,6 +234,7 @@ def __init__(self, backend_config=backend_config, dist_config=dist_config, misc_config=misc_config, + specdecode_config=specdecode_config, adapters=adapters, device_type=device_type) @@ -264,6 +266,7 @@ def __init__(self, backend_config=backend_config, dist_config=dist_config, misc_config=misc_config, + specdecode_config=specdecode_config, adapters=adapters, device_type=device_type, log_level=logger.level) @@ -350,13 +353,13 @@ def gather_free_mem(self): ret = self.collective_rpc('get_free_mem') return ret - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" - self.collective_rpc('set_cache_config', args=(cache_config, )) + self.collective_rpc('set_cache_config', args=(cache_config, spec_cache_config)) - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): """Set all cache config.""" - self.collective_rpc('set_model_config', args=(model_config, )) + self.collective_rpc('set_model_config', args=(model_config, spec_model_config)) def build_graph_runner(self): """Build graph runner.""" @@ -425,6 +428,7 @@ def __init__( model_config: ModelConfig, dist_config: DistConfig, misc_config: MiscConfig, + specdecode_config: SpecDecodeConfig = None, adapters: Dict[str, str] = None, device_type: str = 'cuda', log_level: int = 30, @@ -436,6 +440,7 @@ def __init__( model_config=model_config, dist_config=dist_config, misc_config=misc_config, + specdecode_config=specdecode_config, adapters=adapters, device_type=device_type, log_level=log_level, @@ -486,6 +491,7 @@ def _main_loop( backend_config: BackendConfig, dist_config: DistConfig, misc_config: MiscConfig, + specdecode_config: SpecDecodeConfig = None, adapters: Dict[str, str] = None, device_type: str = 'cuda', log_level: int = 30, @@ -507,6 +513,7 @@ def handle_sigterm(signum, frame): model_config=model_config, dist_config=dist_config, misc_config=misc_config, + specdecode_config=specdecode_config, adapters=adapters, device_type=device_type, log_level=log_level) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..448618f477 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -13,7 +13,7 @@ from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.backends.selector import init_backend -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.pytorch.devices import DeviceContext, get_device_manager from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch @@ -150,21 +150,18 @@ def __init__( model_path: str, cache_config: CacheConfig, backend_config: BackendConfig, + model_config: ModelConfig, dist_config: DistConfig, misc_config: MiscConfig, adapters: Dict[str, str] = None, device_type: str = 'cuda', dtype: str = 'auto', log_level: int = 30, + specdecode_config: SpecDecodeConfig = None, ): init_backend(device_type) try_import_deeplink(device_type) - model_config = ModelConfig.from_pretrained(model_path, - dtype=dtype, - hf_overrides=misc_config.hf_overrides, - dist_config=dist_config) - super().__init__( model_path=model_path, cache_config=cache_config, @@ -175,6 +172,7 @@ def __init__( adapters=adapters, device_type=device_type, log_level=log_level, + specdecode_config=specdecode_config, ) self.node_ip = ray.util.get_node_ip_address() self._remote_logger = RemoteLogger() @@ -222,25 +220,31 @@ def exit(self): class RayExecutor(ExecutorBase): """Ray executor.""" - def __init__(self, - model_path: str, - model_config: ModelConfig, - cache_config: CacheConfig, - backend_config: BackendConfig, - dist_config: DistConfig, - misc_config: MiscConfig, - adapters: Dict[str, str] = None, - device_type: str = 'cuda', - dtype: str = 'auto'): + def __init__( + self, + model_path: str, + model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + dist_config: DistConfig, + misc_config: MiscConfig, + adapters: Dict[str, str] = None, + device_type: str = 'cuda', + dtype: str = 'auto', + specdecode_config: SpecDecodeConfig = None, + ): """Initialize Executor.""" - super().__init__(model_path=model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - dist_config=dist_config, - misc_config=misc_config, - adapters=adapters, - device_type=device_type) + super().__init__( + model_path=model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + dist_config=dist_config, + misc_config=misc_config, + adapters=adapters, + device_type=device_type, + specdecode_config=specdecode_config, + ) self.dp_rank = dist_config.dp_rank device_ctx = DeviceContext(device_type) @@ -266,6 +270,7 @@ def __init__(self, worker_kwargs = dict( model_path=model_path, cache_config=cache_config, + model_config=model_config, backend_config=backend_config, dist_config=dist_config, misc_config=misc_config, @@ -273,6 +278,7 @@ def __init__(self, device_type=device_type, dtype=dtype, log_level=logger.level, + specdecode_config=specdecode_config, ) logger.info('Init ray workers.') @@ -317,13 +323,13 @@ def gather_free_mem(self): """Gather available memory.""" return self.collective_rpc('get_free_mem') - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" - self.collective_rpc('set_cache_config', (cache_config, )) + self.collective_rpc('set_cache_config', (cache_config, spec_cache_config)) - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): """Set all model config.""" - self.collective_rpc('set_model_config', (model_config, )) + self.collective_rpc('set_model_config', (model_config, spec_model_config)) def build_graph_runner(self): """Build graph runner.""" diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index 8e7fab9ca2..9cf86dd915 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -2,7 +2,7 @@ import asyncio from typing import Dict, List -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SpecDecodeConfig from lmdeploy.pytorch.devices import DeviceContext from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch @@ -17,14 +17,17 @@ class UniExecutor(ExecutorBase): """Single node single device Executor.""" - def __init__(self, - model_path: str, - model_config: ModelConfig, - cache_config: CacheConfig, - backend_config: BackendConfig, - misc_config: MiscConfig, - adapters: Dict[str, str] = None, - device_type: str = 'cuda'): + def __init__( + self, + model_path: str, + model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + misc_config: MiscConfig, + adapters: Dict[str, str] = None, + device_type: str = 'cuda', + specdecode_config: SpecDecodeConfig = None, + ): """Initialize Executor.""" super().__init__(model_path=model_path, model_config=model_config, @@ -33,16 +36,20 @@ def __init__(self, dist_config=DistConfig(), misc_config=misc_config, adapters=adapters, - device_type=device_type) + device_type=device_type, + specdecode_config=specdecode_config) self.device_ctx = DeviceContext(device_type=device_type) - self.model_agent = build_model_agent(model_path=model_path, - model_config=model_config, - cache_config=cache_config, - backend_config=backend_config, - misc_config=misc_config, - device_ctx=self.device_ctx, - adapters=adapters) + self.model_agent = build_model_agent( + model_path=model_path, + model_config=model_config, + cache_config=cache_config, + backend_config=backend_config, + misc_config=misc_config, + device_ctx=self.device_ctx, + adapters=adapters, + specdecode_config=specdecode_config, + ) def download_models(self): """Download model.""" @@ -56,13 +63,13 @@ def gather_free_mem(self): """Gather available memory.""" return [self.model_agent.get_free_mem()] - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" - self.model_agent.set_cache_config(cache_config) + self.model_agent.set_cache_config(cache_config, spec_cache_config) - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig): """Set all cache config.""" - self.model_agent.set_model_config(model_config) + self.model_agent.set_model_config(model_config, spec_model_config) def build_graph_runner(self): """Build graph runner.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 2ba2850c75..a53cf93e53 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -20,11 +20,12 @@ from lmdeploy.utils import get_logger from ..backends import get_backend -from ..config import BackendConfig, CacheConfig, MiscConfig, ModelConfig +from ..config import BackendConfig, CacheConfig, MiscConfig, ModelConfig, SpecDecodeConfig from ..devices import DeviceContext, get_device_manager from ..distributed import DistContext, get_dist_manager from ..model_inputs import ModelInputs, step_ctx_manager from ..models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map +from ..spec_decode import SpecModelAgent from ..strategies import build_strategy_factory from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria from ..utils import get_gpu_memory @@ -223,6 +224,7 @@ def model_forward( inputs: ModelInputs, cache_engine: CacheEngine, stream: torch.cuda.Stream = None, + output_position_ids: bool = False, ): """Perform model forward.""" stream = stream or torch.cuda.current_stream() @@ -235,6 +237,7 @@ def model_forward( kv_caches=cache_engine.gpu_cache, kv_quant_policy=cache_engine.cache_config.quant_policy, ) + with ctx_mgr.context(context): model_metas = None model_metas = model.update_model_metas( @@ -246,12 +249,13 @@ def model_forward( context=context, ) output = model(**input_dict) - + if not isinstance(output, dict): + output = dict(hidden_states=output) # InternVL-3.5-Flash will change the seqlen, model_metas during forward - model_metas = context.model_metas - seq_length = context.q_seqlens[:len(inputs.seq_length)] - - return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) + output.update(dict(model_metas=model_metas, seq_length=context.q_seqlens[:len(inputs.seq_length)])) + if output_position_ids: + output.update(dict(position_ids=context.position_ids)) + return output def _try_to_cuda(val, non_blocking: bool = False): @@ -301,15 +305,18 @@ class BaseModelAgent: trust_remote_code (bool): Trust remote code """ - def __init__(self, - model_path: str, - model_config: ModelConfig, - cache_config: CacheConfig, - backend_config: BackendConfig, - misc_config: MiscConfig, - dist_ctx: DistContext, - device_ctx: DeviceContext, - adapters: Dict[str, str] = None): + def __init__( + self, + model_path: str, + model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + misc_config: MiscConfig, + dist_ctx: DistContext, + device_ctx: DeviceContext, + adapters: Dict[str, str] = None, + specdecode_config: SpecDecodeConfig = None, + ): self.model_config = model_config self.cache_config = cache_config @@ -345,7 +352,6 @@ def __init__(self, self.tp = tp self.world_size = world_size self.tp_rank = tp_rank - self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None @@ -365,10 +371,23 @@ def __init__(self, int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) # strategy - self.strategy_factory = build_strategy_factory(model_config, misc_config) + self.strategy_factory = build_strategy_factory(model_config, misc_config, specdecode_config=specdecode_config) self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy() self.agent_strategy = self.strategy_factory.build_model_agent_strategy() + # spec decoding + self.spec_agent = None + self.specdecode_config = specdecode_config + + # only support spec model with tp1 + if specdecode_config is not None: + self.spec_agent = SpecModelAgent(specdecode_config, + backend_config, + dist_ctx, + self.inputs_strategy, + self.agent_strategy, + device=device) + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -376,13 +395,17 @@ def all_context(self): with device_mgr.context(self.device_ctx), dist_mgr.context(self.dist_ctx): yield - def set_cache_config(self, cache_config: CacheConfig): + def set_cache_config(self, cache_config: CacheConfig, spec_cache_config: CacheConfig = None): """Set all cache config.""" self.cache_config = cache_config + if self.spec_agent is not None: + self.spec_agent.set_cache_config(spec_cache_config) - def set_model_config(self, model_config: ModelConfig): + def set_model_config(self, model_config: ModelConfig, spec_model_config: ModelConfig = None): """Set model config.""" self.model_config = model_config + if self.spec_agent is not None: + self.spec_agent.set_model_config(spec_model_config) def get_free_mem(self): """Gather available memory.""" @@ -394,8 +417,9 @@ def get_free_mem(self): def warmup(self): """warmup.""" # TODO: disable for now, do not remove the comments. - with self.all_context(): + with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode(): max_batches = self.cache_config.max_batches + num_tokens = max_batches dist_ctx = get_dist_manager().current_context() dp = dist_ctx.dp @@ -420,6 +444,10 @@ def warmup(self): inputs.build_dp_meta() self._forward_impl(inputs) + # warmup draft model + if self.spec_agent is not None: + self.spec_agent.warmup(max_batches, self.model_config) + def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): """Slice outputs.""" return self.agent_strategy.slice_outputs(inputs, seq_length) @@ -450,6 +478,8 @@ def __init__(self, max_seq_len): self._start = 0 self._output: torch.Tensor = None self._device: torch.device = None + # aux hidden states for eagle3 + self._aux_output: torch.Tensor = None def gather(self, output): """gather.""" @@ -466,6 +496,16 @@ def gather(self, output): out_logits = tmp_output.new_empty(1, self._max_seq_len, tmp_output.size(-1), device='cpu') self._device = tmp_output.device out_logits[:, start:start + seq_len].copy_(tmp_output, non_blocking=True) + + # egale3 + if 'aux_hidden_states' in output: + tmp_aux = output['aux_hidden_states'] + aux_out = self._aux_output + if aux_out is None: + aux_out = tmp_aux.new_empty(1, self._max_seq_len, tmp_aux.size(-1), device='cpu') + aux_out[:, start:start + seq_len].copy_(tmp_aux, non_blocking=True) + self._aux_output = aux_out + self._start = start + seq_len self._output = out_logits @@ -476,9 +516,11 @@ def get_output(self): self._output.numel() // self._output.size(-1), device=self._output.device, dtype=self._output.dtype) - return strategy.slice_outputs(self._output, seqlen) + return strategy.slice_outputs(self._output, seqlen), self._aux_output torch.cuda.synchronize() - return self._output.to(self._device) + if self._aux_output is not None: + self._aux_output = self._aux_output.to(self._device) + return self._output.to(self._device), self._aux_output __forward = self.async_forward @@ -496,13 +538,18 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): model_metas = tmp_out.get('model_metas') output_gather.gather(tmp_out) tmp_out.pop('hidden_states', None) - tmp_out['hidden_states'] = output_gather.get_output() + tmp_out.pop('aux_hidden_states', None) + tmp_out.pop('position_ids', None) + + tmp_out['hidden_states'], aux_hidden_states = output_gather.get_output() + if aux_hidden_states is not None: + tmp_out['aux_hidden_states'] = aux_hidden_states return tmp_out origin_inputs = inputs - # make long context inputs is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding + max_seqlen = 0 if is_long_context: seq_len = inputs.seq_length @@ -535,7 +582,10 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): for _ in range(dummy_loop): await __forward(dummy_inputs) - hidden_states = ret.pop('hidden_states') + if self.spec_agent is not None: + hidden_states, ret = self.spec_agent.update_main_model_outputs(ret, origin_inputs) + else: + hidden_states = ret.pop('hidden_states') logits = self.get_logits(hidden_states) ret['logits'] = logits return ret @@ -702,7 +752,11 @@ async def __prepare_dp(): seq_length = inputs.seq_length seq_length = output.get('seq_length', inputs.seq_length) last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] - extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length) + is_last_step = (idx == loop_count - 1) + extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, + inputs, + output, + is_last_step=is_last_step) model_metas = output.get('model_metas') # output empty for dummy inputs @@ -720,6 +774,12 @@ async def __prepare_dp(): # post sampling next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, extra_inputs) + # spec decoding + if self.spec_agent is not None: + extra_inputs = await self.spec_agent.async_model_forward(next_token_ids, inputs, extra_inputs, + sampling_inputs) + next_token_ids = extra_inputs.next_token_ids + logits = None with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') @@ -734,13 +794,14 @@ async def __prepare_dp(): logger.debug(f' rank[{rank}]: Output [{idx}]') extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs) self._push_output( - BatchedOutputs(next_token_ids=next_token_ids, - logits=logits if return_logits else None, - stopped=stopped, - stop_pos=stop_pos, - model_metas=model_metas, - logprobs=logprobs, - extra_outputs=extra_outputs)) + BatchedOutputs( + next_token_ids=next_token_ids if self.spec_agent is None else extra_inputs.output_token_ids, + logits=logits if return_logits else None, + stopped=stopped, + stop_pos=stop_pos, + model_metas=model_metas, + logprobs=logprobs, + extra_outputs=extra_outputs)) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. @@ -943,6 +1004,11 @@ def build_model(self): """Build model api.""" with self.all_context(): self._build_model() + if self.spec_agent is not None: + self.spec_agent.build_model(self.misc_config.empty_init, + self.patched_model, + model_format=self.misc_config.model_format, + build_model_ctx=self.build_model_ctx) def build_graph_runner(self): """Build graph runner.""" @@ -953,6 +1019,8 @@ def build_graph_runner(self): cache_config=self.cache_config, backend_config=self.backend_config, device=self.device) + if self.spec_agent is not None: + self.spec_agent.build_graph_runner() def build_cache_engine(self): """Build cache engine.""" @@ -967,6 +1035,8 @@ def build_cache_engine(self): tp_rank=self.tp_rank, world_size=tp, cache_stream=self.cache_stream) + if self.spec_agent is not None: + self.spec_agent.build_cache_engine(self.cache_stream) def _forward_impl(self, inputs: ModelInputs): output = model_forward( @@ -974,6 +1044,7 @@ def _forward_impl(self, inputs: ModelInputs): inputs, self.cache_engine, stream=self.stream, + output_position_ids=self.spec_agent is not None, ) return output @@ -1003,6 +1074,10 @@ def reset_graph_runner(self): if hasattr(self.patched_model, 'reset'): self.patched_model.reset() + if self.spec_agent is not None: + if self.spec_agent.proposer.model is not None and hasattr(self.spec_agent.proposer.model, 'reset'): + self.spec_agent.proposer.model.reset() + @torch.inference_mode() def update_params(self, request: UpdateParamsRequest): """Update params.""" @@ -1166,14 +1241,17 @@ def step(self): self._ready_event.record() -def build_model_agent(model_path: str, - model_config: ModelConfig, - cache_config: CacheConfig, - backend_config: BackendConfig, - misc_config: MiscConfig, - dist_ctx: DistContext = None, - device_ctx: DeviceContext = None, - adapters: Dict[str, str] = None): +def build_model_agent( + model_path: str, + model_config: ModelConfig, + cache_config: CacheConfig, + backend_config: BackendConfig, + misc_config: MiscConfig, + dist_ctx: DistContext = None, + device_ctx: DeviceContext = None, + adapters: Dict[str, str] = None, + specdecode_config: SpecDecodeConfig = None, +): """Create model agent. Args: @@ -1202,5 +1280,6 @@ def build_model_agent(model_path: str, adapters=adapters, dist_ctx=dist_ctx, device_ctx=device_ctx, + specdecode_config=specdecode_config, ) return model_agent diff --git a/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py b/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py index abdbe20ec5..47ed1d1a12 100644 --- a/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py @@ -6,7 +6,7 @@ import torch.multiprocessing as mp -from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig from lmdeploy.utils import get_logger from .base import MPEngine @@ -29,20 +29,29 @@ def cancel_async_tasks(loop: asyncio.AbstractEventLoop): class ZMQMPEngine(MPEngine): - def __init__(self, model_path: str, engine_config: PytorchEngineConfig = None, **kwargs) -> None: + def __init__(self, + model_path: str, + engine_config: PytorchEngineConfig = None, + speculative_config: SpeculativeConfig = None, + **kwargs) -> None: """Initialize mp engine.""" from .zmq_rpc import AsyncRPCClient self.shared_dict = None self.port = None self.proc = None - self._start_mp_proc(model_path, engine_config) + self._start_mp_proc(model_path, engine_config, speculative_config=speculative_config) self.rpc_client = AsyncRPCClient(port=self.port) super().__init__() atexit.register(self.close) - def _start_mp_proc(self, model_path: str, engine_config: PytorchEngineConfig = None): + def _start_mp_proc( + self, + model_path: str, + engine_config: PytorchEngineConfig = None, + speculative_config: SpeculativeConfig = None, + ): """Start mp proc.""" logger.debug('Starting engine multi-process.') with mp.Manager() as manager: @@ -57,6 +66,7 @@ def _start_mp_proc(self, model_path: str, engine_config: PytorchEngineConfig = N model_path=model_path, engine_config=engine_config, log_level=log_level, + speculative_config=speculative_config, )), name='mp_engine_proc', ) @@ -68,11 +78,14 @@ def _start_mp_proc(self, model_path: str, engine_config: PytorchEngineConfig = N self.port = self.shared_dict['rpc_server_port'] @staticmethod - def _mp_proc(shared_dict: dict, - condition: mp.Condition, - model_path: str, - engine_config: PytorchEngineConfig = None, - log_level: str = 'WARNING'): + def _mp_proc( + shared_dict: dict, + condition: mp.Condition, + model_path: str, + engine_config: PytorchEngineConfig = None, + log_level: str = 'WARNING', + speculative_config: SpeculativeConfig = None, + ): """Mp process function.""" from lmdeploy.pytorch.engine import Engine @@ -92,6 +105,7 @@ def _mp_proc(shared_dict: dict, engine = Engine.from_pretrained( model_path, engine_config=engine_config, + speculative_config=speculative_config, ) loop = asyncio.new_event_loop() diff --git a/lmdeploy/pytorch/kernels/cuda/flash_mla.py b/lmdeploy/pytorch/kernels/cuda/flash_mla.py index 1a3209edeb..69a7c28d1c 100644 --- a/lmdeploy/pytorch/kernels/cuda/flash_mla.py +++ b/lmdeploy/pytorch/kernels/cuda/flash_mla.py @@ -42,4 +42,4 @@ def flash_mla_fwd( softmax_scale, causal, ) - return out.squeeze(1) + return out.flatten(0, 1) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 13e35fd1ae..1200684c0e 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -144,6 +144,8 @@ class ModelInputs: model_metas: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None enable_microbatch: bool = False + target_hidden_states: torch.Tensor = None + target_position_ids: torch.Tensor = None def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): """Update input ids.""" @@ -242,22 +244,27 @@ def __make_next_vision_inputs(flatten_mms: List, start: int): if isinstance(max_q_seqlen, torch.Tensor): max_q_seqlen = max_q_seqlen.item() max_kv_seqlen += max_q_seqlen - inp = ModelInputs( - input_ids=self.input_ids[:, start:end], - seq_length=input_ids.new_tensor([end - start]), - block_offsets=self.block_offsets, - history_lengths=self.history_lengths + start, - is_decoding=self.is_decoding, - num_ignored_history=self.num_ignored_history, - max_q_seqlen=max_q_seqlen, - max_kv_seqlen=max_kv_seqlen, - sum_kv_seqlen=max_kv_seqlen, - local_adapter_ids=self.local_adapter_ids, - vision_inputs=vision_inputs, - model_metas=self.model_metas, - cross_length=cross_length, - history_cross_length=history_cross_length, - ) + + target_hidden_states = self.target_hidden_states[:, start: + end] if self.target_hidden_states is not None else None + target_position_ids = self.target_position_ids[:, + start:end] if self.target_position_ids is not None else None + inp = ModelInputs(input_ids=self.input_ids[:, start:end], + seq_length=input_ids.new_tensor([end - start]), + block_offsets=self.block_offsets, + history_lengths=self.history_lengths + start, + is_decoding=self.is_decoding, + num_ignored_history=self.num_ignored_history, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=max_kv_seqlen, + local_adapter_ids=self.local_adapter_ids, + vision_inputs=vision_inputs, + model_metas=self.model_metas, + cross_length=cross_length, + history_cross_length=history_cross_length, + target_hidden_states=target_hidden_states, + target_position_ids=target_position_ids) ret.append(inp) history_cross_length = cross_length @@ -308,6 +315,7 @@ class StepContext: kv_caches: List is_decoding: bool sum_kv_seqlen: int + max_kv_seqlen: int = None local_adapter_ids: torch.LongTensor = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None @@ -321,6 +329,8 @@ class StepContext: model_metas: List[Dict[str, Any]] = None dp_meta: DPMeta = None enable_microbatch: bool = False + # for draft model + target_hidden_states: torch.Tensor = None _outputs: Dict = field(default_factory=dict) @@ -353,7 +363,6 @@ def new( # position ids attention_mask, position_ids = cls.get_mask_and_position_ids(inputs) - position_ids = position_ids[None] # [num_tokens] -> [1, num_tokens] q_start_loc = q_seqlens.cumsum(0) - q_seqlens # cross @@ -381,6 +390,7 @@ def new( kv_caches=kv_caches, is_decoding=inputs.is_decoding, sum_kv_seqlen=inputs.sum_kv_seqlen, + max_kv_seqlen=inputs.max_kv_seqlen, local_adapter_ids=inputs.local_adapter_ids, vision_inputs=inputs.vision_inputs, kv_quant_policy=kv_quant_policy, @@ -389,6 +399,7 @@ def new( cross_kv_seqlens=cross_kv_seqlens, dp_meta=inputs.dp_meta, enable_microbatch=inputs.enable_microbatch, + target_hidden_states=inputs.target_hidden_states, ) ret = get_backend().update_step_context(ret) @@ -400,12 +411,14 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): q_seqlens = inputs.seq_length history_seqlens = inputs.history_lengths max_q_seqlen = inputs.max_q_seqlen - + target_position_ids = inputs.target_position_ids # decoding if max_q_seqlen == 1: attention_mask = torch.ones_like(q_seqlens)[:, None] - position_ids = history_seqlens.unsqueeze(-1).clone() - position_ids = position_ids.flatten() + if target_position_ids is not None: + position_ids = target_position_ids + else: + position_ids = history_seqlens.unsqueeze(0).clone() return attention_mask, position_ids num_tokens = inputs.input_ids.numel() @@ -418,11 +431,13 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): ranges = torch.arange(0, max_q_seqlen, device=device) position_ids = history_seqlens[:, None] + ranges[None, :] position_ids = position_ids.flatten() - return attention_mask, position_ids + return attention_mask, position_ids[None] # get mask mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() + if target_position_ids is not None: + return attention_mask, target_position_ids # position_ids indices = attention_mask.long().cumsum(-1) - 1 @@ -430,7 +445,8 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): indices[1:] += q_seqlens.cumsum(0)[:-1, None] position_ids_1d = position_ids.new_empty(num_tokens) position_ids_1d[indices.flatten()] = position_ids.flatten() - return attention_mask, position_ids_1d + position_ids = position_ids_1d[None] + return attention_mask, position_ids @dataclass diff --git a/lmdeploy/pytorch/models/deepseek_mtp.py b/lmdeploy/pytorch/models/deepseek_mtp.py new file mode 100644 index 0000000000..3abe1e5971 --- /dev/null +++ b/lmdeploy/pytorch/models/deepseek_mtp.py @@ -0,0 +1,760 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding, + build_rotary_params) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import build_fused_moe +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight +from lmdeploy.utils import get_logger + +from .deepseek_v2 import DeepseekV2Attention, DeepseekV2DecoderLayer, MoEGate, yarn_get_mscale +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin + +logger = get_logger('lmdeploy') + + +class DeepseekV2BMM(nn.Module): + """Wrapped bmm.""" + + def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device): + super().__init__() + + weight = self.create_weight(batch, in_features, out_features, dtype=dtype, device=device) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.register_parameter('weight', weight) + weight.weight_loader = self.weight_loader + + self.batch = batch + self.in_features = in_features + self.out_features = out_features + self.dtype = dtype + self.device = device + + def create_weight(self, batch: int, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device): + """Create weight.""" + return torch.empty((batch, in_features, out_features), dtype=dtype, device=device) + + def weight_loader(self, param: nn.Parameter, weight: torch.Tensor): + """Weight loader.""" + param.data.copy_(weight) + + def forward(self, x: torch.Tensor, output: torch.Tensor): + """forward.""" + torch.bmm(x.transpose(0, 1), self.weight, out=output.transpose(0, 1)) + + +class DeepseekV2Attention(DeepseekV2Attention): + """Deepseekv2 attention.""" + + def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None): + nn.Module.__init__(self) + quantization_config = getattr(config, 'quantization_config', None) + self.q_lora_rank = config.q_lora_rank + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + num_key_value_heads = getattr(config, 'num_key_value_heads', 1) + use_flash_mla = getattr(config, 'use_flash_mla', False) + + if self.q_lora_rank is None: + self.q_proj = build_colwise_linear( + self.hidden_size, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + dp_disable_tp=True, + ) + else: + self.q_a_proj = build_colwise_linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.q_b_proj = build_colwise_linear( + config.q_lora_rank, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + dp_disable_tp=True, + ) + + self.kv_a_proj_with_mqa = build_colwise_linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.kc = DeepseekV2BMM(self.num_heads, + config.qk_nope_head_dim, + config.kv_lora_rank, + dtype=dtype, + device=device) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + self.softmax_scale = self.q_head_dim**(-0.5) + + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0) + scaling_factor = config.rope_scaling['factor'] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_fwd = Attention(self.num_heads, + config.kv_lora_rank + self.qk_rope_head_dim, + scale=self.softmax_scale, + num_kv_heads=num_key_value_heads, + v_head_size=config.kv_lora_rank, + num_replicate_kv_heads=num_replicate_kv_heads, + use_flash_mla=use_flash_mla) + + self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device) + self.o_proj = build_o_proj( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + num_heads = self.num_heads + nope_size = self.kv_lora_rank + q_len = hidden_states.size(1) + + # qkv_proj + query_states, key_states, value_states, q_pe, k_pe = self._qkv_proj(hidden_states, num_heads=num_heads) + + cos, sin = rotary_pos_emb + q_pe, k_pe = self.apply_rotary_pos_emb( + q_pe, + k_pe, + cos, + sin, + inplace=False, + ) + query_states[..., nope_size:] = q_pe + key_states[..., nope_size:] = k_pe + + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[0][..., :nope_size], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim) + + self.vc(attn_output, attn_bmm_out) + attn_output = attn_bmm_out.flatten(-2, -1)[None] + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DeepseekV2MoE(nn.Module): + """Deepseek v2 MoE.""" + + def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.renormalize = self.top_k > 1 and self.norm_topk_prob + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + self.gate = MoEGate(config, dtype=dtype, device=device, info=None) + self.experts = build_fused_moe( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=False, + dtype=dtype, + device=device, + all_reduce=False, + quant_config=quantization_config, + layer_idx=layer_idx, + ) + self.shared_experts = None + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + config=config, + intermediate_size=intermediate_size, + dtype=dtype, + device=device, + ) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + topk_weights, topk_ids = self.gate(hidden_states) + + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + if self.shared_experts is not None: + shared_states = self.shared_experts(hidden_states) + out_states += shared_states + out_states = out_states.reshape(batch_size, sequence_length, -1) + + return out_states + + +class DeepseekV2MLP(nn.Module): + """Deepseek v2 mlp.""" + + def __init__(self, + config: Any, + intermediate_size: int = None, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + quantization_config = getattr(config, 'quantization_config', None) + # gate up + if intermediate_size is None: + intermediate_size = config.intermediate_size + self.gate_up_proj = build_gateup_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=False, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_down_linear( + intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=False, + all_reduce=False, + ) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class DeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + """Deepseekv2 decoder layer.""" + + def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None): + nn.Module.__init__(self) + self.layer_idx = layer_idx + quantization_config = None + + # build attention layer + self.self_attn = DeepseekV2Attention(config, dtype=dtype, device=device) + + # mlp + self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if + (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device)) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + +# modify from vllm + + +class SharedHead(nn.Module): + """Deepseekv2 shared head.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + # build lm_head + self.head = build_rowwise_linear(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + quantization_config = getattr(config, 'quantization_config', None) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, device=device) + self.eh_proj = build_colwise_linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + dp_disable_tp=True, + ) + + self.shared_head = SharedHead(config=config, dtype=dtype, device=device) + + self.mtp_block = DeepseekV2DecoderLayer(config, layer_idx=layer_idx, dtype=dtype, device=device) + + emb_type = RopeType.LinearScaling + rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size // + config.num_attention_heads) + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + + rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base) + update_params = build_rotary_params(config) + rope_params.update(update_params) + self.rotary_emb = build_rotary_embedding(**rope_params) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + past_key_value: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + + # masking inputs at position 0, as not needed by MTP + inputs_embeds[position_ids == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + # rotary emb + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + hidden_states, residual = self.mtp_block( + hidden_states, + rotary_pos_emb, + past_key_value, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + idx, + dtype=dtype, + device=device, + ) + for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + layer_idx = self.mtp_start_layer_idx + current_step_idx + past_key_value = past_key_values[current_step_idx] + return self.layers[str(layer_idx)]( + input_ids, + position_ids, + previous_hidden_states, + past_key_value, + inputs_embeds=inputs_embeds, + attn_metadata=attn_metadata, + spec_step_index=current_step_idx, + ) + + def get_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + + hidden_states = mtp_layer.shared_head(hidden_states) + logits = mtp_layer.shared_head.head(hidden_states) + return logits + + +class DeepseekMTPModel(nn.Module, CudaGraphMixin): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.quantization_config = getattr(config, 'quantization_config', None) + self.dtype = dtype + self.ctx_mgr = ctx_mgr + self.model = DeepSeekMultiTokenPredictor(config, dtype=dtype, device=device) + + self._load_buffers = dict() + + def get_logits(self, hidden_states: torch.Tensor, spec_step_idx: int = 0): + """Compute logits of the model output.""" + return self.model.get_logits(hidden_states, spec_step_idx=spec_step_idx) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + target_hidden_states: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, + position_ids, + target_hidden_states, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + spec_step_idx=spec_step_idx) + return hidden_states + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1, + max_tokens, + self.config.hidden_size, + dtype=self.dtype) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: torch.Tensor, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, input_ids=input_ids, **kwargs) + + num_tokens = input_ids.size(-1) + input_buffers = graph_meta.input_buffers + target_hidden_states = kwargs.get('target_hidden_states') + assert target_hidden_states is not None + input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states + new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'] + return new_inputs + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + target_hidden_states = context.target_hidden_states + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + target_hidden_states=target_hidden_states, + ) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + update_pe_mapping: List): + """Load weight attention.""" + device = next(iter(params_dict.values())).device + + def __update_pe(weight, head_dim: int, pe_dim_offset: int): + # (num_heads, q_head_dim, input_dim) + weight = weight.unflatten(0, (-1, head_dim)) + # (num_heads, nope_head_dim, input_dim) + w_pe = weight[:, pe_dim_offset:] + # (num_heads, nope_head_dim//2, 2, input_dim) + new_w_pe = w_pe.unflatten(1, (-1, 2)) + # (num_heads, nope_head_dim, input_dim) + new_w_pe = new_w_pe.transpose(1, 2).flatten(1, 2) + weight[:, pe_dim_offset:] = new_w_pe + weight = weight.flatten(0, 1) + return weight + + def __load_kcvc(name: str, weight: torch.Tensor): + """Load kc and vc from weight.""" + config = self.config + v_head_dim = config.v_head_dim + qk_nope_head_dim = config.qk_nope_head_dim + w_kc, w_vc = weight.unflatten(0, (-1, qk_nope_head_dim + v_head_dim)).split([qk_nope_head_dim, v_head_dim], + dim=1) + w_vc = w_vc.transpose(1, 2).contiguous() + kc_param_name = name.replace('.kv_b_proj', '.kc') + param_kc = params_dict[kc_param_name] + load_weight(param_kc, w_kc) + vc_param_name = name.replace('.kv_b_proj', '.vc') + param_vc = params_dict[vc_param_name] + load_weight(param_vc, w_vc) + + def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype): + """Dequant weight.""" + dim_w0, dim_w1 = weight.shape + dim_s0, dim_s1 = scale.shape + assert dim_w0 % dim_s0 == 0 + assert dim_w1 % dim_s1 == 0 + group0 = dim_w0 // dim_s0 + group1 = dim_w1 // dim_s1 + weight = weight.reshape(dim_s0, group0, dim_s1, group1) + scale = scale.reshape(dim_s0, 1, dim_s1, 1) + weight = weight.to(scale.dtype) * scale + weight = weight.to(dtype) + weight = weight.reshape(dim_w0, dim_w1) + return weight + + def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor): + """Dequant weight.""" + if name.endswith('.weight'): + weight_name = name + scale_name = name.replace('.weight', '.scale') + elif name.endswith('.weight_scale_inv'): + weight_name = name.replace('.weight_scale_inv', '.weight') + scale_name = name + self._load_buffers[name] = loaded_weight + if (weight_name in self._load_buffers and scale_name in self._load_buffers): + weight = self._load_buffers.pop(weight_name) + scale = self._load_buffers.pop(scale_name) + kc_param_name = weight_name.replace('.kv_b_proj', '.kc') + dtype = params_dict[kc_param_name].dtype + weight = __dequant_weight(weight, scale, dtype) + __load_kcvc(weight_name, weight) + + for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping: + if mod_name not in name: + continue + if name.endswith('.weight_scale_inv'): + weight = loaded_weight + else: + loaded_weight = loaded_weight.to(device) + weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) + param = params_dict[name] + load_weight(param, weight) + break + else: + if '.kv_b_proj' in name: + quantization_config = self.quantization_config + quant_method = None + if quantization_config is not None: + quant_method = quantization_config.get('quant_method') + + loaded_weight = loaded_weight.to(device) + if quant_method == 'fp8': + # update blocked fp8 weight + __load_kcvc_blocked_fp8(name, loaded_weight) + else: + __load_kcvc(name, loaded_weight) + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + + def __skip_nextn(name, nextn_keys): + for nextn_key in nextn_keys: + if nextn_key in name: + return True + return False + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + config = self.config + + qk_rope_head_dim = config.qk_rope_head_dim + kv_lora_rank = config.kv_lora_rank + qk_nope_head_dim = config.qk_nope_head_dim + q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + kv_dim = kv_lora_rank + qk_rope_head_dim + update_pe_mapping = [('q_proj', q_head_dim, qk_nope_head_dim), ('q_b_proj', q_head_dim, qk_nope_head_dim), + ('kv_a_proj_with_mqa', kv_dim, kv_lora_rank)] + + num_experts = self.config.n_routed_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + num_hidden_layers = self.config.num_hidden_layers + + num_nextn_predict_layers = getattr(self.config, 'num_nextn_predict_layers', 1) + nextn_keys = [f'.layers.{num_hidden_layers+i}' for i in range(num_nextn_predict_layers)] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + # keep nextn + if not __skip_nextn(name, nextn_keys): + continue + if '.layers' in name: + layer_idx = int(name.split('layers.')[1].split('.')[0]) + name = self._rewrite_spec_layer_name(layer_idx, name) + if '.experts' in name: + self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping) + elif '.self_attn' in name and getattr(config, 'use_mla', True): + # attention + self._load_weight_attention(name, loaded_weight, params_dict, update_pe_mapping) + else: + # other + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """Rewrite the weight name to match the format of the original model. + + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = ['embed_tokens', 'enorm', 'hnorm', 'eh_proj', 'shared_head'] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f'model.layers.{spec_layer}.', f'model.layers.{spec_layer}.mtp_block.') + return name diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 5d9ce0a854..2af9ff2ccb 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -12,13 +12,13 @@ build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMixin +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" - def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None): + def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True): super().__init__() quantization_config = getattr(config, 'quantization_config', None) num_heads = config.num_attention_heads @@ -37,6 +37,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch dtype=dtype, device=device, num_replicate_kv_heads=num_replicate_kv_heads, + is_tp=is_tp, ) # rotary embedding @@ -57,7 +58,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch quant_config=quantization_config, dtype=dtype, device=device, - is_tp=True) + is_tp=is_tp) def forward( self, @@ -105,7 +106,7 @@ def forward( class LlamaMLP(nn.Module): """Llama mlp.""" - def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None): + def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch.device = None, is_tp: bool = True): super().__init__() quantization_config = getattr(config, 'quantization_config', None) # gate up @@ -117,7 +118,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch dtype=dtype, device=device, quant_config=quantization_config, - is_tp=True, + is_tp=is_tp, ) # silu and mul @@ -130,7 +131,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch quant_config=quantization_config, dtype=dtype, device=device, - is_tp=True) + is_tp=is_tp) def forward(self, x): """forward.""" @@ -142,16 +143,21 @@ def forward(self, x): class LlamaDecoderLayer(nn.Module): """Llama decoder layer.""" - def __init__(self, config: LlamaConfig, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None): + def __init__(self, + config: LlamaConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = True): super().__init__() self.layer_idx = layer_idx quantization_config = getattr(config, 'quantization_config', None) # build attention layer - self.self_attn = LlamaAttention(config, dtype=dtype, device=device) + self.self_attn = LlamaAttention(config, dtype=dtype, device=device, is_tp=is_tp) # build MLP - self.mlp = LlamaMLP(config, dtype=dtype, device=device) + self.mlp = LlamaMLP(config, dtype=dtype, device=device, is_tp=is_tp) # build input layer norm self.input_layernorm = RMSNorm(config.hidden_size, @@ -217,7 +223,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch LlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.num_hidden_layers) ]) - + self.aux_hidden_state_layers: Tuple[int] = getattr(config, 'aux_hidden_state_layers', tuple()) # build norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) @@ -245,10 +251,14 @@ def forward( cos, sin = cos[0], sin[0] rotary_pos_emb = (cos, sin) + # for eagle3 + aux_hidden_states = [] # decoding residual = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = decoder_layer( hidden_states, rotary_pos_emb=rotary_pos_emb, @@ -260,6 +270,9 @@ def forward( # norm hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) > 0: + aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) + return dict(hidden_states=hidden_states, aux_hidden_states=aux_hidden_states) return hidden_states def get_input_embeddings(self): @@ -331,6 +344,15 @@ def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() + def get_outputs_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: torch.Tensor, **kwargs): + """Get outputs from buffers.""" + num_tokens = input_ids.size(-1) + outputs = dict() + outputs['hidden_states'] = graph_meta.output_buffers['hidden_states'][:, :num_tokens] + if 'aux_hidden_states' in graph_meta.output_buffers: + outputs['aux_hidden_states'] = graph_meta.output_buffers['aux_hidden_states'][:, :num_tokens] + return outputs + def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], diff --git a/lmdeploy/pytorch/models/llama_eagle.py b/lmdeploy/pytorch/models/llama_eagle.py new file mode 100644 index 0000000000..d581e1ff3d --- /dev/null +++ b/lmdeploy/pytorch/models/llama_eagle.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.nn import build_rotary_embedding_from_config +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .llama import LlamaDecoderLayer +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin + + +class EagleLlamaDecoderLayer(LlamaDecoderLayer): + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if layer_idx == 0: + del self.input_layernorm + setattr(self, 'input_layernorm', lambda x: x) + + +class EagleLlamaModel(nn.Module): + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + EagleLlamaDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + # build fc + self.fc = nn.Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=dtype, + device=device, + ) + + # build rotary embedding in LlamaModel + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + previous_hidden_states: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + # token embedding + if inputs_embeds is None: + assert input_ids is not None + inputs_embeds = self.embed_tokens(input_ids) + previous_hidden_states = previous_hidden_states.to(inputs_embeds) + hidden_states = torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + residual + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class EagleLlamaForCausalLM(nn.Module, CudaGraphMixin): + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, config, ctx_mgr, dtype=None, device=None): + nn.Module.__init__(self) + self.config = config + self.ctx_mgr = ctx_mgr + self.dtype = dtype + # build LLamaModel + self.model = EagleLlamaModel(config, dtype=dtype, device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + target_hidden_states: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + previous_hidden_states=target_hidden_states, + ) + return hidden_states + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + target_hidden_states = context.target_hidden_states + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + target_hidden_states=target_hidden_states, + ) + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1, + max_tokens, + self.config.hidden_size, + dtype=self.dtype) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + + num_tokens = kwargs['input_ids'].size(-1) + + is_decoding = graph_meta.is_decoding + input_buffers = graph_meta.input_buffers + padded_num_tokens = new_inputs['input_ids'].size(-1) + + target_hidden_states = kwargs.get('target_hidden_states') + assert target_hidden_states is not None + input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states + if is_decoding: + new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'][:, :padded_num_tokens, :] + else: + new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'] + + return new_inputs + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name = 'model.' + name + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/llama_eagle3.py b/lmdeploy/pytorch/models/llama_eagle3.py new file mode 100644 index 0000000000..42f21f80f9 --- /dev/null +++ b/lmdeploy/pytorch/models/llama_eagle3.py @@ -0,0 +1,317 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.nn import RMSNorm, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .llama import LlamaDecoderLayer +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin + + +class Eagle3LlamaDecoderLayer(LlamaDecoderLayer): + """Llama decoder layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config, layer_idx, dtype=dtype, device=device, is_tp=False) + self.layer_idx = layer_idx + + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + + # override attention qkv + self.self_attn.qkv_proj = build_qkv_proj( + 2 * hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.hidden_norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + attn_metadata: Any = None, + ): + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class Eagle3LlamaModel(nn.Module): + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.dtype = dtype + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build layer + self.midlayer = Eagle3LlamaDecoderLayer(config, layer_idx=0, dtype=dtype, device=device) + target_hidden_size = getattr(config, 'target_hidden_size', config.hidden_size) + self.fc = build_rowwise_linear( + target_hidden_size * 3, + config.hidden_size, + bias=False, + dtype=dtype, + device=device, + ) + + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + # build rotary embedding in LlamaModel + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + previous_hidden_states: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + # token embedding + if inputs_embeds is None: + assert input_ids is not None + inputs_embeds = self.embed_tokens(input_ids).to(self.dtype) + previous_hidden_states = previous_hidden_states.to(inputs_embeds) + if previous_hidden_states.shape[-1] != inputs_embeds.shape[-1]: + # previous_hidden_states if from target model + previous_hidden_states = self.fc(previous_hidden_states) + # rotary embedding + cos, sin = self.rotary_emb(previous_hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + past_key_value = past_key_values[0] + hidden_states, residual = self.midlayer( + inputs_embeds, + previous_hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + hidden_states, hidden_states_prenorm = self.norm(hidden_states, residual) + outputs = dict(hidden_states=hidden_states, hidden_states_prenorm=hidden_states_prenorm) + return outputs + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class Eagle3LlamaForCausalLM(nn.Module, CudaGraphMixin): + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, config, ctx_mgr, dtype=None, device=None): + nn.Module.__init__(self) + self.config = config + self.ctx_mgr = ctx_mgr + self.dtype = dtype + + if config.num_hidden_layers != 1: + raise ValueError('eagle3 only supports 1 decode layer') + + # build LLamaModel + self.model = Eagle3LlamaModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.draft_vocab_size, + bias=False, + dtype=dtype, + device=device) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros(self.config.draft_vocab_size, dtype=torch.long, device=device), + requires_grad=False, + ) + self.include_embed_tokens = False + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + target_hidden_states: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + previous_hidden_states=target_hidden_states, + ) + return hidden_states + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + target_hidden_states = context.target_hidden_states + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + target_hidden_states=target_hidden_states, + ) + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + logits = self.lm_head(hidden_states) + return logits + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + target_hidden_states = kwargs.get('target_hidden_states') + assert target_hidden_states is not None + target_hidden_size = target_hidden_states.size(-1) + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + input_buffers['target_hidden_states'] = input_buffers['input_ids'].new_zeros(1, + max_tokens, + target_hidden_size, + dtype=self.dtype) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + + num_tokens = kwargs['input_ids'].size(-1) + + input_buffers = graph_meta.input_buffers + + target_hidden_states = kwargs.get('target_hidden_states') + assert target_hidden_states is not None + input_buffers['target_hidden_states'][:, :num_tokens] = target_hidden_states + + new_inputs['target_hidden_states'] = input_buffers['target_hidden_states'] + + return new_inputs + + def get_outputs_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: torch.Tensor, **kwargs): + """Get outputs from buffers.""" + num_tokens = input_ids.size(-1) + outputs = dict() + outputs['hidden_states'] = graph_meta.output_buffers['hidden_states'][:, :num_tokens] + outputs['hidden_states_prenorm'] = graph_meta.output_buffers['hidden_states_prenorm'][:, :num_tokens] + return outputs + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'd2t' in name: + name = 'draft_id_to_target_id' + base = torch.arange(self.config.draft_vocab_size, + device=loaded_weight.device, + dtype=loaded_weight.dtype) + loaded_weight += base + elif 'lm_head.weight' not in name: + name = 'model.' + name + if 'embed_tokens' in name: + self.include_embed_tokens = True + if 't2d' in name: + continue + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 498e2c6554..3d5823fdda 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -232,3 +232,13 @@ }) CUSTOM_MODULE_MAP = dict() + +# spec models +# eagle llama +MODULE_MAP.update({'EagleLlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle.EagleLlamaForCausalLM'}) + +# eagle3 llama +MODULE_MAP.update({'Eagle3LlamaForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama_eagle3.Eagle3LlamaForCausalLM'}) + +# deepseek mtp +MODULE_MAP.update({'DeepseekMTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_mtp.DeepseekMTPModel'}) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 065aef97d1..a4b07bb949 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -34,6 +34,7 @@ class CudaGraphMeta: input_buffers: BuffType = None output_buffers: BuffType = None vocab_size: int = 1 + decode_query_len: int = 1 class CudaGraphMixin: @@ -57,6 +58,7 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> max_tokens = graph_meta.max_tokens num_blocks = graph_meta.num_blocks device = graph_meta.device + decode_query_len = graph_meta.decode_query_len input_buffers: BuffType = dict() input_buffers['input_ids'] = torch.randint(0, @@ -64,16 +66,27 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> dtype=torch.int64, device=device) input_buffers['position_ids'] = torch.zeros((1, max_tokens), dtype=torch.int64, device=device) - if getattr(self.config, 'use_flash_mla', False) is True: - import flash_mla + seqlens_dtype = torch.int64 + use_flash_mla = getattr(self.config, 'use_flash_mla', False) + # use fa3 decode kernel for spec decode + use_flash_attn3_decoding = decode_query_len > 1 and not use_flash_mla + if use_flash_mla is True: + import flash_mla + if graph_meta.is_decoding: + seqlens_dtype = torch.int32 # create buffers for flash mla input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata( - torch.ones(max_batches, dtype=torch.int32, device=device), self.config.num_attention_heads, 1) + torch.ones(max_batches, dtype=torch.int32, device=device), + self.config.num_attention_heads * decode_query_len, 1) + + if use_flash_attn3_decoding is True: + seqlens_dtype = torch.int32 + input_buffers['scheduler_metadata'] = torch.zeros(max_batches + 1, dtype=torch.int32, device=device) # flash_mla requires block_offsets and kv_lens int32 - input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device) - input_buffers['qkv_lens'] = torch.zeros(3, max_batches, dtype=torch.int32, device=device) + input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=seqlens_dtype, device=device) + input_buffers['qkv_lens'] = torch.zeros(3, max_batches, dtype=seqlens_dtype, device=device) input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] @@ -89,7 +102,6 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p **kwargs) -> Dict[str, Tensor]: """Fill cudagraph buffers from forward inputs.""" - is_decoding = graph_meta.is_decoding block_offsets: Tensor = attn_metadata.block_offsets q_start_loc: Tensor = attn_metadata.q_start_loc q_seqlens: Tensor = attn_metadata.q_seqlens @@ -98,7 +110,7 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p batch_size, num_blocks = block_offsets.size() num_tokens = input_ids.size(-1) - + decode_query_len = graph_meta.decode_query_len # fill buffer input_buffers['input_ids'].random_(0, graph_meta.vocab_size) input_buffers['input_ids'][:, :num_tokens] = input_ids @@ -121,16 +133,39 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p attn_metadata.q_start_loc = input_buffers['q_start_loc'] attn_metadata.q_seqlens = input_buffers['q_seqlens'] attn_metadata.kv_seqlens = input_buffers['kv_seqlens'] - if getattr(self.config, 'use_flash_mla', False) is True: + + use_flash_mla = getattr(self.config, 'use_flash_mla', False) + # use fa3 decode kernel for spec decode + use_flash_attn3_decoding = decode_query_len > 1 and not use_flash_mla + + if use_flash_mla is True: import flash_mla - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32), - self.config.num_attention_heads, 1) + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + attn_metadata.kv_seqlens.to(torch.int32), self.config.num_attention_heads * decode_query_len, 1) # here we use copy_ instead of = to avoid using new allocated mem for cuda graph input_buffers['tile_scheduler_metadata'].copy_(tile_scheduler_metadata) input_buffers['num_splits'][:new_batch_size + 1].copy_(num_splits[:new_batch_size + 1]) attn_metadata.tile_scheduler_metadata = input_buffers['tile_scheduler_metadata'] attn_metadata.num_splits = input_buffers['num_splits'] + if use_flash_attn3_decoding: + from flash_attn_interface import get_scheduler_metadata + block_size = past_key_values[0][0].size(1) + scheduler_metadata = get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=decode_query_len, + max_seqlen_k=attn_metadata.max_kv_seqlen, + num_heads_q=self.config.num_attention_heads, + num_heads_kv=self.config.num_key_value_heads, + headdim=self.config.head_dim, + cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32), + qkv_dtype=self.config.torch_dtype, + page_size=block_size, + ) + input_buffers['scheduler_metadata'].zero_() + input_buffers['scheduler_metadata'][:batch_size + 1].copy_(scheduler_metadata[:batch_size + 1]) + attn_metadata.scheduler_metadata = input_buffers['scheduler_metadata'] + new_inputs = dict( past_key_values=past_key_values, attn_metadata=attn_metadata, @@ -141,18 +176,11 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p # TODO: update cross_attn_metadata here new_inputs['cross_attn_metadata'] = cross_attn_metadata - if is_decoding: - new_inputs['input_ids'] = input_buffers['input_ids'] - new_inputs['position_ids'] = input_buffers['position_ids'] - else: - new_inputs['input_ids'] = input_buffers['input_ids'] - new_inputs['position_ids'] = input_buffers['position_ids'] + new_inputs['input_ids'] = input_buffers['input_ids'] + new_inputs['position_ids'] = input_buffers['position_ids'] if inputs_embeds is not None: - if is_decoding: - new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] - else: - new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] new_inputs.update(kwargs) return new_inputs @@ -170,3 +198,10 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte context.q_seqlens = input_buffers['q_seqlens'] context.kv_seqlens = input_buffers['kv_seqlens'] context.q_start_loc = input_buffers['q_start_loc'] + + def get_outputs_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, **kwargs): + """Get outputs from buffers.""" + num_tokens = input_ids.size(-1) + outputs = dict() + outputs['hidden_states'] = graph_meta.output_buffers['hidden_states'][:, :num_tokens] + return outputs diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index e19bd18141..eb903bacc5 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -37,13 +37,14 @@ class Scheduler: cache_config (CacheConfig): The config of cache info. """ - def __init__(self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - seq_meta: SequenceMeta = None) -> None: + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + seq_meta: SequenceMeta = None, + ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.sessions: Dict[int, SchedulerSession] = OrderedDict() # For Disaggregation @@ -300,7 +301,7 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill(0) + output = self._schedule_prefill(prealloc_size) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output diff --git a/lmdeploy/pytorch/spec_decode/__init__.py b/lmdeploy/pytorch/spec_decode/__init__.py new file mode 100644 index 0000000000..869409f6ca --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .spec_agent import SpecModelAgent + +__all__ = ['SpecModelAgent'] diff --git a/lmdeploy/pytorch/spec_decode/proposers/__init__.py b/lmdeploy/pytorch/spec_decode/proposers/__init__.py new file mode 100644 index 0000000000..a95a19e6f7 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/proposers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. + +from .deepseek_mtp import DeepseekMTP # noqa F401 +from .eagle import Eagle # noqa F401 +from .eagle3 import Eagle3 # noqa F401 diff --git a/lmdeploy/pytorch/spec_decode/proposers/base.py b/lmdeploy/pytorch/spec_decode/proposers/base.py new file mode 100644 index 0000000000..9f6d073788 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/proposers/base.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional + +import torch +from mmengine import Registry +from torch.profiler import record_function + +from lmdeploy.utils import get_logger + +from ...config import ModelConfig, SpecDecodeConfig +from ...engine.cache_engine import CacheEngine +from ...model_inputs import ModelInputs, step_ctx_manager +from ...models.patch import build_patched_model, update_custom_module_map +from ...strategies.base.model_agent import ExtraInputs +from ...weight_loader.model_weight_loader import load_model_weights + +SPEC_PROPOSERS = Registry('spec_proposers') + +logger = get_logger('lmdeploy') + + +@torch.inference_mode() +def draft_model_forward( + model: torch.nn.Module, + inputs: ModelInputs, + model_config: Optional[ModelConfig] = None, + cache_engine: Optional[CacheEngine] = None, +): + """Perform model forward.""" + stream = torch.cuda.current_stream() + with torch.cuda.stream(stream), step_ctx_manager(model.ctx_mgr): + # forward + ctx_mgr = model.ctx_mgr + kv_caches = None if cache_engine is None else cache_engine.gpu_cache + context = ctx_mgr.build_context( + inputs=inputs, + model_config=model_config, + kv_caches=kv_caches, + ) + with ctx_mgr.context(context): + model_metas = None + model_metas = model.update_model_metas( + past_key_values=kv_caches, + context=context, + ) + input_dict = model.prepare_inputs_for_generation( + past_key_values=kv_caches, + context=context, + ) + outputs = model(**input_dict) + if not isinstance(outputs, dict): + outputs = dict(hidden_states=outputs) + outputs.update(dict(model_metas=model_metas)) + return outputs + + +class BaseSpecProposer: + + def __init__(self, specdecode_config: SpecDecodeConfig, device: torch.device = None): + self.specdecode_config = specdecode_config + self.model = None + self.device = device + self.lm_head = None + self.num_speculative_tokens = specdecode_config.num_speculative_tokens + self.target_model = None + + def build_model(self, + empty_init: bool, + target_model: torch.nn.Module = None, + model_format=None, + build_model_ctx=None): + if self.specdecode_config is None: + return + model_path = self.specdecode_config.model + model_config = self.specdecode_config.model_config + custom_module_map = model_config.custom_module_map + if custom_module_map is not None: + update_custom_module_map(custom_module_map) + logger.debug('build draft model') + patched_model = build_patched_model( + model_config, + device=self.device, + model_format=model_format, + build_model_ctx=build_model_ctx, + ) + logger.debug('loading weights for draft model.') + if not empty_init: + load_model_weights(patched_model, model_path, device=self.device) + self.model = patched_model + self.target_model = target_model + + def get_outputs(self, + model_outputs: Dict[str, torch.Tensor], + model_inputs: ModelInputs, + extra_inputs: ExtraInputs = None): + """Get outputs.""" + raise NotImplementedError() + + @record_function('draft_model_forward') + def _forward(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None): + """Forward.""" + return draft_model_forward( + self.model, + model_inputs, + model_config=self.specdecode_config.model_config, + cache_engine=cache_engine, + ) + + def update_inputs_decoding(self, model_inputs: ModelInputs, extra_inputs: ExtraInputs, next_input_ids: torch.Tensor, + target_hidden_states: torch.Tensor, model_metas: List[Any]): + """Update to decoding inputs.""" + model_inputs.is_decoding = True + batch_size = model_inputs.seq_length.size(0) + model_inputs.input_ids = next_input_ids + model_inputs.max_q_seqlen = 1 + model_inputs.max_kv_seqlen += 1 + model_inputs.sum_kv_seqlen += model_inputs.seq_length.numel() + model_inputs.history_lengths += model_inputs.seq_length + if extra_inputs.num_rejected_tokens is not None: + model_inputs.history_lengths -= extra_inputs.num_rejected_tokens + model_inputs.seq_length = model_inputs.seq_length.new_ones(batch_size) + model_inputs.target_position_ids = model_inputs.history_lengths.unsqueeze(0).clone() + model_inputs.model_metas = model_metas + model_inputs.target_hidden_states = target_hidden_states + return model_inputs + + @record_function('draft_get_logits') + def get_logits(self, hidden_states: torch.Tensor): + """Get logits of model output.""" + draft_model = self.model + if not isinstance(draft_model, torch.nn.Module): + draft_model = draft_model.model + + if hasattr(draft_model, 'get_logits'): + logits = draft_model.get_logits(hidden_states) + else: + logits = self.target_model.get_logits(hidden_states) + return logits + + def get_target_hidden_size(self, model_config: ModelConfig): + """Get target hidden size.""" + return model_config.hidden_size + + +def build_specdecode_proposer(specdecode_config: SpecDecodeConfig, device: str = 'cuda'): + """Build spec decoding proposer.""" + method = specdecode_config.method + if method in SPEC_PROPOSERS.module_dict: + spec_cls = SPEC_PROPOSERS.module_dict[method] + obj = spec_cls(specdecode_config, device=device) + return obj + raise ValueError(f'{method} not found in {SPEC_PROPOSERS.module_dict.keys()}') diff --git a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py new file mode 100644 index 0000000000..de19beb761 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch + +from lmdeploy.utils import get_logger + +from ...model_inputs import ModelInputs +from ...strategies.ar_spec.model_agent import ARSpecExtraInputs +from .base import SPEC_PROPOSERS, BaseSpecProposer + +logger = get_logger('lmdeploy') + + +@SPEC_PROPOSERS.register_module(name='deepseek_mtp') +class DeepseekMTP(BaseSpecProposer): + + def get_outputs(self, + model_outputs: Dict[str, torch.Tensor], + model_inputs: ModelInputs, + extra_inputs: ARSpecExtraInputs = None): + """Get outputs.""" + hidden_states = model_outputs['hidden_states'] + model_metas = model_outputs['model_metas'] + if extra_inputs is not None and extra_inputs.last_token_indices is not None: + # for long input + if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1: + hidden_states = hidden_states[:, -1:] + else: + last_token_loc = extra_inputs.last_token_indices + hidden_states = hidden_states[:, last_token_loc] + + logits = self.get_logits(hidden_states)[0] + draft_token_ids = logits.argmax(dim=-1, keepdim=True) + return draft_token_ids, model_metas, hidden_states diff --git a/lmdeploy/pytorch/spec_decode/proposers/eagle.py b/lmdeploy/pytorch/spec_decode/proposers/eagle.py new file mode 100644 index 0000000000..f9a84a154b --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/proposers/eagle.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from .base import SPEC_PROPOSERS +from .deepseek_mtp import DeepseekMTP + + +@SPEC_PROPOSERS.register_module(name='eagle') +class Eagle(DeepseekMTP): + """Eagle.""" diff --git a/lmdeploy/pytorch/spec_decode/proposers/eagle3.py b/lmdeploy/pytorch/spec_decode/proposers/eagle3.py new file mode 100644 index 0000000000..1ca4e703b2 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/proposers/eagle3.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch + +from lmdeploy.utils import get_logger + +from ...config import ModelConfig +from ...model_inputs import ModelInputs +from ...strategies.base.model_agent import ExtraInputs +from .base import SPEC_PROPOSERS +from .deepseek_mtp import DeepseekMTP + +logger = get_logger('lmdeploy') + + +@SPEC_PROPOSERS.register_module(name='eagle3') +class Eagle3(DeepseekMTP): + + def build_model(self, + empty_init: bool, + target_model: torch.nn.Module = None, + model_format=None, + build_model_ctx=None): + super().build_model(empty_init, + target_model=target_model, + model_format=model_format, + build_model_ctx=build_model_ctx) + self.draft_id_to_target_id = self.model.draft_id_to_target_id + if not self.model.include_embed_tokens: + logger.info('Using embed_tokens from target model.') + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_model.get_input_embeddings() + + def get_target_hidden_size(self, model_config: ModelConfig): + """Get target hidden size.""" + hf_config = self.specdecode_config.model_config.hf_config + hidden_size = getattr(hf_config, 'target_hidden_size', hf_config.hidden_size) + return hidden_size * 3 + + def get_outputs(self, + model_outputs: Dict[str, torch.Tensor], + model_inputs: ModelInputs, + extra_inputs: ExtraInputs = None): + """Get outputs.""" + hidden_states = model_outputs['hidden_states'] + hidden_states_prenorm = model_outputs['hidden_states_prenorm'] + model_metas = model_outputs['model_metas'] + if extra_inputs is not None and extra_inputs.last_token_indices is not None: + # for long input + if (not model_inputs.is_decoding) and model_inputs.seq_length.size(0) == 1: + hidden_states = hidden_states[:, -1:] + hidden_states_prenorm = hidden_states_prenorm[:, -1:] + else: + last_token_loc = extra_inputs.last_token_indices + hidden_states = hidden_states[:, last_token_loc] + hidden_states_prenorm = hidden_states_prenorm[:, last_token_loc] + + logits = self.get_logits(hidden_states)[0] + draft_token_ids = logits.argmax(dim=-1, keepdim=True) + # token mapping + draft_token_ids = self.draft_id_to_target_id[draft_token_ids] + return draft_token_ids, model_metas, hidden_states_prenorm diff --git a/lmdeploy/pytorch/spec_decode/reject_sampler.py b/lmdeploy/pytorch/spec_decode/reject_sampler.py new file mode 100644 index 0000000000..b2c4e34946 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/reject_sampler.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import enum +from typing import Optional + +import torch +from torch import LongTensor, Tensor, nn +from torch.profiler import record_function + + +class SamplePolicy(enum.Enum): + """Sample policy.""" + + ALL_GREEDY = enum.auto() + + +class RejectionSampler(nn.Module): + + def __init__(self, sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY): + super().__init__() + self.sample_policy = sample_policy + + def forward( + self, + target_logits: Tensor, + draft_token_ids: LongTensor, + bonus_token_ids: LongTensor, + draft_probs: Optional[Tensor] = None, + ): + """forward + Args: + target_logits (Tensor): The logits of target model in shape of [batch_size, num_spec_tokens, vocab_size]. + draft_token_ids (LongTensor): The input draft tokens ishape of [batch_size, num_spec_tokens] + bonus_token_ids (LongTensor): The bonus token ids in shape of [batch_size, 1]. + draft_probs (Tensor): The probability of draft model in shape of [batch_size, num_spec_tokens, vocab_size]. + Default to ``None``. + """ + output_token_ids, num_rejected_tokens, last_token_ids = rejection_sample( + target_logits, + draft_token_ids, + bonus_token_ids, + draft_probs=draft_probs, + ) + return output_token_ids, num_rejected_tokens, last_token_ids + + +@record_function('rejection_sample') +def rejection_sample( + target_probs: Tensor, + draft_token_ids: LongTensor, + bonus_token_ids: LongTensor, + sample_policy: SamplePolicy = SamplePolicy.ALL_GREEDY, + draft_probs: Optional[Tensor] = None, +): + """rejection sample + Args: + target_probs (Tensor): + + """ + assert draft_probs is None or draft_probs.is_contiguous() + assert sample_policy == SamplePolicy.ALL_GREEDY, 'only support all greedy sampling policy' + + target_argmax_tokens = target_probs.argmax(dim=-1) + return greedy_reject_sampler(draft_token_ids, target_argmax_tokens, bonus_token_ids) + + +def greedy_reject_sampler(draft_token_ids, target_token_ids, bonus_token_ids): + """Greedy reject sampler + 1. keep targets tokens that are equal to draft tokens + 2. keep first not equal target tokens + 3. add bonus tokens if all equal + Args: + draft_token_ids: (batch_size, num_spec_tokens) + target_token_ids: (batch_size, num_spec_tokens) + bonus_token_ids: (batch_size, 1) + Returns: + output_token_ids: (batch_size, num_spec_tokens + 1) + """ + masks = draft_token_ids == target_token_ids + batch_size, num_spec_tokens = draft_token_ids.shape + # check rest draft tokens + range_data = torch.arange(num_spec_tokens, device=draft_token_ids.device)[None, :] + equals = (masks.cumsum(dim=1) - 1) == range_data + num_rejected_tokens = num_spec_tokens - equals.sum(dim=1) + first_diff_indices = torch.argmin(equals.int(), dim=1, keepdim=True) + keeps = range_data.repeat(batch_size, 1) <= first_diff_indices + keeps = keeps | equals + keep_token_ids = torch.where(keeps, target_token_ids, -1) + # add bonus tokens + keep_bonus_ids = torch.where(equals[:, -1:], bonus_token_ids, -1) + output_token_ids = torch.cat([keep_token_ids, keep_bonus_ids], dim=1) + # get last token ids + last_indices = (torch.cat([keeps, equals[:, -1:]], dim=1).cumsum(dim=1) - 1)[:, -1].flatten() + last_token_ids = output_token_ids[torch.arange(batch_size, device=draft_token_ids.device), last_indices] + return output_token_ids, num_rejected_tokens, last_token_ids diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py new file mode 100644 index 0000000000..ff6dc5992a --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -0,0 +1,272 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import asyncio +from typing import Dict + +import torch + +from lmdeploy.utils import get_logger + +from ..backends import get_backend +from ..config import BackendConfig, CacheConfig, ModelConfig, SpecDecodeConfig +from ..distributed import DistContext +from ..engine.cache_engine import CacheEngine +from ..engine.logits_process import SamplingInputs +from ..model_inputs import ModelInputs +from ..strategies.ar_spec.model_agent import ARSpecExtraInputs +from ..strategies.base.model_agent import ExtraInputs +from .proposers.base import build_specdecode_proposer +from .reject_sampler import RejectionSampler + +logger = get_logger('lmdeploy') + + +class SpecModelAgent: + """Speculative model agent.""" + + def __init__( + self, + specdecode_config: SpecDecodeConfig, + backend_config: BackendConfig, + dist_ctx: DistContext, + inputs_strategy, + agent_strategy, + device: str = 'cuda', + ): + self.method = specdecode_config.method + self.model_config = specdecode_config.model_config + self.cache_config = specdecode_config.cache_config + self.num_spec_tokens = specdecode_config.num_speculative_tokens + self.backend_config = backend_config + self.device = device + self.dist_ctx = dist_ctx + + self.proposer = build_specdecode_proposer(specdecode_config, device=device) + self.cache_engine = None + self.inputs_strategy = inputs_strategy + self.agent_strategy = agent_strategy + self.to_runable = dist_ctx.dp > 1 or dist_ctx.rank % dist_ctx.tp == 0 + self.rejection_sampler = RejectionSampler() + + def set_cache_config(self, cache_config: CacheConfig): + """Set all cache config.""" + self.cache_config = cache_config + + def set_model_config(self, model_config: ModelConfig): + """Set model config.""" + self.model_config = model_config + + def build_model(self, empty_init: bool, target_model=None, model_format=None, build_model_ctx=None): + """Build draft model.""" + if not self.to_runable: + return + self.proposer.build_model(empty_init, + target_model=target_model, + model_format=model_format, + build_model_ctx=build_model_ctx) + + def build_graph_runner(self): + """Build graph runner.""" + if not self.to_runable: + return + backend = get_backend() + self.proposer.model = backend.build_graph_runner(self.proposer.model, + model_config=self.model_config, + cache_config=self.cache_config, + backend_config=self.backend_config, + device=self.device) + + def build_cache_engine(self, cache_stream: torch.cuda.Stream): + """Build cache engine.""" + if not self.to_runable: + return + if self.cache_config is not None: + self.cache_engine = CacheEngine(self.cache_config, + self.model_config, + rank=0, + tp_rank=0, + world_size=1, + cache_stream=cache_stream) + + def rejection_sampling(self, next_token_ids, model_inputs: 'ModelInputs', extra_inputs: ExtraInputs): + """Do rejection sampling.""" + num_rejected_tokens = None + bonus_token_ids = output_token_ids = next_token_ids.unsqueeze(-1) + last_token_indices = model_inputs.seq_length.cumsum(0) - 1 + if model_inputs.is_decoding: + # only do rejection sample for decoding with draft tokens + input_draft_token_ids = model_inputs.input_ids.squeeze(0).unflatten(0, (-1, self.num_spec_tokens + 1))[:, + 1:] + output_token_ids, num_rejected_tokens, next_token_ids = self.rejection_sampler( + extra_inputs.target_logits, + input_draft_token_ids, + bonus_token_ids, + ) + # update last token indices + last_token_indices = last_token_indices - num_rejected_tokens + + # create new inputs + input_ids = model_inputs.input_ids.clone() + seq_length = model_inputs.seq_length + # # offset by 1 token + input_ids[:, :-1] = model_inputs.input_ids[:, 1:] + # # update next tokens + input_ids[:, last_token_indices] = next_token_ids + # use new inputs + new_model_inputs = ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + max_kv_seqlen=model_inputs.max_kv_seqlen, + max_q_seqlen=model_inputs.max_q_seqlen, + sum_kv_seqlen=model_inputs.sum_kv_seqlen, + history_lengths=model_inputs.history_lengths.clone(), + block_offsets=model_inputs.block_offsets, + num_ignored_history=model_inputs.num_ignored_history, + is_decoding=model_inputs.is_decoding, + target_hidden_states=extra_inputs.target_hidden_states, + target_position_ids=extra_inputs.target_position_ids, + ) + new_extra_inputs = ARSpecExtraInputs( + next_token_ids=next_token_ids, + last_token_indices=last_token_indices, + num_rejected_tokens=num_rejected_tokens, + output_token_ids=output_token_ids, + loop_last_step=extra_inputs.loop_last_step, + ) + return new_model_inputs, new_extra_inputs + + def _forward_impl(self, inputs: ModelInputs): + """Forward impl.""" + output = self.proposer._forward(inputs, cache_engine=self.cache_engine) + return output + + async def async_forward(self, inputs: ModelInputs): + """Model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + """ + output = self._forward_impl(inputs) + await asyncio.sleep(0) + return output + + async def _async_model_forward(self, inputs: ModelInputs, extra_inputs: ExtraInputs, + sampling_inputs: SamplingInputs): + """Model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + """ + max_prefill_token_num = self.cache_config.max_prefill_token_num + + async def __long_context_single_forward(new_inputs): + """One large sequence.""" + model_metas = new_inputs[0].model_metas + for inp in new_inputs: + inp.model_metas = model_metas + output = await self.async_forward(inp) + model_metas = output.get('model_metas') + return output + + # make long context inputs + is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding + + if is_long_context: + seq_len = inputs.seq_length + batch_size = seq_len.size(0) + assert batch_size == 1, 'Do not support batched long context.' + inputs_li = inputs.split(max_prefill_token_num) + outputs = await __long_context_single_forward(inputs_li) + else: + outputs = await self.async_forward(inputs) + + loop_count = self.num_spec_tokens - 1 + draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs, extra_inputs) + draft_tokens_li = [draft_token_ids] + if loop_count > 0: + # set last_token_indices to None for decoding + extra_inputs.last_token_indices = None + inputs = self.proposer.update_inputs_decoding(inputs, extra_inputs, draft_token_ids.transpose(0, 1), + target_hidden_states, model_metas) + for loop_idx in range(loop_count): + outputs = await self.async_forward(inputs) + draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs) + draft_tokens_li.append(draft_token_ids) + if loop_idx < loop_count - 1: + step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0)) + inputs.step(draft_token_ids.transpose(0, 1), step_seqlens) + inputs.model_metas = model_metas + inputs.target_hidden_states = target_hidden_states + if inputs.target_position_ids is not None: + inputs.target_position_ids += 1 + + output_draft_ids = torch.cat(draft_tokens_li, dim=-1) + return output_draft_ids + + async def async_model_forward( + self, + next_token_ids: torch.Tensor, + model_inputs: ModelInputs, + extra_inputs: ExtraInputs, + sampling_inputs: SamplingInputs, + ): + """Draft model forward.""" + draft_model_inputs, draft_extra_inputs = self.rejection_sampling(next_token_ids, model_inputs, extra_inputs) + next_draft_ids = await self._async_model_forward(draft_model_inputs, draft_extra_inputs, sampling_inputs) + draft_extra_inputs.output_draft_token_ids = next_draft_ids + return draft_extra_inputs + + def warmup(self, max_batches: int, target_model_config: ModelConfig): + """warmup.""" + if not self.to_runable: + return + + target_hidden_size = self.proposer.get_target_hidden_size(target_model_config) + + # warmup prefill + inputs = self.inputs_strategy.make_dummy(max_batches, + is_decoding=False, + device='cuda', + vocab_size=self.model_config.vocab_size, + target_hidden_size=target_hidden_size, + target_dtype=self.model_config.dtype) + + self._forward_impl(inputs) + + capture_batch_sizes = self.proposer.model.get_capture_batch_sizes() + capture_batch_sizes = sorted(capture_batch_sizes, reverse=True) + + for batch_size in capture_batch_sizes: + # decode with num_spec_tokens + 1 per seq + inputs = self.inputs_strategy.make_dummy( + batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=self.num_spec_tokens + 1, + target_hidden_size=target_hidden_size, + target_dtype=self.model_config.dtype, + ) + self._forward_impl(inputs) + # decode 1 tokens per sequence + inputs = self.inputs_strategy.make_dummy( + batch_size, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size, + max_q_seqlen=1, + target_hidden_size=self.model_config.hidden_size, + target_dtype=self.model_config.dtype, + ) + self._forward_impl(inputs) + + def update_main_model_outputs(self, output: Dict[str, torch.Tensor], model_inputs: ModelInputs): + """Update outputs of main model.""" + hidden_states = output['hidden_states'] + if not model_inputs.is_decoding: + logits_indices = model_inputs.seq_length.cumsum(0) - 1 + hidden_states = hidden_states[:, logits_indices] + if 'aux_hidden_states' in output: + # replace with aux + output['hidden_states'] = output.pop('aux_hidden_states') + return hidden_states, output diff --git a/lmdeploy/pytorch/strategies/__init__.py b/lmdeploy/pytorch/strategies/__init__.py index c0f5da1262..2cd8cddc9d 100644 --- a/lmdeploy/pytorch/strategies/__init__.py +++ b/lmdeploy/pytorch/strategies/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.config import MiscConfig, ModelConfig +from lmdeploy.pytorch.config import MiscConfig, ModelConfig, SpecDecodeConfig -def build_strategy_factory(model_config: ModelConfig, misc_config: MiscConfig): +def build_strategy_factory(model_config: ModelConfig, + misc_config: MiscConfig, + specdecode_config: SpecDecodeConfig = None): """Build strategy factory.""" model_paradigm = model_config.model_paradigm @@ -12,5 +14,9 @@ def build_strategy_factory(model_config: ModelConfig, misc_config: MiscConfig): elif model_paradigm == 'dllm': from .dllm import DLLMStrategyFactory return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config) + elif model_paradigm == 'ar_spec': + from .ar_spec import ARSpecStrategyFactory + assert specdecode_config is not None, 'specdecode_config must be provided for ar_spec model' + return ARSpecStrategyFactory(model_config=model_config, specdecode_config=specdecode_config) else: raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}') diff --git a/lmdeploy/pytorch/strategies/ar/cudagraph.py b/lmdeploy/pytorch/strategies/ar/cudagraph.py index e3749bcfc2..142539fb9b 100644 --- a/lmdeploy/pytorch/strategies/ar/cudagraph.py +++ b/lmdeploy/pytorch/strategies/ar/cudagraph.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + from ..base.cudagraph import CudagraphStrategy class ARCudagraphStrategy(CudagraphStrategy): - def get_max_tokens(self, batch_size: int) -> int: + def get_max_tokens(self, batch_size: int, input_ids: torch.Tensor, q_seqlens: torch.Tensor) -> int: """Get max tokens.""" return batch_size diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 94429dae3f..727653df4e 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import torch from torch.profiler import record_function @@ -63,7 +63,8 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t last_idx = seq_length.cumsum(-1) - 1 return inputs[last_idx] - def slice_extra_inputs(self, extra_inputs: ARExtraInputs, seq_length: torch.LongTensor) -> ARExtraInputs: + def slice_extra_inputs(self, extra_inputs: ARExtraInputs, model_inputs: ModelInputs, + model_outputs: Dict[str, torch.Tensor], **kwargs) -> ARExtraInputs: """Slice outputs.""" return extra_inputs diff --git a/lmdeploy/pytorch/strategies/ar_spec/__init__.py b/lmdeploy/pytorch/strategies/ar_spec/__init__.py new file mode 100644 index 0000000000..416d20460c --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.config import ModelConfig, SpecDecodeConfig +from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy + +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy + from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base import StrategyFactoryBase + + +class ARSpecStrategyFactory(StrategyFactoryBase): + + def __init__(self, model_config: ModelConfig, specdecode_config: SpecDecodeConfig): + """config.""" + self.model_config = model_config + self.specdecode_config = specdecode_config + self.pad_token_id = model_config.bos_token_id or 0 + + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + from .cudagraph import ARSpecCudagraphStrategy + return ARSpecCudagraphStrategy(self.specdecode_config.num_speculative_tokens) + + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + from .sampling import ARSpecSamplingStrategy + pad_token_id = self.model_config.bos_token_id + pad_token_id = 0 if pad_token_id is None else pad_token_id + return ARSpecSamplingStrategy(pad_token_id) + + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + from .model_inputs import ARSpecModelInputsStrategy + return ARSpecModelInputsStrategy(self.specdecode_config.num_speculative_tokens) + + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + from .model_agent import ARSpecModelAgentStrategy + return ARSpecModelAgentStrategy(self.specdecode_config.num_speculative_tokens) + + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + from .engine import ARSpecEngineStrategy + return ARSpecEngineStrategy(cache_config=cache_config, + scheduler_config=scheduler_config, + num_spec_tokens=self.specdecode_config.num_speculative_tokens) + + def build_sequence_strategy(self) -> SequenceStrategy: + from .sequence import ARSpecSequenceStrategy + return ARSpecSequenceStrategy() diff --git a/lmdeploy/pytorch/strategies/ar_spec/cudagraph.py b/lmdeploy/pytorch/strategies/ar_spec/cudagraph.py new file mode 100644 index 0000000000..5ff1c779e8 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/cudagraph.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..base.cudagraph import CudagraphStrategy + + +class ARSpecCudagraphStrategy(CudagraphStrategy): + + def __init__(self, num_spec_tokens: int): + super().__init__() + self.num_spec_tokens = num_spec_tokens + + def get_max_tokens(self, batch_size: int, input_ids: torch.Tensor, q_seqlens: torch.Tensor) -> int: + """Get max tokens.""" + num_tokens = input_ids.size(1) + orig_batch = q_seqlens.size(0) + if num_tokens == orig_batch: + return batch_size + + assert num_tokens % (self.num_spec_tokens + 1) == 0, 'The input_ids length must be divisible by batch_size.' + return batch_size * (self.num_spec_tokens + 1) diff --git a/lmdeploy/pytorch/strategies/ar_spec/engine.py b/lmdeploy/pytorch/strategies/ar_spec/engine.py new file mode 100644 index 0000000000..c7cf2f9495 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/engine.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base.engine import EngineStrategy + + +class ARSpecEngineStrategy(EngineStrategy): + """AR Engine Strategy.""" + + def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, num_spec_tokens: int) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.num_spec_tokens = num_spec_tokens + + def get_prealloc_size(self, is_decoding: bool): + """Get prealloc_size.""" + return self.scheduler_config.prefill_interval * (1 + + self.num_spec_tokens) if is_decoding else self.num_spec_tokens + + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + return self.scheduler_config.prefill_interval if is_decoding else 1 diff --git a/lmdeploy/pytorch/strategies/ar_spec/model_agent.py b/lmdeploy/pytorch/strategies/ar_spec/model_agent.py new file mode 100644 index 0000000000..0f4dc19891 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/model_agent.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.profiler import record_function + +import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.distributed import DistContext +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..ar.model_agent import ARStoppingCriteria +from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy + +SeqList = List[SchedulerSequence] + + +@dataclass +class ARSpecExtraInputs(ExtraInputs): + """ARSpec extra inputs.""" + # draft model inputs + target_logits: torch.Tensor = None + target_hidden_states: torch.Tensor = None + target_position_ids: torch.Tensor = None + next_token_ids: torch.LongTensor = None + last_token_indices: torch.LongTensor = None + + # draft model outputs + output_draft_token_ids: torch.Tensor = None + num_rejected_tokens: torch.Tensor = None + output_token_ids: torch.Tensor = None + loop_last_step: bool = None + + def __repr__(self): + return (f'ARSpecExtraInputs(next_token_ids={self.next_token_ids}, ' + f'output_draft_token_ids={self.output_draft_token_ids}, ' + f'last_token_indices={self.last_token_indices}, ' + f'num_rejected_tokens={self.num_rejected_tokens}, ' + f'output_token_ids={self.output_token_ids}, ' + f'loop_last_step={self.loop_last_step})') + + def broadcast(self, src: int, group, async_op=False): + dist.broadcast(self.output_draft_token_ids, src=src, group=group, async_op=async_op) + handle = dist.broadcast(self.num_rejected_tokens, src=src, group=group, async_op=async_op) + return handle + + +@dataclass +class ARSpecExtraOutputs(ExtraOutputs): + """ARSpec extra outputs.""" + # output the draft tokens to seq only for last loop step + draft_token_ids: torch.Tensor = None + + def __repr__(self): + return (f'ARSpecExtraOutputs(draft_token_ids={self.draft_token_ids})') + + +@dataclass +class ARSpecStoppingCriteria(ARStoppingCriteria): + num_appendable_ids: torch.Tensor + + @record_function('stopping_criteria') + def step(self, + next_token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional[ModelInputs] = None, + extra_inputs: Optional[ARSpecExtraInputs] = None): + """Check whether to stop generation.""" + token_ids = extra_inputs.output_token_ids + + if token_ids.ndim == 1: + token_ids = token_ids.unsqueeze(-1) + valid_tokens = token_ids > -1 + mask = (self.num_appendable_ids.unsqueeze(-1) - valid_tokens.cumsum(dim=-1)) <= 0 + if stop_words is not None: + token_ids_rsp = token_ids.unsqueeze(-1).repeat(1, 1, stop_words.numel()) + stop_words_rsp = stop_words.reshape(1, 1, -1) + assert stop_words_rsp.ndim == token_ids_rsp.ndim == 3 + stop_mask = (token_ids_rsp == stop_words_rsp).any(-1) + mask = mask ^ stop_mask + # find the index of first `1`, if not found, would be 0 + index = torch.argmax(mask.int(), dim=-1, keepdim=True) + # update index of 0 to -1 if not found + stop_pos = torch.where(index == 0, mask[:, 0:1].int() - 1, index).ravel() + stopped = stop_pos != -1 + num_valid_tokens = valid_tokens.sum(dim=-1) + num_appendable_ids = self.num_appendable_ids - num_valid_tokens + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(stopped, one_ids, num_appendable_ids) + return stopped, stop_pos, ARSpecStoppingCriteria(num_appendable_ids=num_appendable_ids) + + +class ARSpecModelAgentStrategy(ModelAgentStrategy): + + def __init__(self, num_spec_tokens: int): + self.num_spec_tokens = num_spec_tokens + + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + # batch size == 1 + if len(seq_length) == 1: + return inputs[-1:] + + if len(seq_length) == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(-1) - 1 + return inputs[last_idx] + + def slice_extra_inputs(self, + extra_inputs: ARSpecExtraInputs, + model_inputs: ModelInputs, + model_outputs: Dict[str, torch.Tensor], + is_last_step: bool = None, + **kwargs) -> ARSpecExtraInputs: + """Slice outputs.""" + extra_inputs = ARSpecExtraInputs() + extra_inputs.target_hidden_states = model_outputs.get('hidden_states') + extra_inputs.target_position_ids = model_outputs.get('position_ids', None) + if model_inputs.is_decoding: + batch_size = model_inputs.seq_length.size(0) + logits = model_outputs['logits'][0] + extra_inputs.target_logits = logits.unflatten(0, (batch_size, -1))[:, :-1] + + # extra_inputs. + extra_inputs.loop_last_step = is_last_step + return extra_inputs + + def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor): + """step.""" + sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1 + + all_ids = sampling_inputs.all_ids + if all_ids is not None: + sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1) + + return sampling_inputs + + def make_stopping_criteria(self, seqs: SeqList) -> ARSpecStoppingCriteria: + """Create stopping criteria.""" + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + return ARSpecStoppingCriteria(num_appendable_ids=num_appendable) + + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + return ARSpecExtraInputs() + + def make_extra_outputs(self, extra_inputs: ARSpecExtraInputs) -> ARSpecExtraOutputs: + """Create extra outputs.""" + output = ARSpecExtraOutputs() + # only output draft tokens to seq for last loop step + if extra_inputs.loop_last_step is True: + output.draft_token_ids = extra_inputs.output_draft_token_ids + return output + + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ARSpecExtraInputs, + **kwargs): + """Step next inputs.""" + model_inputs.model_metas = model_metas + step_seqlens = model_inputs.seq_length + batch_size = step_seqlens.size(0) + + step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens + input_ids = next_token_ids.new_empty((batch_size, self.num_spec_tokens + 1)) + input_ids[:, 0] = next_token_ids + input_ids[:, 1:] = extra_inputs.output_draft_token_ids + input_ids = input_ids.flatten()[None, :] + model_inputs.step(input_ids, step_seqlens) + self._step_sampling_inputs(sampling_inputs, next_token_ids) + return model_inputs, extra_inputs + + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: ARSpecExtraInputs): + """Post sampling.""" + return next_token_ids, extra_inputs + + def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs): + """Make dummy next token for broadcast.""" + with torch.inference_mode(): + next_token_ids = inputs.input_ids.new_zeros(logits.size(0)) + extra_inputs.output_draft_token_ids = inputs.input_ids.new_zeros((logits.size(0), self.num_spec_tokens)) + extra_inputs.num_rejected_tokens = inputs.input_ids.new_zeros(logits.size(0)) + return next_token_ids, extra_inputs + + @contextmanager + def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ARSpecExtraInputs, + dist_ctx: DistContext): + """Broadcast next token ids and extra inputs.""" + tp_gpu_group = dist_ctx.tp_gpu_group + dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + handle = extra_inputs.broadcast(src=0, group=tp_gpu_group, async_op=True) + yield + handle.wait() diff --git a/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py new file mode 100644 index 0000000000..d6862cd8b9 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/model_inputs.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs + + +class ARSpecModelInputsStrategy(ModelInputsStrategy): + + def __init__(self, num_spec_tokens: int): + self.num_spec_tokens = num_spec_tokens + + def make_dummy( + self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1, + max_q_seqlen: int = 1, + target_hidden_size: int = None, + target_dtype: torch.dtype = torch.bfloat16, + ) -> ModelInputs: + """Create dummy model inputs.""" + inputs = make_dummy_inputs(batch_size, + max_q_seqlen=max_q_seqlen, + is_decoding=is_decoding, + device=device, + dummy_block_id=dummy_block_id, + vocab_size=vocab_size) + if target_hidden_size is not None: + inputs.target_hidden_states = torch.randn((1, batch_size * max_q_seqlen, target_hidden_size), + dtype=target_dtype, + device=device) + return inputs diff --git a/lmdeploy/pytorch/strategies/ar_spec/sampling.py b/lmdeploy/pytorch/strategies/ar_spec/sampling.py new file mode 100644 index 0000000000..3d5bf670ca --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/sampling.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..ar.sampling import ARSamplingStrategy + + +class ARSpecSamplingStrategy(ARSamplingStrategy): + """Sampling strategy for AR with spec models.""" diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py new file mode 100644 index 0000000000..ba4236e988 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import numpy as np +from torch import Tensor + +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.model_agent import BatchedOutputs +from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, + SchedulerSession, UpdateTokenMode, _to_ndarray) + +from ..ar.sequence import ARSequenceStrategy, SchedulerSequenceDefault + +SeqList = List['SchedulerSequenceARSpec'] + + +@dataclass +class SchedulerSequenceARSpec(SchedulerSequenceDefault): + + def __post_init__(self): + """Post init.""" + super().__post_init__() + self._num_spec_ids: int = 0 + self._num_new_valid: int = 0 + self._num_valid_ids: int = len(self.history_cache) + self._strategy: ARSpecSequenceStrategy = self._seq_meta.strategy + + @property + def num_valid_ids(self): + return self._num_valid_ids + + @property + def num_spec_ids(self): + return self._num_spec_ids + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + + def set_stop_pos(self, pos: int): + val = self._num_new_valid - pos - 1 + self._num_valid_ids -= val + self.num_new_tokens -= val + self._num_token_ids = 1 + self._num_history_ids -= val + + self._num_spec_ids = 0 + self._num_new_valid = 0 + self.history_cache.resize(self.num_valid_ids) + + def _update_token_ids_inputs(self, token_ids: np.ndarray): + """Append tokens.""" + num_tokens = len(token_ids) + self.output_start_pos = self.num_valid_ids + num_tokens + self._num_valid_ids = self.num_history_ids + num_tokens + self._num_token_ids = num_tokens + self.num_new_tokens = 0 + self._num_spec_ids = 0 + self._num_new_valid = 0 + self.history_cache.append(token_ids) + + def _update_token_ids_prefill(self, token_ids: np.ndarray, draft_token_ids: np.ndarray): + """Update token ids for prefill.""" + num_valid = len(token_ids) + self._num_spec_ids = len(draft_token_ids) + token_ids = np.concatenate([token_ids, draft_token_ids]) + num_tokens = len(token_ids) + self._num_history_ids += self._num_token_ids + self._num_token_ids = num_tokens + self.num_new_tokens += num_valid + self._num_new_valid = num_valid + self._num_valid_ids = self.num_history_ids + num_valid + self.history_cache.append(token_ids) + + def _update_token_ids_decode(self, token_ids: np.ndarray, draft_token_ids: np.ndarray = None): + """Update token ids for decode.""" + valid_ids = token_ids[token_ids > -1] + num_valid = len(valid_ids) + self.num_new_tokens = self.num_new_tokens + num_valid + + self._num_new_valid = num_valid + self._num_valid_ids += num_valid + self._num_history_ids = self.num_valid_ids - 1 + + # last step has spec ids + if self.num_spec_ids > 0: + token_ids = valid_ids[-1:] + else: + token_ids = valid_ids + + num_tokens = len(token_ids) + + if draft_token_ids is not None: + num_tokens = 1 + len(draft_token_ids) + token_ids = np.concatenate([token_ids, draft_token_ids]) + self._num_spec_ids = len(draft_token_ids) + else: + self._num_spec_ids = 0 + + self._num_token_ids = num_tokens + if self.num_history_ids < len(self.history_cache): + self.history_cache.resize(self.num_history_ids) + self.history_cache.append(token_ids) + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + draft_token_ids: Tensor = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(multimodals) + + self.arrive_time = time.perf_counter() + + token_ids: np.ndarray = _to_ndarray(token_ids) + if draft_token_ids is not None: + draft_token_ids = _to_ndarray(draft_token_ids) + if mode == UpdateTokenMode.INPUTS: + self._update_token_ids_inputs(token_ids) + elif mode == UpdateTokenMode.PREFILL: + self._update_token_ids_prefill(token_ids, draft_token_ids) + else: + self._update_token_ids_decode(token_ids, draft_token_ids) + if model_meta is not None: + self.model_meta = model_meta + + +class ARSpecSequenceStrategy(ARSequenceStrategy): + + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequenceARSpec': + """Make sequence.""" + return SchedulerSequenceARSpec(seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: + """Update running sequences.""" + next_token_ids = batched_outputs.next_token_ids + extra_outputs = batched_outputs.extra_outputs + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + stop_pos = batched_outputs.stop_pos + + batch_size = len(running) + next_token_ids = next_token_ids.view(batch_size, -1).numpy() + if extra_outputs is None or extra_outputs.draft_token_ids is None: + draft_token_ids = [None] * batch_size + else: + draft_token_ids = extra_outputs.draft_token_ids.numpy() + stop_pos = stop_pos.tolist() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + + for idx, token in enumerate(next_token_ids): + msg = running[idx] + stop = stopped[idx] + model_meta = model_metas[idx] + if msg.status != MessageStatus.LOCKED: + continue + cur_draft_tokens = draft_token_ids[idx] + # fill token + msg.update_token_ids(token, draft_token_ids=cur_draft_tokens, model_meta=model_meta, mode=update_mode) + if stop: + msg.set_stop_pos(stop_pos[idx]) + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/lmdeploy/pytorch/strategies/base/cudagraph.py b/lmdeploy/pytorch/strategies/base/cudagraph.py index 795c3b5350..8bf728329d 100644 --- a/lmdeploy/pytorch/strategies/base/cudagraph.py +++ b/lmdeploy/pytorch/strategies/base/cudagraph.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +import torch + class CudagraphStrategy(ABC): @abstractmethod - def get_max_tokens(self, batch_size: int) -> int: + def get_max_tokens(self, batch_size: int, input_ids: torch.Tensor, q_seqlens: torch.Tensor) -> int: """Get max tokens.""" pass diff --git a/lmdeploy/pytorch/strategies/base/model_agent.py b/lmdeploy/pytorch/strategies/base/model_agent.py index 53dfd57a3f..e5974c5bdc 100644 --- a/lmdeploy/pytorch/strategies/base/model_agent.py +++ b/lmdeploy/pytorch/strategies/base/model_agent.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np import torch @@ -105,7 +105,8 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t pass @abstractmethod - def slice_extra_inputs(self, extra_inputs: ExtraInputs, seq_length: torch.LongTensor) -> ExtraInputs: + def slice_extra_inputs(self, extra_inputs: ExtraInputs, model_inputs: 'ModelInputs', + model_outputs: Dict[str, torch.Tensor], **kwargs) -> ExtraInputs: """Slice outputs.""" pass diff --git a/lmdeploy/pytorch/strategies/dllm/cudagraph.py b/lmdeploy/pytorch/strategies/dllm/cudagraph.py index 2e388b22de..44688b642e 100644 --- a/lmdeploy/pytorch/strategies/dllm/cudagraph.py +++ b/lmdeploy/pytorch/strategies/dllm/cudagraph.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + from ..base.cudagraph import CudagraphStrategy @@ -8,6 +10,6 @@ def __init__(self, block_size: int) -> None: super().__init__() self.block_size = block_size - def get_max_tokens(self, batch_size: int) -> int: + def get_max_tokens(self, batch_size: int, input_ids: torch.Tensor, q_seqlens: torch.Tensor) -> int: """Get max tokens.""" return batch_size * self.block_size diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py index 849708819b..72d3e89a86 100644 --- a/lmdeploy/pytorch/strategies/dllm/model_agent.py +++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -154,9 +154,10 @@ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> t inputs = inputs[index] return inputs - def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, seq_length: torch.LongTensor) -> DLLMExtraInputs: + def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, model_inputs: ModelInputs, + model_outputs: Dict[str, torch.Tensor], **kwargs) -> DLLMExtraInputs: """Slice outputs.""" - dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, seq_length) + dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, model_inputs.seq_length) return DLLMExtraInputs(dllm_mask=dllm_mask) def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, diff --git a/lmdeploy/pytorch/third_party/flash_attn_interface.py b/lmdeploy/pytorch/third_party/flash_attn_interface.py index 2cc0e3cdf1..5dbcc2fbec 100644 --- a/lmdeploy/pytorch/third_party/flash_attn_interface.py +++ b/lmdeploy/pytorch/third_party/flash_attn_interface.py @@ -2,6 +2,7 @@ import functools from flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func +from flash_attn_interface import flash_attn_with_kvcache as _flash_attn_with_kvcache @functools.wraps(_flash_attn_varlen_func) @@ -11,3 +12,9 @@ def flash_attn_varlen_func(*args, **kwargs): # for old api return output[0] return output + + +@functools.wraps(_flash_attn_with_kvcache) +def flash_attn_with_kvcache(*args, **kwargs): + output = _flash_attn_with_kvcache(*args, **kwargs) + return output diff --git a/lmdeploy/pytorch/utils.py b/lmdeploy/pytorch/utils.py index de67c23578..5b96d11d42 100644 --- a/lmdeploy/pytorch/utils.py +++ b/lmdeploy/pytorch/utils.py @@ -6,6 +6,10 @@ import psutil +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + def get_gpu_memory(device_id: int = None) -> int: """Returns the free and total physical memory of the GPU in bytes.""" @@ -27,3 +31,67 @@ def bind_sigature(input_names: str, args: Sequence, kwargs: Dict): sig = Signature([Parameter(name, kind) for name in input_names]) bind = sig.bind(*args, **kwargs) return bind.arguments + + +# from vllm +def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: + """Try to register HF model configuration class to serialize by value With + trust_remote_code, the config class is typically an instance of a custom + class imported from the HF modules cache. + + The class will not be + importable in spawned workers by default (and won't exist at all on + other nodes), which breaks serialization of the config. + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. This + registration only works if the modules cache has already been + initialized. + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ # noqa: E501 + if not trust_remote_code: + return + + try: + import transformers_modules + except ImportError: + logger.debug('Could not import transformers_modules used for remote' + ' code. If remote code is not needed remove' + ' `--trust-remote-code`.') + return + + try: + import cloudpickle + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + try: + import ray + except ImportError: + return + + ray.cloudpickle.register_pickle_by_value(transformers_modules) + + # multiprocessing uses pickle to serialize arguments when using spawn + # Here we get pickle to use cloudpickle to serialize ModelConfig objects + # that contain instances of the custom config class to avoid + # serialization problems if the generated module (and model) has a `.` + # in its name + import multiprocessing + import pickle + + from lmdeploy.pytorch.config import ModelConfig + + def _reduce_modelconfig(mc: ModelConfig): + return (pickle.loads, (cloudpickle.dumps(mc), )) + + multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig) + + except Exception as e: + logger.warning( + 'Unable to register remote classes used by' + ' trust_remote_code with by-value serialization. This may' + ' lead to a later error. If remote code is not needed' + ' remove `--trust-remote-code`', + exc_info=e) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 5322bcfd55..d8032a7cb1 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -18,9 +18,10 @@ from lmdeploy import Tokenizer from lmdeploy.archs import get_model_arch from lmdeploy.logger import RequestLogger -from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig +from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, SpeculativeConfig, + TurbomindEngineConfig) from lmdeploy.metrics.metrics_processor import metrics_processor -from lmdeploy.metrics.stats import IterationStats, RequestState +from lmdeploy.metrics.stats import IterationStats, RequestState, SpeculativeDecodingStats from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) @@ -300,10 +301,11 @@ def __init__(self, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, chat_template_config: Optional[ChatTemplateConfig] = None, max_log_len: int = None, + speculative_config: SpeculativeConfig = None, **kwargs) -> None: logger.info(f'input backend={backend}, backend_config={backend_config}') logger.info(f'input chat_template_config={chat_template_config}') - + logger.info(f'speculative_config={speculative_config}') backend_config = backend_config or (TurbomindEngineConfig() if backend == 'turbomind' else PytorchEngineConfig()) self.model_name = model_name if model_name else model_path @@ -322,12 +324,17 @@ def __init__(self, self.session_len = (_get_and_verify_max_len(cfg, None) if backend_config.session_len is None else backend_config.session_len) backend_config.session_len = self.session_len + if speculative_config is not None and backend == 'turbomind': + logger.warning('speculative decoding is not supported by turbomind ') # build backend engine if backend == 'turbomind': self.engine = self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs) self.hf_tm_cfg = self.engine.config elif backend == 'pytorch': - self.engine = self._build_pytorch(model_path=model_path, backend_config=backend_config, **kwargs) + self.engine = self._build_pytorch(model_path=model_path, + backend_config=backend_config, + speculative_config=speculative_config, + **kwargs) self.hf_tm_cfg = getattr(self.engine.model_config, 'hf_config', None) else: raise ValueError(f'unsupported backend {backend}') @@ -348,6 +355,8 @@ def __init__(self, self.request_logger = RequestLogger(max_log_len) self.internal_thread = _EventLoopThread(daemon=True) self.limiter: asyncio.Semaphore = None + self.num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \ + else speculative_config.num_speculative_tokens # build stat loggers self._build_stat_loggers() @@ -383,10 +392,11 @@ def _build_turbomind(self, def _build_pytorch(self, model_path: str, backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, + speculative_config: SpeculativeConfig = None, **kwargs): """Innter build method for pytorch backend.""" from lmdeploy.pytorch.engine import Engine - return Engine.from_pretrained(model_path, engine_config=backend_config) + return Engine.from_pretrained(model_path, engine_config=backend_config, speculative_config=speculative_config) def _build_stat_loggers(self): self.stat_loggers = [] @@ -849,7 +859,9 @@ def is_error(status): req_state = RequestState(prompt_tokens=input_len) # per-requst state async for outputs in gen: iteration_stats = IterationStats() # per-iteration stats - metrics_processor.queue_update((outputs, req_state, iteration_stats)) + specdecode_stats = SpeculativeDecodingStats( + self.num_spec_token) if self.num_spec_token > 0 else None + metrics_processor.queue_update((outputs, req_state, iteration_stats, specdecode_stats)) # decode res if is_error(outputs.status): break diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index cdbbaa7315..5f817d0077 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -22,7 +22,8 @@ from starlette.routing import Mount from lmdeploy.archs import get_task -from lmdeploy.messages import GenerationConfig, LogitsProcessor, PytorchEngineConfig, TurbomindEngineConfig +from lmdeploy.messages import (GenerationConfig, LogitsProcessor, PytorchEngineConfig, SpeculativeConfig, + TurbomindEngineConfig) from lmdeploy.metrics.metrics_processor import metrics_processor from lmdeploy.model import ChatTemplateConfig from lmdeploy.pytorch.disagg.config import DistServeEngineConfig @@ -1313,6 +1314,7 @@ def serve(model_path: str, reasoning_parser: Optional[str] = None, tool_call_parser: Optional[str] = None, allow_terminate_by_client: bool = False, + speculative_config: Optional[SpeculativeConfig] = None, **kwargs): """An example to perform model inference through the command line interface. @@ -1391,6 +1393,7 @@ def serve(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + speculative_config=speculative_config, **kwargs) # set reasoning parser and tool parser set_parsers(reasoning_parser, tool_call_parser) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index c9b887ccd4..927e6f7c3d 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,4 +1,5 @@ accelerate>=0.29.3 +cloudpickle dlinfer-ascend>=0.1.3 einops fastapi diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index e3742c58f8..d9834413b9 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -1,5 +1,6 @@ accelerate>=0.29.3 aiohttp +cloudpickle einops fastapi fire diff --git a/requirements/runtime_rocm.txt b/requirements/runtime_rocm.txt index 47d6f66fcd..cf8091d251 100644 --- a/requirements/runtime_rocm.txt +++ b/requirements/runtime_rocm.txt @@ -1,4 +1,5 @@ accelerate>=0.29.3 +cloudpickle einops fastapi fire