diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 2b19741f7b..16ad395046 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -15,10 +15,8 @@ # IMPORT NOTE: Make sure this is in sync with lib/runtime/docker-compose.yml networks: - server: - driver: bridge - monitoring: - driver: bridge + dynamo: + driver: bridge # Note that the images are pinned to specific versions to avoid breaking changes. services: @@ -30,8 +28,7 @@ services: - 6222:6222 - 8222:8222 # the endpoints include /varz, /healthz, ... networks: - - server - - monitoring + - dynamo etcd-server: image: bitnamilegacy/etcd:3.6.1 @@ -41,9 +38,7 @@ services: - 2379:2379 # this port exposes the /metrics endpoint - 2380:2380 networks: - - server - - monitoring - + - dynamo # All the services below are part of the metrics profile and monitoring network. # The exporter translates from /varz and other stats to Prometheus metrics @@ -53,7 +48,7 @@ services: ports: - 7777:7777 networks: - - monitoring + - dynamo profiles: [metrics] depends_on: - nats-server @@ -84,7 +79,7 @@ services: - DCGM_EXPORTER_LISTEN=:9401 runtime: nvidia # Specify the NVIDIA runtime networks: - - monitoring + - dynamo # To access Prometheus from another machine, you may need to disable te firewall on your host. On Ubuntu: # sudo ufw allow 9090/tcp @@ -104,7 +99,7 @@ services: # Example to pull from the /query endpoint: # {__name__=~"DCGM.*", job="dcgm-exporter"} networks: - - monitoring + - dynamo ports: - "9090:9090" profiles: [metrics] @@ -143,7 +138,7 @@ services: ports: - "3001:3001" networks: - - monitoring + - dynamo profiles: [metrics] depends_on: - prometheus diff --git a/examples/multimodal/components/processor.py b/examples/multimodal/components/processor.py index b972220f5c..718aa83c23 100644 --- a/examples/multimodal/components/processor.py +++ b/examples/multimodal/components/processor.py @@ -15,6 +15,7 @@ import argparse import asyncio +import copy import json import logging import os @@ -22,7 +23,7 @@ import sys import uuid from enum import Enum -from typing import AsyncIterator, Tuple, Union +from typing import AsyncIterator, Optional, Tuple, Union import uvloop from transformers import AutoTokenizer @@ -32,11 +33,11 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import FlexibleArgumentParser -from dynamo.llm import ModelInput, ModelType, register_llm +from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import Client, DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging +from dynamo._core import parse_tool_calls_py -# To import example local module sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) from utils.args import Config, base_parse_args, parse_endpoint from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn @@ -65,7 +66,7 @@ class Processor(ProcessMixIn): def parse_args(cls) -> Tuple[argparse.Namespace, Config]: DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.processor.generate" - DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate" + DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.llm.generate" parser = FlexibleArgumentParser( description="vLLM based processor for Dynamo LLM." @@ -93,7 +94,7 @@ def parse_args(cls) -> Tuple[argparse.Namespace, Config]: "--downstream-endpoint", type=str, default=DEFAULT_DOWNSTREAM_ENDPOINT, - help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'", + help=f"The endpoint string of the downstream LLM worker in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'", ) args, config = base_parse_args(parser) @@ -104,23 +105,26 @@ def __init__( self, args: argparse.Namespace, engine_args: AsyncEngineArgs, - encode_worker_client: Client, + llm_worker_client: Client, + custom_template_path: Optional[str] = None, + tool_call_parser: Optional[str] = None, ): - self.encode_worker_client = encode_worker_client + self.llm_worker_client = llm_worker_client self.prompt_template = args.prompt_template self.engine_args = engine_args self.model_config = self.engine_args.create_model_config() self.default_sampling_params = self.model_config.get_diff_sampling_param() - self.tokenizer = self._create_tokenizer(self.engine_args) + self.tokenizer = self._create_tokenizer(self.engine_args, custom_template_path) self.chat_processor = ChatProcessor(self.tokenizer, self.model_config) self.completions_processor = CompletionsProcessor( self.tokenizer, self.model_config ) + self.tool_call_parser = tool_call_parser def cleanup(self): pass - def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer: + def _create_tokenizer(self, engine_args: AsyncEngineArgs, custom_template_path: Optional[str] = None) -> AnyTokenizer: """Create a TokenizerGroup using engine arguments similar to VLLM's approach""" model_path = engine_args.model @@ -132,6 +136,16 @@ def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer: truncation_side="left", use_fast=True, # VLLM might use the fast tokenizer for efficiency ) + # Store custom template path but DON'T set it as default on the tokenizer + # We'll apply it conditionally using shallow copy (thread-safe) + # The template itself handles whether tools are present or not with {% if tools %} logic + if custom_template_path: + logger.info(f"Custom chat template path: {custom_template_path}") + with open(custom_template_path, 'r') as f: + self.custom_tool_template = f.read() + logger.info("Custom chat template loaded (will be used for all requests when specified)") + else: + self.custom_tool_template = None return base_tokenizer # Main method to parse the request and send the request to the vllm worker. @@ -142,7 +156,7 @@ async def _generate( request_type: RequestType, ): request_id = str(uuid.uuid4().hex) - logger.debug(f"Got raw request: {raw_request}") + logger.info(f"Got raw request: {raw_request}") ( request, conversation, @@ -162,7 +176,7 @@ async def _generate( # This API could accept Pydantic class, but SamplingParams # in vLLMMultimodalRequest is not a Pydantic class and will # cause TypeError: unsupported type SamplingParams - response_generator = await self.encode_worker_client.round_robin( + response_generator = await self.llm_worker_client.round_robin( worker_request.model_dump_json() ) @@ -211,35 +225,143 @@ async def generate(self, raw_request: MultiModalRequest): # If the request is not MultiModalRequest, convert it to MultiModalRequest raw_request = MultiModalRequest.model_validate(raw_request) - # Ensure the configured template includes the placeholder - template = self.prompt_template - if "" not in template: - raise ValueError("prompt_template must contain '' placeholder") - - # Safely extract user text - try: - user_text = raw_request.messages[0].content[0].text - except (IndexError, AttributeError) as e: - raise ValueError(f"Invalid message structure: {e}") - - prompt = template.replace("", user_text) - - msg = { - "role": "user", - "content": prompt, - } - - # Set stream=True - the http frontend will handle aggregation of - # streamed chunks into a single http response, or stream them - # back as SSE responses based on the stream flag in the request. - chat_request = ChatCompletionRequest( - model=raw_request.model, - messages=[msg], - stream=True, - max_tokens=raw_request.max_tokens, - temperature=raw_request.temperature, - request_id=str(uuid.uuid4()), - ) + # If tools are provided, apply the chat template with tools + # We need to apply the Jinja template but NOT process images (keep them as URLs) + if raw_request.tools and len(raw_request.tools) > 0: + # Convert messages and tools to dicts for template rendering + messages_for_template = [] + for msg in raw_request.messages: + # Flatten multi-part content into a single text string for template + content_parts = [] + for content in msg.content: + if content.type == "text": + content_parts.append(content.text) + elif content.type == "image_url": + # Use Qwen's vision tokens + content_parts.append("<|vision_start|><|image_pad|><|vision_end|>") + elif content.type == "video_url": + # Use similar format for video if needed + content_parts.append("<|vision_start|><|video_pad|><|vision_end|>") + + messages_for_template.append({ + "role": msg.role, + "content": "".join(content_parts) # Join without spaces to keep tokens together + }) + + # Convert tools to dicts + tools_dicts = [tool.model_dump() if hasattr(tool, 'model_dump') else tool for tool in raw_request.tools] + + # Apply the custom tool calling template + if self.custom_tool_template: + # THREAD-SAFE: Create a shallow copy to avoid race conditions when mutating chat_template + # Shallow copy is fast (~1-5 ยตs) and only copies attribute references, not large objects like vocab + temp_tokenizer = copy.copy(self.tokenizer) + temp_tokenizer.chat_template = self.custom_tool_template + prompt = temp_tokenizer.apply_chat_template( + messages_for_template, + tools=tools_dicts, + add_generation_prompt=True, + tokenize=False + ) + else: + # Use default tokenizer template with tools + prompt = self.tokenizer.apply_chat_template( + messages_for_template, + tools=tools_dicts, + add_generation_prompt=True, + tokenize=False + ) + + # Create a simple message with the formatted prompt (like the original path) + msg = { + "role": "user", + "content": prompt, + } + + chat_request = ChatCompletionRequest( + model=raw_request.model, + messages=[msg], + stream=raw_request.stream, + max_tokens=raw_request.max_tokens, + temperature=raw_request.temperature, + request_id=str(uuid.uuid4()), + tools=raw_request.tools, + tool_choice=raw_request.tool_choice, + ) + else: + # Check if this is a multimodal request (has images/video) + has_multimodal = any( + item.type in ["image_url", "video_url"] + for msg in raw_request.messages + for item in msg.content + ) + + if has_multimodal: + # Original path: manual template replacement for multimodal non-tool calls + # Ensure the configured template includes the placeholder + template = self.prompt_template + if "" not in template: + raise ValueError("prompt_template must contain '' placeholder") + + # Extract all text from content items (handles image-only, text-only, or mixed) + text_parts = [] + for item in raw_request.messages[0].content: + if item.type == "text": + text_parts.append(item.text) + + # Use empty string if no text (image-only case) + user_text = " ".join(text_parts) if text_parts else "" + + prompt = template.replace("", user_text) + + msg = { + "role": "user", + "content": prompt, + } + + chat_request = ChatCompletionRequest( + model=raw_request.model, + messages=[msg], + stream=raw_request.stream, + max_tokens=raw_request.max_tokens, + temperature=raw_request.temperature, + request_id=str(uuid.uuid4()), + tools=raw_request.tools, + tool_choice=raw_request.tool_choice, + ) + else: + # Text-only chat: use tokenizer's chat template + messages_for_template = [] + for msg in raw_request.messages: + # Flatten content to string + content_text = " ".join([item.text for item in msg.content if item.type == "text"]) + messages_for_template.append({ + "role": msg.role, + "content": content_text + }) + + # Apply chat template + prompt = self.tokenizer.apply_chat_template( + messages_for_template, + add_generation_prompt=True, + tokenize=False + ) + + msg = { + "role": "user", + "content": prompt, + } + + chat_request = ChatCompletionRequest( + model=raw_request.model, + messages=[msg], + stream=raw_request.stream, + max_tokens=raw_request.max_tokens, + temperature=raw_request.temperature, + request_id=str(uuid.uuid4()), + tools=raw_request.tools, + tool_choice=raw_request.tool_choice, + ) multimodal_input = MultiModalInput() for message in raw_request.messages: @@ -251,9 +373,13 @@ async def generate(self, raw_request: MultiModalRequest): raise ValueError("Cannot provide both image and video URLs") multimodal_input.video_url = item.video_url.url - if multimodal_input.image_url is None and multimodal_input.video_url is None: - raise ValueError("Either image URL or video URL is required") + # Allow text-only messages (no image/video required) + # This enables both pure text chat and multimodal use cases + # Buffer chunks when tool calling is enabled to clear content after parsing + accumulated_content = "" + buffered_chunks = [] + async for response in self._generate( chat_request, multimodal_input, RequestType.CHAT ): @@ -263,8 +389,86 @@ async def generate(self, raw_request: MultiModalRequest): # reconstructing back the OpenAI chat response as dynamo egress expects it if response.startswith("data: [DONE]"): break - response = json.loads(response.lstrip("data: ")) - yield response + + # Handle both streaming (with "data: " prefix) and non-streaming responses + if response.startswith("data: "): + response = json.loads(response.lstrip("data: ")) + else: + response = json.loads(response) + # Convert non-streaming format (message) to streaming format (delta) + if "choices" in response and "message" in response["choices"][0]: + message_content = response["choices"][0]["message"]["content"] + response["choices"][0]["delta"] = {"content": message_content, "role": "assistant"} + del response["choices"][0]["message"] + response["object"] = "chat.completion.chunk" + + # Buffer chunks and accumulate content when tool calling is configured + if ( + self.tool_call_parser + and raw_request.tools + and "choices" in response + and len(response["choices"]) > 0 + ): + choice = response["choices"][0] + + # Buffer this chunk + buffered_chunks.append(response) + + # Accumulate delta content + if "delta" in choice and choice["delta"].get("content"): + accumulated_content += choice["delta"]["content"] + + # Parse when we hit the end (finish_reason is set) + finish_reason = choice.get("finish_reason") + if finish_reason == "stop": + if accumulated_content: + logger.info(f"Attempting to parse accumulated tool calls (length={len(accumulated_content)}) with parser: {self.tool_call_parser}") + try: + tool_calls, normal_text = parse_tool_calls_py(accumulated_content, self.tool_call_parser) + logger.info(f"Parse result: {len(tool_calls) if tool_calls else 0} tool calls found") + + if tool_calls: + # Convert tool calls to OpenAI format + tool_call_chunks = [] + for idx, tc in enumerate(tool_calls): + tool_call_chunks.append({ + "index": idx, + "id": tc["id"], + "type": tc["type"], + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"] + } + }) + + # Clear content from ALL buffered chunks (per OpenAI spec) + for buffered_chunk in buffered_chunks: + if "choices" in buffered_chunk and len(buffered_chunk["choices"]) > 0: + buffered_choice = buffered_chunk["choices"][0] + if "delta" in buffered_choice: + buffered_choice["delta"]["content"] = "" + elif "message" in buffered_choice: + buffered_choice["message"]["content"] = "" + + # Add tool_calls to the final chunk + if "delta" in choice: + choice["delta"]["tool_calls"] = tool_call_chunks + elif "message" in choice: + choice["message"]["tool_calls"] = tool_call_chunks + + choice["finish_reason"] = "tool_calls" + logger.info(f"Cleared content from {len(buffered_chunks)} chunks and added {len(tool_calls)} tool call(s) to final chunk") + except Exception as e: + logger.warning(f"Failed to parse tool calls: {e}", exc_info=True) + # Continue with original response if parsing fails + + # Yield all buffered chunks now that we've processed them + for chunk in buffered_chunks: + yield chunk + buffered_chunks = [] + else: + # No tool calling, yield immediately + yield response async def graceful_shutdown(runtime): @@ -311,26 +515,39 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( args.downstream_endpoint ) - encode_worker_client = ( + llm_worker_client = ( await runtime.namespace(parsed_namespace) .component(parsed_component_name) .endpoint(parsed_endpoint_name) .client() ) - handler = Processor(args, config.engine_args, encode_worker_client) + handler = Processor( + args, + config.engine_args, + llm_worker_client, + config.custom_jinja_template, + config.tool_call_parser, + ) - logger.info("Waiting for Encoder Worker Instances ...") - await encode_worker_client.wait_for_instances() + logger.info("Waiting for LLM Worker Instances ...") + await llm_worker_client.wait_for_instances() # Register the endpoint as entrypoint to a model + logger.info(f"Config: {config.tool_call_parser}, {config.reasoning_parser}, {config.custom_jinja_template}") + runtime_config = ModelRuntimeConfig() + runtime_config.tool_call_parser = config.tool_call_parser + runtime_config.reasoning_parser = config.reasoning_parser + await register_llm( - ModelInput.Text, # Custom processor is used and this type bypasses SDK processor + ModelInput.Text, ModelType.Chat, generate_endpoint, config.model, config.served_model_name, kv_cache_block_size=config.engine_args.block_size, + runtime_config=runtime_config, + custom_template_path=config.custom_jinja_template, ) logger.info(f"Starting to serve the {args.endpoint} endpoint...") diff --git a/examples/multimodal/components/worker.py b/examples/multimodal/components/worker.py index c3258fdbfb..6e17311ab2 100644 --- a/examples/multimodal/components/worker.py +++ b/examples/multimodal/components/worker.py @@ -44,7 +44,7 @@ class VllmBaseWorker: @classmethod def parse_args(cls) -> Tuple[argparse.Namespace, Config]: parser = FlexibleArgumentParser( - description="vLLM based encoder for Dynamo LLM." + description="vLLM based worker for Dynamo LLM." ) parser.add_argument( "--endpoint", @@ -270,38 +270,42 @@ async def generate(self, request: vLLMMultimodalRequest): request.multimodal_input.image_url is None and request.multimodal_input.video_url is None ): - # Process embeddings using the connector - # Create a descriptor based on the embedding shape. - embeddings = torch.empty( - request.embeddings_shape, - dtype=self.EMBEDDINGS_DTYPE, - device=self.EMBEDDINGS_DEVICE, - ) - descriptor = connect.Descriptor(embeddings) - - if descriptor is None: - raise RuntimeError( - "Descriptor is None in PD worker - cannot process embeddings" + # Check if embeddings are provided via connector (for disaggregated serving) + if request.embeddings_shape is not None: + # Process embeddings using the connector + # Create a descriptor based on the embedding shape. + embeddings = torch.empty( + request.embeddings_shape, + dtype=self.EMBEDDINGS_DTYPE, + device=self.EMBEDDINGS_DEVICE, ) - - read_op = await self._connector.begin_read( - request.serialized_request, descriptor - ) - await read_op.wait_for_completion() - if "video" in self.engine_args.model.lower(): - video_numpy = embeddings.numpy() - multi_modal_data = construct_mm_data( - self.engine_args.model, - self.EMBEDDINGS_DTYPE, - video_numpy=video_numpy, + descriptor = connect.Descriptor(embeddings) + if descriptor is None: + raise RuntimeError( + "Descriptor is None in PD worker - cannot process embeddings" + ) + + read_op = await self._connector.begin_read( + request.serialized_request, descriptor ) + await read_op.wait_for_completion() + if "video" in self.engine_args.model.lower(): + video_numpy = embeddings.numpy() + multi_modal_data = construct_mm_data( + self.engine_args.model, + self.EMBEDDINGS_DTYPE, + video_numpy=video_numpy, + ) + else: + multi_modal_data = construct_mm_data( + self.engine_args.model, + self.EMBEDDINGS_DTYPE, + image_embeds=embeddings, + image_grid_thw=request.image_grid_thw, + ) else: - multi_modal_data = construct_mm_data( - self.engine_args.model, - self.EMBEDDINGS_DTYPE, - image_embeds=embeddings, - image_grid_thw=request.image_grid_thw, - ) + # Text-only request: no multimodal data + multi_modal_data = None else: # Use PIL image instead of image embeddings multi_modal_data = { diff --git a/examples/multimodal/launch/agg.sh b/examples/multimodal/launch/agg.sh index 8a5a908142..d48afeb996 100755 --- a/examples/multimodal/launch/agg.sh +++ b/examples/multimodal/launch/agg.sh @@ -8,6 +8,8 @@ trap 'echo Cleaning up...; kill 0' EXIT MODEL_NAME="llava-hf/llava-1.5-7b-hf" PROMPT_TEMPLATE="USER: \n ASSISTANT:" PROVIDED_PROMPT_TEMPLATE="" +TOOL_CALL_PARSER="" +CUSTOM_TEMPLATE="" # Parse command line arguments while [[ $# -gt 0 ]]; do @@ -20,12 +22,30 @@ while [[ $# -gt 0 ]]; do PROVIDED_PROMPT_TEMPLATE=$2 shift 2 ;; + --dyn-tool-call-parser) + TOOL_CALL_PARSER=$2 + shift 2 + ;; + --custom-jinja-template) + CUSTOM_TEMPLATE=$2 + shift 2 + ;; -h|--help) echo "Usage: $0 [OPTIONS]" echo "Options:" - echo " --model Specify the model to use (default: $MODEL_NAME)" - echo " --prompt-template