Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ TensorRT-LLM
benchmarks/results
profiling_results*


# Direnv
.envrc
2 changes: 2 additions & 0 deletions components/src/dynamo/common/config_dump/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
add_config_dump_args,
dump_config,
get_config_dump,
get_config_endpoint,
register_encoder,
)
from dynamo.common.config_dump.environment import get_environment_vars
Expand All @@ -25,6 +26,7 @@
"add_config_dump_args",
"dump_config",
"get_config_dump",
"get_config_endpoint",
"get_environment_vars",
"get_gpu_info",
"get_runtime_info",
Expand Down
35 changes: 28 additions & 7 deletions components/src/dynamo/common/config_dump/config_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import pathlib
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from dynamo.common._version import __version__

Expand Down Expand Up @@ -77,6 +77,16 @@ def _get_vllm_version() -> Optional[str]:
return None


async def get_config_endpoint(config: Any, request=None):
try:
# TODO: Putting the dict instead of the string doesn't get sent
# through the endpoint correctly...
yield {"status": "success", "message": get_config_dump(config)}
except Exception as e:
logger.exception("Unexpected error dumping config")
yield {"status": "error", "message": str(e)}


def dump_config(dump_config_to: Optional[str], config: Any) -> None:
"""
Dump the configuration to a file or stdout.
Expand All @@ -101,17 +111,26 @@ def dump_config(dump_config_to: Optional[str], config: Any) -> None:
logger.info(f"Dumped config to {dump_path.resolve()}")
except (OSError, IOError):
logger.exception(f"Failed to dump config to {dump_config_to}")
logger.info(f"CONFIG_DUMP: {config_dump_payload}")
logger.debug(f"CONFIG_DUMP: {config_dump_payload}")
except Exception:
logger.exception("Unexpected error dumping config")
logger.info(f"CONFIG_DUMP: {config_dump_payload}")
logger.debug(f"CONFIG_DUMP: {config_dump_payload}")
else:
logger.info(f"CONFIG_DUMP: {config_dump_payload}")
logger.debug(f"CONFIG_DUMP: {config_dump_payload}")


def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) -> str:
"""
Collect comprehensive config information about a backend instance.
"""
return canonical_json_encoder.encode(_get_config_dump_data(config, extra_info))


def _get_config_dump_data(
config: Any, extra_info: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Collect comprehensive config information about a backend instance.

Args:
config: Any JSON-serializable object containing the backend configuration.
Expand Down Expand Up @@ -148,7 +167,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) ->
if extra_info:
config_dump.update(extra_info)

return canonical_json_encoder.encode(config_dump)
return config_dump

except Exception as e:
logger.error(f"Error collecting config dump: {e}")
Expand All @@ -157,7 +176,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) ->
"error": f"Failed to collect config dump: {str(e)}",
"system_info": get_system_info(), # Always try to include basic system info
}
return canonical_json_encoder.encode(error_info)
return error_info


def add_config_dump_args(parser: argparse.ArgumentParser):
Expand All @@ -176,13 +195,15 @@ def add_config_dump_args(parser: argparse.ArgumentParser):


@functools.singledispatch
def _preprocess_for_encode(obj: object) -> object:
def _preprocess_for_encode(obj: object) -> Union[Dict[str, Any], str]:
"""
Single dispatch function for preprocessing objects before JSON encoding.

