diff --git a/.gitignore b/.gitignore index 6ddd5ea67f..48437bef33 100644 --- a/.gitignore +++ b/.gitignore @@ -103,5 +103,6 @@ TensorRT-LLM benchmarks/results profiling_results* + # Direnv .envrc diff --git a/components/src/dynamo/common/config_dump/__init__.py b/components/src/dynamo/common/config_dump/__init__.py index 792d68d585..39bf30547c 100644 --- a/components/src/dynamo/common/config_dump/__init__.py +++ b/components/src/dynamo/common/config_dump/__init__.py @@ -12,6 +12,7 @@ add_config_dump_args, dump_config, get_config_dump, + get_config_endpoint, register_encoder, ) from dynamo.common.config_dump.environment import get_environment_vars @@ -25,6 +26,7 @@ "add_config_dump_args", "dump_config", "get_config_dump", + "get_config_endpoint", "get_environment_vars", "get_gpu_info", "get_runtime_info", diff --git a/components/src/dynamo/common/config_dump/config_dumper.py b/components/src/dynamo/common/config_dump/config_dumper.py index 2a06a040f9..1589d03f5d 100644 --- a/components/src/dynamo/common/config_dump/config_dumper.py +++ b/components/src/dynamo/common/config_dump/config_dumper.py @@ -8,7 +8,7 @@ import logging import pathlib from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from dynamo.common._version import __version__ @@ -77,6 +77,16 @@ def _get_vllm_version() -> Optional[str]: return None +async def get_config_endpoint(config: Any, request=None): + try: + # TODO: Putting the dict instead of the string doesn't get sent + # through the endpoint correctly... + yield {"status": "success", "message": get_config_dump(config)} + except Exception as e: + logger.exception("Unexpected error dumping config") + yield {"status": "error", "message": str(e)} + + def dump_config(dump_config_to: Optional[str], config: Any) -> None: """ Dump the configuration to a file or stdout. @@ -101,17 +111,26 @@ def dump_config(dump_config_to: Optional[str], config: Any) -> None: logger.info(f"Dumped config to {dump_path.resolve()}") except (OSError, IOError): logger.exception(f"Failed to dump config to {dump_config_to}") - logger.info(f"CONFIG_DUMP: {config_dump_payload}") + logger.debug(f"CONFIG_DUMP: {config_dump_payload}") except Exception: logger.exception("Unexpected error dumping config") - logger.info(f"CONFIG_DUMP: {config_dump_payload}") + logger.debug(f"CONFIG_DUMP: {config_dump_payload}") else: - logger.info(f"CONFIG_DUMP: {config_dump_payload}") + logger.debug(f"CONFIG_DUMP: {config_dump_payload}") def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) -> str: """ Collect comprehensive config information about a backend instance. + """ + return canonical_json_encoder.encode(_get_config_dump_data(config, extra_info)) + + +def _get_config_dump_data( + config: Any, extra_info: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """ + Collect comprehensive config information about a backend instance. Args: config: Any JSON-serializable object containing the backend configuration. @@ -148,7 +167,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) -> if extra_info: config_dump.update(extra_info) - return canonical_json_encoder.encode(config_dump) + return config_dump except Exception as e: logger.error(f"Error collecting config dump: {e}") @@ -157,7 +176,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) -> "error": f"Failed to collect config dump: {str(e)}", "system_info": get_system_info(), # Always try to include basic system info } - return canonical_json_encoder.encode(error_info) + return error_info def add_config_dump_args(parser: argparse.ArgumentParser): @@ -176,13 +195,15 @@ def add_config_dump_args(parser: argparse.ArgumentParser): @functools.singledispatch -def _preprocess_for_encode(obj: object) -> object: +def _preprocess_for_encode(obj: object) -> Union[Dict[str, Any], str]: """ Single dispatch function for preprocessing objects before JSON encoding. This function should be extended using @register_encoder decorator for backend-specific types. """ + if isinstance(obj, dict): + return obj if dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.asdict(obj) logger.warning(f"Unknown type {type(obj)}, using __dict__ or str(obj)") diff --git a/components/src/dynamo/frontend/main.py b/components/src/dynamo/frontend/main.py index cc0caba8ba..6896a3616a 100644 --- a/components/src/dynamo/frontend/main.py +++ b/components/src/dynamo/frontend/main.py @@ -210,6 +210,11 @@ def parse_args(): help="Start KServe gRPC server.", ) add_config_dump_args(parser) + parser.add_argument( + "--extremely-unsafe-do-not-use-in-prod-expose-dump-config", + action="store_true", + help="DO NOT USE IN PROD! Exposes the `/dump_config` endpoint which will dump config + environment variables + system info + GPU info + installed packages.", + ) flags = parser.parse_args() @@ -274,6 +279,10 @@ async def async_main(): kwargs["tls_key_path"] = flags.tls_key_path if flags.namespace: kwargs["namespace"] = flags.namespace + if flags.extremely_unsafe_do_not_use_in_prod_expose_dump_config: + kwargs[ + "extremely_unsafe_do_not_use_in_prod_expose_dump_config" + ] = flags.extremely_unsafe_do_not_use_in_prod_expose_dump_config if is_static: # out=dyn:// diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index e852c56c8c..7f150e0d95 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import functools import logging import signal import sys @@ -9,7 +10,7 @@ import sglang as sgl import uvloop -from dynamo.common.config_dump import dump_config +from dynamo.common.config_dump import dump_config, get_config_endpoint from dynamo.llm import ModelInput, ModelType from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging @@ -75,6 +76,7 @@ async def init(runtime: DistributedRuntime, config: Config): await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") prefill_client = None if config.serving_mode == DisaggregationMode.DECODE: @@ -115,6 +117,10 @@ async def init(runtime: DistributedRuntime, config: Config): dynamo_args, readiness_gate=ready_event, ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") @@ -140,6 +146,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") handler = PrefillWorkerHandler(component, engine, config) @@ -151,7 +158,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], health_check_payload=health_check_payload, - ) + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ] try: @@ -175,6 +186,7 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") # publisher instantiates the metrics and kv event publishers publisher, metrics_task, metrics_labels = await setup_sgl_metrics( @@ -206,6 +218,10 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): output_type=ModelType.Embedding, readiness_gate=ready_event, ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ) except Exception as e: logging.error(f"Failed to serve embedding endpoints: {e}") @@ -229,6 +245,7 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config) await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") # For processor, we need to connect to the encode worker encode_worker_client = ( @@ -260,6 +277,10 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config) input_type=ModelInput.Text, readiness_gate=ready_event, ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") @@ -278,6 +299,7 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") # For encode worker, we need to connect to the downstream LLM worker pd_worker_client = ( @@ -297,7 +319,11 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con handler.generate, graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], - ) + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ] try: @@ -319,6 +345,7 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("config") engine = sgl.Engine(server_args=server_args) @@ -337,10 +364,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): await handler.async_init() try: - await generate_endpoint.serve_endpoint( - handler.generate, - metrics_labels=[("model", server_args.served_model_name)], - graceful_shutdown=True, + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + metrics_labels=[("model", server_args.served_model_name)], + graceful_shutdown=True, + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") @@ -361,6 +394,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co await component.create_service() generate_endpoint = component.endpoint(dynamo_args.endpoint) + dump_config_endpoint = component.endpoint("dump_config") handler = MultimodalPrefillWorkerHandler(component, engine, config) await handler.async_init() @@ -374,7 +408,11 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], health_check_payload=health_check_payload, - ) + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", server_args.served_model_name)], + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index a453638492..147f0682c6 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import functools import json import logging import os @@ -34,7 +35,7 @@ from transformers import AutoConfig import dynamo.nixl_connect as nixl_connect -from dynamo.common.config_dump import dump_config +from dynamo.common.config_dump import dump_config, get_config_endpoint from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging @@ -282,9 +283,9 @@ async def init(runtime: DistributedRuntime, config: Config): connector = nixl_connect.Connector() await connector.initialize() - dump_config( - config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config} - ) + # dump config to file/stdout + config_to_dump = {"engine_args": engine_args, "dynamo_args": config} + dump_config(config.dump_config_to, config_to_dump) async with get_llm_engine(engine_args) as engine: endpoint = component.endpoint(config.endpoint) @@ -357,6 +358,7 @@ async def init(runtime: DistributedRuntime, config: Config): # Get health check payload (checks env var and falls back to TensorRT-LLM default) health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict() + dump_config_endpoint = component.endpoint("dump_config") if config.publish_events_and_metrics and is_first_worker(config): # Initialize and pass in the publisher to the request handler to # publish events and metrics. @@ -374,15 +376,26 @@ async def init(runtime: DistributedRuntime, config: Config): ) as publisher: handler_config.publisher = publisher handler = RequestHandlerFactory().get_request_handler(handler_config) - await endpoint.serve_endpoint( - handler.generate, - metrics_labels=metrics_labels, - health_check_payload=health_check_payload, + await asyncio.gather( + endpoint.serve_endpoint( + handler.generate, + metrics_labels=metrics_labels, + health_check_payload=health_check_payload, + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config_to_dump), + metrics_labels=metrics_labels, + ), ) else: handler = RequestHandlerFactory().get_request_handler(handler_config) - await endpoint.serve_endpoint( - handler.generate, health_check_payload=health_check_payload + await asyncio.gather( + endpoint.serve_endpoint( + handler.generate, health_check_payload=health_check_payload + ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config_to_dump), + ), ) diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index f81f1a310f..68d08ac637 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import functools import logging import os import signal @@ -12,7 +13,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.async_llm import AsyncLLM -from dynamo.common.config_dump import dump_config +from dynamo.common.config_dump import dump_config, get_config_endpoint from dynamo.llm import ( ModelInput, ModelRuntimeConfig, @@ -173,6 +174,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): generate_endpoint = component.endpoint(config.endpoint) clear_endpoint = component.endpoint("clear_kv_blocks") + dump_config_endpoint = component.endpoint("dump_config") engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config) @@ -205,6 +207,10 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): clear_endpoint.serve_endpoint( handler.clear_kv_blocks, metrics_labels=[("model", config.model)] ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", config.model)], + ), ) logger.debug("serve_endpoint completed for prefill worker") except Exception as e: @@ -225,6 +231,7 @@ async def init(runtime: DistributedRuntime, config: Config): generate_endpoint = component.endpoint(config.endpoint) clear_endpoint = component.endpoint("clear_kv_blocks") + dump_config_endpoint = component.endpoint("dump_config") prefill_router_client = ( await runtime.namespace(config.namespace) @@ -314,6 +321,10 @@ async def init(runtime: DistributedRuntime, config: Config): clear_endpoint.serve_endpoint( handler.clear_kv_blocks, metrics_labels=[("model", config.model)] ), + dump_config_endpoint.serve_endpoint( + functools.partial(get_config_endpoint, config), + metrics_labels=[("model", config.model)], + ), ) logger.debug("serve_endpoint completed for decode worker") except Exception as e: diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index 73b81fa869..6cbb083d9a 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -119,13 +119,14 @@ pub(crate) struct EntrypointArgs { tls_key_path: Option, extra_engine_args: Option, namespace: Option, + extremely_unsafe_do_not_use_in_prod_expose_dump_config: Option, } #[pymethods] impl EntrypointArgs { #[allow(clippy::too_many_arguments)] #[new] - #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None))] + #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, extremely_unsafe_do_not_use_in_prod_expose_dump_config=None))] pub fn new( engine_type: EngineType, model_path: Option, @@ -142,6 +143,7 @@ impl EntrypointArgs { tls_key_path: Option, extra_engine_args: Option, namespace: Option, + extremely_unsafe_do_not_use_in_prod_expose_dump_config: Option, ) -> PyResult { let endpoint_id_obj: Option = endpoint_id.as_deref().map(EndpointId::from); if (tls_cert_path.is_some() && tls_key_path.is_none()) @@ -167,6 +169,7 @@ impl EntrypointArgs { tls_key_path, extra_engine_args, namespace, + extremely_unsafe_do_not_use_in_prod_expose_dump_config, }) } } @@ -200,7 +203,11 @@ pub fn make_engine<'p>( .tls_key_path(args.tls_key_path.clone()) .is_mocker(matches!(args.engine_type, EngineType::Mocker)) .extra_engine_args(args.extra_engine_args.clone()) - .namespace(args.namespace.clone()); + .namespace(args.namespace.clone()) + .extremely_unsafe_do_not_use_in_prod_expose_dump_config( + args.extremely_unsafe_do_not_use_in_prod_expose_dump_config + .unwrap_or(false), + ); pyo3_async_runtimes::tokio::future_into_py(py, async move { let local_model = builder.build().await.map_err(to_pyerr)?; let inner = select_engine(distributed_runtime, args, local_model) diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 8dd22cee9c..56add958e7 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -51,13 +51,17 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul } http_service_builder = http_service_builder.with_request_template(engine_config.local_model().request_template()); + http_service_builder = http_service_builder + .extremely_unsafe_do_not_use_in_prod_expose_dump_config( + local_model.extremely_unsafe_do_not_use_in_prod_expose_dump_config(), + ); let http_service = match engine_config { EngineConfig::Dynamic(_) => { let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let etcd_client = distributed_runtime.etcd_client(); // This allows the /health endpoint to query etcd for active instances - http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone()); + http_service_builder = http_service_builder.with_drt(Some(distributed_runtime.clone())); let http_service = http_service_builder.build()?; match etcd_client { Some(ref etcd_client) => { diff --git a/lib/llm/src/http/service.rs b/lib/llm/src/http/service.rs index 7f163a200f..f489c05d4d 100644 --- a/lib/llm/src/http/service.rs +++ b/lib/llm/src/http/service.rs @@ -21,6 +21,7 @@ mod openai; pub mod disconnect; +pub mod dump_config; pub mod error; pub mod health; pub mod metrics; diff --git a/lib/llm/src/http/service/dump_config.rs b/lib/llm/src/http/service/dump_config.rs new file mode 100644 index 0000000000..ebfa216546 --- /dev/null +++ b/lib/llm/src/http/service/dump_config.rs @@ -0,0 +1,228 @@ +use super::{RouteDoc, service_v2}; +use anyhow::anyhow; +use axum::{Json, Router, http::Method, http::StatusCode, response::IntoResponse, routing::get}; +use dynamo_runtime::{ + component::Instance, + instances::list_all_instances, + pipeline::{AsyncEngine, Context, PushRouter, RouterMode}, + protocols::maybe_error::MaybeError, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use std::sync::Arc; +use tokio_stream::StreamExt; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ConfigInstance { + instance: Instance, + config: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DumpConfigRequest {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DumpConfigResponse(serde_json::Value); + +impl MaybeError for DumpConfigResponse { + fn from_err(err: Box) -> Self { + Self(json!({ + "error": format!("{:?}", err) + })) + } + + fn err(&self) -> Option { + // Only return an error if the response contains an "error" field or "status": "error" + if let Some(error_msg) = self.0.get("error") { + return Some(anyhow!("Config dump error: {}", error_msg)); + } + if let Some(status) = self.0.get("status") && status == "error" { + let message = self + .0 + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error"); + return Some(anyhow!("Config dump failed: {}", message)); + } + None + } +} + +pub fn dump_config_router( + state: Arc, + path: Option, +) -> (Vec, Router) { + let config_path = path.unwrap_or_else(|| "/dump_config".to_string()); + let docs: Vec = vec![RouteDoc::new(Method::GET, &config_path)]; + let router = Router::new() + .route(&config_path, get(get_config_handler)) + .with_state(state); + (docs, router) +} + +async fn get_config_handler_inner( + state: Arc, +) -> Result, String> { + let etcd_client = state.etcd_client().ok_or("No etcd client found")?; + + let instances = list_all_instances(etcd_client) + .await + .map_err(|e| e.to_string())?; + + if instances.is_empty() { + return Ok(Json(json!({ + "message": "No active instances found" + }))); + } + + let drt = state.drt().ok_or("No distributed runtime available")?; + let mut configs = Vec::new(); + + for instance in instances { + // Skip non-dump_config endpoints + if instance.endpoint != "dump_config" { + continue; + } + + tracing::debug!( + "Fetching config from instance: namespace={}, component={}, endpoint={}, id={}", + instance.namespace, + instance.component, + instance.endpoint, + instance.instance_id + ); + + match fetch_instance_config(drt, &instance).await { + Ok(config) => { + configs.push(ConfigInstance { + instance: instance.clone(), + config, + }); + } + Err(e) => { + tracing::warn!( + "Failed to fetch config from instance {}: {}", + instance.instance_id, + e + ); + // Continue with other instances even if one fails + configs.push(ConfigInstance { + instance: instance.clone(), + config: json!({ + "error": format!("Failed to fetch config: {}", e) + }), + }); + } + } + } + + Ok(Json(json!(configs))) +} + +async fn fetch_instance_config( + drt: &dynamo_runtime::DistributedRuntime, + instance: &Instance, +) -> Result { + // Create an endpoint for this specific instance's dump_config endpoint + let endpoint = drt + .namespace(&instance.namespace) + .map_err(|e| format!("Failed to create namespace: {}", e))? + .component(&instance.component) + .map_err(|e| format!("Failed to create component: {}", e))? + .endpoint(&instance.endpoint); + + // Create a client for this endpoint + let client = endpoint + .client() + .await + .map_err(|e| format!("Failed to create client: {}", e))?; + + // TODO: this is very hacky and needs to be improved, should I be tracking all the endpoints as they come up? + // Wait for the client to discover instances from etcd + client + .wait_for_instances() + .await + .map_err(|e| format!("Failed to wait for instances: {}", e))?; + + // Additional wait: Give the background monitor_instance_source task time to populate instance_avail + // The Client spawns a background task that updates instance_avail from instance_source, + // but it runs asynchronously. We need to ensure it has run at least once. + let max_retries = 50; // 50 * 10ms = 500ms max wait + for _ in 0..max_retries { + let avail_ids = client.instance_ids_avail(); + if avail_ids.contains(&instance.instance_id) { + break; + } + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + + // Final check: ensure the instance is available + let avail_ids = client.instance_ids_avail(); + if !avail_ids.contains(&instance.instance_id) { + return Err(format!( + "Instance {} not found in available instances after waiting. Available: {:?}", + instance.instance_id, + avail_ids.as_ref() + )); + } + + // Create a router that targets this specific instance + let router: PushRouter = + PushRouter::from_client(client, RouterMode::Direct(instance.instance_id)) + .await + .map_err(|e| format!("Failed to create router: {}", e))?; + + // Create the request + let request = Context::new(DumpConfigRequest {}); + + // Call the endpoint + let mut stream = router + .generate(request) + .await + .map_err(|e| format!("Failed to generate request: {}", e))?; + + // Collect the response (dump_config should return a single response) + let mut responses = Vec::new(); + while let Some(response) = stream.next().await { + responses.push(response.0); + } + + // Get the first response or error if empty + let mut response = responses + .into_iter() + .next() + .ok_or_else(|| "No response received".to_string())?; + // Should be of the format {"data": {"message": "json_string"}} + // I'm not sure why I can't nest more than one level, but when + // I do, it passes through weird json + let message = response + .get_mut("data") + .map(|v| v.take()) + .ok_or_else(|| format!("No data field in response {:?}", response))? + .get_mut("message") + .map(|v| v.take()) + .ok_or_else(|| format!("No message field in response {:?}", response))?; + if let Some(message_str) = message.as_str() { + // The message is itself a json string + tracing::warn!("message: {}", message_str); + let message_json = serde_json::from_str(message_str) + .map_err(|e| format!("Failed to parse message as json: {}", e))?; + Ok(message_json) + } else { + Err(format!("message field is not a string {:?}", response)) + } +} + +async fn get_config_handler( + axum::extract::State(state): axum::extract::State>, +) -> impl IntoResponse { + match get_config_handler_inner(state).await { + Ok(response) => (StatusCode::OK, response), + Err(error) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "message": error + })), + ), + } +} diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 21d3e8422c..46d8f1c229 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -18,6 +18,7 @@ use crate::request_template::RequestTemplate; use anyhow::Result; use axum_server::tls_rustls::RustlsConfig; use derive_builder::Builder; +use dynamo_runtime::DistributedRuntime; use dynamo_runtime::logging::make_request_span; use dynamo_runtime::transports::etcd; use std::net::SocketAddr; @@ -31,6 +32,7 @@ pub struct State { metrics: Arc, manager: Arc, etcd_client: Option, + drt: Option, flags: StateFlags, } @@ -76,6 +78,7 @@ impl State { manager, metrics: Arc::new(Metrics::default()), etcd_client: None, + drt: None, flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -85,11 +88,13 @@ impl State { } } - pub fn new_with_etcd(manager: Arc, etcd_client: Option) -> Self { + pub fn new_with_drt(manager: Arc, drt: Option) -> Self { + let etcd_client = drt.as_ref().and_then(|drt| drt.etcd_client().clone()); Self { manager, metrics: Arc::new(Metrics::default()), etcd_client, + drt, flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -119,6 +124,10 @@ impl State { pub fn sse_keep_alive(&self) -> Option { None } + + pub fn drt(&self) -> Option<&DistributedRuntime> { + self.drt.as_ref() + } } #[derive(Clone)] @@ -167,11 +176,14 @@ pub struct HttpServiceConfig { #[builder(default = "true")] enable_responses_endpoints: bool, + #[builder(default = "false")] + extremely_unsafe_do_not_use_in_prod_expose_dump_config: bool, + #[builder(default = "None")] request_template: Option, #[builder(default = "None")] - etcd_client: Option, + drt: Option, } impl HttpService { @@ -278,6 +290,8 @@ static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH"; static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH"; /// Environment variable to set the health endpoint path (default: `/health`) static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH"; +/// Environment variable to set the dump_config endpoint path (default: `/dump_config`) +static HTTP_SVC_DUMP_CONFIG_PATH_ENV: &str = "DYN_HTTP_SVC_DUMP_CONFIG_PATH"; /// Environment variable to set the live endpoint path (default: `/live`) static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH"; /// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`) @@ -294,8 +308,7 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let etcd_client = config.etcd_client; - let state = Arc::new(State::new_with_etcd(model_manager, etcd_client)); + let state = Arc::new(State::new_with_drt(model_manager, config.drt)); state .flags @@ -326,6 +339,15 @@ impl HttpServiceConfigBuilder { super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), ]; + if config.extremely_unsafe_do_not_use_in_prod_expose_dump_config { + tracing::warn!( + "Exposing unsafe dump_config endpoint. IF YOU SEE THIS IN PRODUCTION, YOU ARE DOING SOMETHING WRONG." + ); + routes.push(super::dump_config::dump_config_router( + state.clone(), + var(HTTP_SVC_DUMP_CONFIG_PATH_ENV).ok(), + )); + } let endpoint_routes = HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template); @@ -355,8 +377,8 @@ impl HttpServiceConfigBuilder { self } - pub fn with_etcd_client(mut self, etcd_client: Option) -> Self { - self.etcd_client = Some(etcd_client); + pub fn with_drt(mut self, distributed_runtime: Option) -> Self { + self.drt = Some(distributed_runtime); self } diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 27a8eed5de..22bcacde8d 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -59,6 +59,7 @@ pub struct LocalModelBuilder { user_data: Option, custom_template_path: Option, namespace: Option, + extremely_unsafe_do_not_use_in_prod_expose_dump_config: bool, } impl Default for LocalModelBuilder { @@ -83,6 +84,7 @@ impl Default for LocalModelBuilder { user_data: Default::default(), custom_template_path: Default::default(), namespace: Default::default(), + extremely_unsafe_do_not_use_in_prod_expose_dump_config: false, } } } @@ -148,6 +150,14 @@ impl LocalModelBuilder { self.namespace = namespace; self } + pub fn extremely_unsafe_do_not_use_in_prod_expose_dump_config( + &mut self, + extremely_unsafe_do_not_use_in_prod_expose_dump_config: bool, + ) -> &mut Self { + self.extremely_unsafe_do_not_use_in_prod_expose_dump_config = + extremely_unsafe_do_not_use_in_prod_expose_dump_config; + self + } pub fn request_template(&mut self, template_file: Option) -> &mut Self { self.template_file = template_file; @@ -229,6 +239,8 @@ impl LocalModelBuilder { router_config: self.router_config.take().unwrap_or_default(), runtime_config: self.runtime_config.clone(), namespace: self.namespace.clone(), + extremely_unsafe_do_not_use_in_prod_expose_dump_config: self + .extremely_unsafe_do_not_use_in_prod_expose_dump_config, }); } @@ -308,6 +320,8 @@ impl LocalModelBuilder { router_config: self.router_config.take().unwrap_or_default(), runtime_config: self.runtime_config.clone(), namespace: self.namespace.clone(), + extremely_unsafe_do_not_use_in_prod_expose_dump_config: self + .extremely_unsafe_do_not_use_in_prod_expose_dump_config, }) } } @@ -325,6 +339,7 @@ pub struct LocalModel { router_config: RouterConfig, runtime_config: ModelRuntimeConfig, namespace: Option, + extremely_unsafe_do_not_use_in_prod_expose_dump_config: bool, } impl LocalModel { @@ -379,6 +394,10 @@ impl LocalModel { self.namespace.as_deref() } + pub fn extremely_unsafe_do_not_use_in_prod_expose_dump_config(&self) -> bool { + self.extremely_unsafe_do_not_use_in_prod_expose_dump_config + } + pub fn is_gguf(&self) -> bool { // GGUF is the only file (not-folder) we accept, so we don't need to check the extension // We will error when we come to parse it