diff --git a/pyproject.toml b/pyproject.toml index b8daeb4..3f11ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vllm-detector-adapter" -version = "0.9.0" +version = "0.10.0" authors = [ { name = "Gaurav Kumbhat", email = "kumbhat.gaurav@gmail.com" }, { name = "Evaline Ju", email = "evaline.ju@ibm.com" }, @@ -15,9 +15,8 @@ dependencies = ["orjson>=3.10.16,<3.11"] [project.optional-dependencies] vllm-tgis-adapter = ["vllm-tgis-adapter>=0.8.0,<0.9.0"] vllm = [ - # Note: vllm < 0.10.0 has issues with transformers >= 4.54.0 - "vllm @ git+https://github.com/vllm-project/vllm.git@v0.11.0 ; sys_platform == 'darwin'", - "vllm>=0.10.1,<0.11.1 ; sys_platform != 'darwin'", + "vllm @ git+https://github.com/vllm-project/vllm.git@v0.12.0 ; sys_platform == 'darwin'", + "vllm>=0.11.1,<0.12.1 ; sys_platform != 'darwin'", ] ## Dev Extra Sets ## diff --git a/tests/generative_detectors/test_base.py b/tests/generative_detectors/test_base.py index 29f517d..ebb0a61 100644 --- a/tests/generative_detectors/test_base.py +++ b/tests/generative_detectors/test_base.py @@ -1,7 +1,7 @@ # Standard -from dataclasses import dataclass -from typing import Optional -from unittest.mock import patch +from dataclasses import dataclass, field +from typing import Any, Optional +from unittest.mock import MagicMock, patch import asyncio # Third Party @@ -49,17 +49,23 @@ class MockHFConfig: @dataclass class MockModelConfig: task = "generate" + runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None - embedding_mode = False multimodal_config = MultiModalConfig() - diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} @@ -78,10 +84,11 @@ async def _async_serving_detection_completion_init(): """Initialize a chat completion base with string templates""" engine = MockEngine() engine.errored = False - model_config = await engine.get_model_config() + engine.model_config = MockModelConfig() + engine.input_processor = MagicMock() + engine.io_processor = MagicMock() models = OpenAIServingModels( engine_client=engine, - model_config=model_config, base_model_paths=BASE_MODEL_PATHS, ) @@ -89,7 +96,6 @@ async def _async_serving_detection_completion_init(): task_template="hello {{user_text}}", output_template="bye {{text}}", engine_client=engine, - model_config=model_config, models=models, response_role="assistant", chat_template=CHAT_TEMPLATE, diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 1432a5e..daf7882 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -1,8 +1,8 @@ # Standard -from dataclasses import dataclass +from dataclasses import dataclass, field from http import HTTPStatus -from typing import Optional -from unittest.mock import patch +from typing import Any, Optional +from unittest.mock import MagicMock, patch import asyncio import json @@ -76,17 +76,23 @@ class MockHFConfig: @dataclass class MockModelConfig: task = "generate" + runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None - embedding_mode = False multimodal_config = MultiModalConfig() - diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} @@ -94,9 +100,6 @@ def get_diff_sampling_param(self): @dataclass class MockEngine: - async def get_model_config(self): - return MockModelConfig() - async def get_tokenizer(self): return MockTokenizer() @@ -105,10 +108,11 @@ async def _granite_guardian_init(): """Initialize a granite guardian""" engine = MockEngine() engine.errored = False - model_config = await engine.get_model_config() + engine.model_config = MockModelConfig() + engine.input_processor = MagicMock() + engine.io_processor = MagicMock() models = OpenAIServingModels( engine_client=engine, - model_config=model_config, base_model_paths=BASE_MODEL_PATHS, ) @@ -116,7 +120,6 @@ async def _granite_guardian_init(): task_template=None, output_template=None, engine_client=engine, - model_config=model_config, models=models, response_role="assistant", chat_template=CHAT_TEMPLATE, diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 2754bb1..7f0a3f0 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -1,8 +1,8 @@ # Standard -from dataclasses import dataclass +from dataclasses import dataclass, field from http import HTTPStatus -from typing import Optional -from unittest.mock import patch +from typing import Any, Optional +from unittest.mock import MagicMock, patch import asyncio # Third Party @@ -54,17 +54,23 @@ class MockHFConfig: @dataclass class MockModelConfig: task = "generate" + runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None - embedding_mode = False multimodal_config = MultiModalConfig() - diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} @@ -72,9 +78,6 @@ def get_diff_sampling_param(self): @dataclass class MockEngine: - async def get_model_config(self): - return MockModelConfig() - async def get_tokenizer(self): return MockTokenizer() @@ -83,10 +86,11 @@ async def _llama_guard_init(): """Initialize a llama guard""" engine = MockEngine() engine.errored = False - model_config = await engine.get_model_config() + engine.model_config = MockModelConfig() + engine.input_processor = MagicMock() + engine.io_processor = MagicMock() models = OpenAIServingModels( engine_client=engine, - model_config=model_config, base_model_paths=BASE_MODEL_PATHS, ) @@ -94,7 +98,6 @@ async def _llama_guard_init(): task_template=None, output_template=None, engine_client=engine, - model_config=model_config, models=models, response_role="assistant", chat_template=CHAT_TEMPLATE, diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index 3bf241d..17318bd 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -9,7 +9,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from starlette.datastructures import State -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http @@ -19,7 +18,10 @@ from vllm.entrypoints.openai.protocol import ErrorInfo, ErrorResponse from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.utils import FlexibleArgumentParser, is_valid_ipv6_address, set_ulimit +from vllm.entrypoints.utils import process_lora_modules +from vllm.reasoning import ReasoningParserManager +from vllm.utils import is_valid_ipv6_address, set_ulimit +from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION import uvloop @@ -36,14 +38,6 @@ ) from vllm_detector_adapter.utils import LocalEnvVarArgumentParser -try: - # Third Party - from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager -except ImportError: - # Third Party - from vllm.reasoning import ReasoningParserManager - - TIMEOUT_KEEP_ALIVE = 5 # seconds # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) @@ -61,20 +55,21 @@ def chat_detection( async def init_app_state_with_detectors( engine_client: EngineClient, - config, # ModelConfig | VllmConfig state: State, args: Namespace, ) -> None: """Add detection capabilities to app state""" + vllm_config = engine_client.vllm_config + if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - if args.disable_log_requests: - request_logger = None - else: + if args.enable_log_requests: request_logger = RequestLogger(max_log_len=args.max_log_len) + else: + request_logger = None base_model_paths = [ BaseModelPath(name=name, model_path=args.model) for name in served_model_names @@ -82,21 +77,24 @@ async def init_app_state_with_detectors( resolved_chat_template = load_chat_template(args.chat_template) - model_config = config - if type(config) != ModelConfig: # VllmConfig - model_config = config.model_config + # Merge default_mm_loras into the static lora_modules + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, - model_config=model_config, base_model_paths=base_model_paths, - lora_modules=args.lora_modules, + lora_modules=lora_modules, ) # Use vllm app state init # init_app_state became async in https://github.com/vllm-project/vllm/pull/11727 # ref. https://github.com/opendatahub-io/vllm-tgis-adapter/pull/207 - maybe_coroutine = api_server.init_app_state(engine_client, config, state, args) + maybe_coroutine = api_server.init_app_state(engine_client, state, args) if inspect.isawaitable(maybe_coroutine): await maybe_coroutine @@ -107,7 +105,6 @@ async def init_app_state_with_detectors( args.task_template, args.output_template, engine_client, - model_config, state.openai_serving_models, args.response_role, request_logger=request_logger, @@ -196,18 +193,7 @@ async def validation_exception_handler( content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST ) - # api_server.init_app_state takes vllm_config - # ref. https://github.com/vllm-project/vllm/pull/16572 - if hasattr(engine_client, "get_vllm_config"): - vllm_config = await engine_client.get_vllm_config() - await init_app_state_with_detectors( - engine_client, vllm_config, app.state, args - ) - else: - model_config = await engine_client.get_model_config() - await init_app_state_with_detectors( - engine_client, model_config, app.state, args - ) + await init_app_state_with_detectors(engine_client, app.state, args) def _listen_addr(a: str) -> str: if is_valid_ipv6_address(a): diff --git a/vllm_detector_adapter/start_with_tgis_adapter.py b/vllm_detector_adapter/start_with_tgis_adapter.py index a889950..d56aaa1 100644 --- a/vllm_detector_adapter/start_with_tgis_adapter.py +++ b/vllm_detector_adapter/start_with_tgis_adapter.py @@ -21,7 +21,7 @@ from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai import api_server from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import uvloop if TYPE_CHECKING: diff --git a/vllm_detector_adapter/utils.py b/vllm_detector_adapter/utils.py index 62e77d7..f18da76 100644 --- a/vllm_detector_adapter/utils.py +++ b/vllm_detector_adapter/utils.py @@ -4,7 +4,7 @@ import os # Third Party -from vllm.utils import FlexibleArgumentParser, StoreBoolean +from vllm.utils.argparse_utils import FlexibleArgumentParser class DetectorType(Enum): @@ -16,6 +16,20 @@ class DetectorType(Enum): TEXT_CONTEXT_DOC = auto() +# This is taken from vLLM < 0.11.1 for backwards compatibility. +# vLLM versions >=0.11.1 no longer include StoreBoolean. +class StoreBoolean(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError( + f"Invalid boolean value: {values}. Expected 'true' or 'false'." + ) + + # LocalEnvVarArgumentParser and dependent functions taken from # https://github.com/opendatahub-io/vllm-tgis-adapter/blob/main/src/vllm_tgis_adapter/tgis_utils/args.py # vllm by default parses args from CLI, not from env vars, but env var overrides