This function should be extended using @register_encoder decorator
for backend-specific types.
"""
if isinstance(obj, dict):
return obj
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
return dataclasses.asdict(obj)
logger.warning(f"Unknown type {type(obj)}, using __dict__ or str(obj)")
Expand Down
9 changes: 9 additions & 0 deletions components/src/dynamo/frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def parse_args():
help="Start KServe gRPC server.",
)
add_config_dump_args(parser)
parser.add_argument(
"--extremely-unsafe-do-not-use-in-prod-expose-dump-config",
action="store_true",
help="DO NOT USE IN PROD! Exposes the `/dump_config` endpoint which will dump config + environment variables + system info + GPU info + installed packages.",
)

flags = parser.parse_args()

Expand Down Expand Up @@ -274,6 +279,10 @@ async def async_main():
kwargs["tls_key_path"] = flags.tls_key_path
if flags.namespace:
kwargs["namespace"] = flags.namespace
if flags.extremely_unsafe_do_not_use_in_prod_expose_dump_config:
kwargs[
"extremely_unsafe_do_not_use_in_prod_expose_dump_config"
] = flags.extremely_unsafe_do_not_use_in_prod_expose_dump_config

if is_static:
# out=dyn://<static_endpoint>
Expand Down
54 changes: 46 additions & 8 deletions components/src/dynamo/sglang/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import functools
import logging
import signal
import sys

import sglang as sgl
import uvloop

from dynamo.common.config_dump import dump_config
from dynamo.common.config_dump import dump_config, get_config_endpoint
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
Expand Down Expand Up @@ -75,6 +76,7 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
Expand Down Expand Up @@ -115,6 +117,10 @@ async def init(runtime: DistributedRuntime, config: Config):
dynamo_args,
readiness_gate=ready_event,
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
Expand All @@ -140,6 +146,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

handler = PrefillWorkerHandler(component, engine, config)

Expand All @@ -151,7 +158,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
)
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
]

try:
Expand All @@ -175,6 +186,7 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
Expand Down Expand Up @@ -206,6 +218,10 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
output_type=ModelType.Embedding,
readiness_gate=ready_event,
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
)
except Exception as e:
logging.error(f"Failed to serve embedding endpoints: {e}")
Expand All @@ -229,6 +245,7 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

# For processor, we need to connect to the encode worker
encode_worker_client = (
Expand Down Expand Up @@ -260,6 +277,10 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
input_type=ModelInput.Text,
readiness_gate=ready_event,
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
Expand All @@ -278,6 +299,7 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

# For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = (
Expand All @@ -297,7 +319,11 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
)
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
]

try:
Expand All @@ -319,6 +345,7 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("config")

engine = sgl.Engine(server_args=server_args)

Expand All @@ -337,10 +364,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
await handler.async_init()

try:
await generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True,
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True,
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
Expand All @@ -361,6 +394,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
await component.create_service()

generate_endpoint = component.endpoint(dynamo_args.endpoint)
dump_config_endpoint = component.endpoint("dump_config")

handler = MultimodalPrefillWorkerHandler(component, engine, config)
await handler.async_init()
Expand All @@ -374,7 +408,11 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
)
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config),
metrics_labels=[("model", server_args.served_model_name)],
),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
Expand Down
33 changes: 23 additions & 10 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import functools
import json
import logging
import os
Expand Down Expand Up @@ -34,7 +35,7 @@
from transformers import AutoConfig

import dynamo.nixl_connect as nixl_connect
from dynamo.common.config_dump import dump_config
from dynamo.common.config_dump import dump_config, get_config_endpoint
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
Expand Down Expand Up @@ -282,9 +283,9 @@ async def init(runtime: DistributedRuntime, config: Config):
connector = nixl_connect.Connector()
await connector.initialize()

dump_config(
config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config}
)
# dump config to file/stdout
config_to_dump = {"engine_args": engine_args, "dynamo_args": config}
dump_config(config.dump_config_to, config_to_dump)

async with get_llm_engine(engine_args) as engine:
endpoint = component.endpoint(config.endpoint)
Expand Down Expand Up @@ -357,6 +358,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()

dump_config_endpoint = component.endpoint("dump_config")
if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
Expand All @@ -374,15 +376,26 @@ async def init(runtime: DistributedRuntime, config: Config):
) as publisher:
handler_config.publisher = publisher
handler = RequestHandlerFactory().get_request_handler(handler_config)
await endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
await asyncio.gather(
endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config_to_dump),
metrics_labels=metrics_labels,
),
)
else:
handler = RequestHandlerFactory().get_request_handler(handler_config)
await endpoint.serve_endpoint(
handler.generate, health_check_payload=health_check_payload
await asyncio.gather(
endpoint.serve_endpoint(
handler.generate, health_check_payload=health_check_payload
),
dump_config_endpoint.serve_endpoint(
functools.partial(get_config_endpoint, config_to_dump),
),
)


Expand Down
Loading