Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vllm-detector-adapter"
version = "0.9.0"
version = "0.10.0"
authors = [
{ name = "Gaurav Kumbhat", email = "[email protected]" },
{ name = "Evaline Ju", email = "[email protected]" },
Expand All @@ -15,9 +15,8 @@ dependencies = ["orjson>=3.10.16,<3.11"]
[project.optional-dependencies]
vllm-tgis-adapter = ["vllm-tgis-adapter>=0.8.0,<0.9.0"]
vllm = [
# Note: vllm < 0.10.0 has issues with transformers >= 4.54.0
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.10.1,<0.11.1 ; sys_platform != 'darwin'",
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
"vllm>=0.11.1,<0.12.1 ; sys_platform != 'darwin'",
]

## Dev Extra Sets ##
Expand Down
22 changes: 14 additions & 8 deletions tests/generative_detectors/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Standard
from dataclasses import dataclass
from typing import Optional
from unittest.mock import patch
from dataclasses import dataclass, field
from typing import Any, Optional
from unittest.mock import MagicMock, patch
import asyncio

# Third Party
Expand Down Expand Up @@ -49,17 +49,23 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand All @@ -78,18 +84,18 @@ async def _async_serving_detection_completion_init():
"""Initialize a chat completion base with string templates"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()
engine.model_config = MockModelConfig()
engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=BASE_MODEL_PATHS,
)

detection_completion = ChatCompletionDetectionBase(
task_template="hello {{user_text}}",
output_template="bye {{text}}",
engine_client=engine,
model_config=model_config,
models=models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
Expand Down
25 changes: 14 additions & 11 deletions tests/generative_detectors/test_granite_guardian.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Standard
from dataclasses import dataclass
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Optional
from unittest.mock import patch
from typing import Any, Optional
from unittest.mock import MagicMock, patch
import asyncio
import json

Expand Down Expand Up @@ -76,27 +76,30 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
class MockEngine:
async def get_model_config(self):
return MockModelConfig()

async def get_tokenizer(self):
return MockTokenizer()

Expand All @@ -105,18 +108,18 @@ async def _granite_guardian_init():
"""Initialize a granite guardian"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()
engine.model_config = MockModelConfig()
engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=BASE_MODEL_PATHS,
)

granite_guardian = GraniteGuardian(
task_template=None,
output_template=None,
engine_client=engine,
model_config=model_config,
models=models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
Expand Down
25 changes: 14 additions & 11 deletions tests/generative_detectors/test_llama_guard.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Standard
from dataclasses import dataclass
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Optional
from unittest.mock import patch
from typing import Any, Optional
from unittest.mock import MagicMock, patch
import asyncio

# Third Party
Expand Down Expand Up @@ -54,27 +54,30 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
diff_sampling_param: Optional[dict] = None
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
class MockEngine:
async def get_model_config(self):
return MockModelConfig()

async def get_tokenizer(self):
return MockTokenizer()

Expand All @@ -83,18 +86,18 @@ async def _llama_guard_init():
"""Initialize a llama guard"""
engine = MockEngine()
engine.errored = False
model_config = await engine.get_model_config()
engine.model_config = MockModelConfig()
engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=BASE_MODEL_PATHS,
)

llama_guard_detection = LlamaGuard(
task_template=None,
output_template=None,
engine_client=engine,
model_config=model_config,
models=models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
Expand Down
52 changes: 19 additions & 33 deletions vllm_detector_adapter/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.datastructures import State
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
Expand All @@ -19,7 +18,10 @@
from vllm.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser, is_valid_ipv6_address, set_ulimit
from vllm.entrypoints.utils import process_lora_modules
from vllm.reasoning import ReasoningParserManager
from vllm.utils import is_valid_ipv6_address, set_ulimit
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
import uvloop

Expand All @@ -36,14 +38,6 @@
)
from vllm_detector_adapter.utils import LocalEnvVarArgumentParser

try:
# Third Party
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
except ImportError:
# Third Party
from vllm.reasoning import ReasoningParserManager


TIMEOUT_KEEP_ALIVE = 5 # seconds

# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
Expand All @@ -61,42 +55,46 @@ def chat_detection(

async def init_app_state_with_detectors(
engine_client: EngineClient,
config, # ModelConfig | VllmConfig
state: State,
args: Namespace,
) -> None:
"""Add detection capabilities to app state"""
vllm_config = engine_client.vllm_config

if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]

if args.disable_log_requests:
request_logger = None
else:
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None

base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]

resolved_chat_template = load_chat_template(args.chat_template)

model_config = config
if type(config) != ModelConfig: # VllmConfig
model_config = config.model_config
# Merge default_mm_loras into the static lora_modules
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)

state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
lora_modules=lora_modules,
)

# Use vllm app state init
# init_app_state became async in https://github.com/vllm-project/vllm/pull/11727
# ref. https://github.com/opendatahub-io/vllm-tgis-adapter/pull/207
maybe_coroutine = api_server.init_app_state(engine_client, config, state, args)
maybe_coroutine = api_server.init_app_state(engine_client, state, args)
if inspect.isawaitable(maybe_coroutine):
await maybe_coroutine

Expand All @@ -107,7 +105,6 @@ async def init_app_state_with_detectors(
args.task_template,
args.output_template,
engine_client,
model_config,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
Expand Down Expand Up @@ -196,18 +193,7 @@ async def validation_exception_handler(
content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST
)

# api_server.init_app_state takes vllm_config
# ref. https://github.com/vllm-project/vllm/pull/16572
if hasattr(engine_client, "get_vllm_config"):
vllm_config = await engine_client.get_vllm_config()
await init_app_state_with_detectors(
engine_client, vllm_config, app.state, args
)
else:
model_config = await engine_client.get_model_config()
await init_app_state_with_detectors(
engine_client, model_config, app.state, args
)
await init_app_state_with_detectors(engine_client, app.state, args)

def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
Expand Down
2 changes: 1 addition & 1 deletion vllm_detector_adapter/start_with_tgis_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai import api_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
import uvloop

if TYPE_CHECKING:
Expand Down
16 changes: 15 additions & 1 deletion vllm_detector_adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

# Third Party
from vllm.utils import FlexibleArgumentParser, StoreBoolean
from vllm.utils.argparse_utils import FlexibleArgumentParser


class DetectorType(Enum):
Expand All @@ -16,6 +16,20 @@ class DetectorType(Enum):
TEXT_CONTEXT_DOC = auto()


# This is taken from vLLM < 0.11.1 for backwards compatibility.
# vLLM versions >=0.11.1 no longer include StoreBoolean.
class StoreBoolean(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(
f"Invalid boolean value: {values}. Expected 'true' or 'false'."
)


# LocalEnvVarArgumentParser and dependent functions taken from
# https://github.com/opendatahub-io/vllm-tgis-adapter/blob/main/src/vllm_tgis_adapter/tgis_utils/args.py
# vllm by default parses args from CLI, not from env vars, but env var overrides
Expand Down