diff --git a/components/backends/sglang/README.md b/components/backends/sglang/README.md index a66b23f6d3..f5877789f9 100644 --- a/components/backends/sglang/README.md +++ b/components/backends/sglang/README.md @@ -117,9 +117,7 @@ uv pip install maturin cd $DYNAMO_HOME/lib/bindings/python maturin develop --uv cd $DYNAMO_HOME -# installs sglang supported version along with dynamo -# include the prerelease flag to install flashinfer rc versions -uv pip install --prerelease=allow -e .[sglang] +uv pip install -e .[sglang] ``` diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 070a4e2ca9..044ed2d6b2 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -26,6 +26,7 @@ MultimodalPrefillWorkerHandler, MultimodalProcessorHandler, MultimodalWorkerHandler, + NativeApiHandler, PrefillWorkerHandler, ) @@ -74,7 +75,13 @@ async def init(runtime: DistributedRuntime, config: Config): generate_endpoint = component.endpoint(dynamo_args.endpoint) + # publisher instantiates the metrics and kv event publishers + publisher, metrics_task, metrics_labels = await setup_sgl_metrics( + engine, config, component, generate_endpoint + ) + prefill_client = None + native_api_tasks = [] if config.serving_mode == DisaggregationMode.DECODE: logging.info("Initializing prefill client") prefill_client = ( @@ -83,11 +90,11 @@ async def init(runtime: DistributedRuntime, config: Config): .endpoint("generate") .client() ) - - # publisher instantiates the metrics and kv event publishers - publisher, metrics_task, metrics_labels = await setup_sgl_metrics( - engine, config, component, generate_endpoint - ) + # TODO: implement other native APIs and come up with clean layer to apply to agg/disagg/etc + if config.serving_mode == DisaggregationMode.AGGREGATED: + native_api_tasks = await NativeApiHandler( + component, engine, metrics_labels + ).init_native_apis() # Readiness gate: requests wait until model is registered ready_event = asyncio.Event() @@ -97,7 +104,6 @@ async def init(runtime: DistributedRuntime, config: Config): health_check_payload = SglangHealthCheckPayload(engine).to_dict() try: - # Start endpoint immediately and register model concurrently # Requests queue until ready_event is set await asyncio.gather( generate_endpoint.serve_endpoint( @@ -113,6 +119,7 @@ async def init(runtime: DistributedRuntime, config: Config): dynamo_args, readiness_gate=ready_event, ), + *native_api_tasks, ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") diff --git a/components/src/dynamo/sglang/request_handlers/__init__.py b/components/src/dynamo/sglang/request_handlers/__init__.py index bfc76e4ce0..e2708aa3ae 100644 --- a/components/src/dynamo/sglang/request_handlers/__init__.py +++ b/components/src/dynamo/sglang/request_handlers/__init__.py @@ -17,6 +17,7 @@ MultimodalProcessorHandler, MultimodalWorkerHandler, ) +from .native_api_handler import NativeApiHandler __all__ = [ "BaseWorkerHandler", @@ -28,6 +29,7 @@ # Multimodal handlers "MultimodalEncodeWorkerHandler", "MultimodalPrefillWorkerHandler", + "NativeApiHandler", "MultimodalProcessorHandler", "MultimodalWorkerHandler", ] diff --git a/components/src/dynamo/sglang/request_handlers/native_api_handler.py b/components/src/dynamo/sglang/request_handlers/native_api_handler.py new file mode 100644 index 0000000000..0b9cc767ff --- /dev/null +++ b/components/src/dynamo/sglang/request_handlers/native_api_handler.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# SGLang Native APIs: https://docs.sglang.ai/basic_usage/native_api.html +# Code: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py + +import asyncio +import logging +from typing import List, Optional, Tuple + +import sglang as sgl + +from dynamo._core import Component + + +class NativeApiHandler: + """Handler to add sglang native API endpoints to workers""" + + def __init__( + self, + component: Component, + engine: sgl.Engine, + metrics_labels: Optional[List[Tuple[str, str]]] = None, + ): + self.component = component + self.engine = engine + self.metrics_labels = metrics_labels + self.native_api_tasks = [] + + async def init_native_apis( + self, + ) -> List[asyncio.Task]: + """ + Initialize and register native API endpoints. + Returns list of tasks to be gathered. + """ + logging.info("Initializing native SGLang API endpoints") + + self.tm = self.engine.tokenizer_manager + + tasks = [] + + model_info_ep = self.component.endpoint("get_model_info") + tasks.extend( + [ + model_info_ep.serve_endpoint( + self.get_model_info, + graceful_shutdown=True, + metrics_labels=self.metrics_labels, + http_endpoint_path="/get_model_info", + ), + ] + ) + + self.native_api_tasks = tasks + logging.info(f"Registered {len(tasks)} native API endpoints") + return tasks + + async def get_model_info(self, request: dict): + _ = request + result = { + "model_path": self.tm.server_args.model_path, + "tokenizer_path": self.tm.server_args.tokenizer_path, + "preferred_sampling_params": self.tm.server_args.preferred_sampling_params, + "weight_version": self.tm.server_args.weight_version, + } + + yield {"data": [result]} diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index b2d1b351aa..b85d64d323 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -643,7 +643,7 @@ impl Component { #[pymethods] impl Endpoint { - #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None))] + #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None, http_endpoint_path = None))] fn serve_endpoint<'p>( &self, py: Python<'p>, @@ -651,6 +651,7 @@ impl Endpoint { graceful_shutdown: Option, metrics_labels: Option>, health_check_payload: Option<&Bound<'p, PyDict>>, + http_endpoint_path: Option<&str>, ) -> PyResult> { let engine = Arc::new(engine::PythonAsyncEngine::new( generator, @@ -688,6 +689,10 @@ impl Endpoint { builder = builder.health_check_payload(payload); } + if let Some(http_endpoint_path) = http_endpoint_path { + builder = builder.http_endpoint_path(http_endpoint_path); + } + let graceful_shutdown = graceful_shutdown.unwrap_or(true); pyo3_async_runtimes::tokio::future_into_py(py, async move { builder diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 8037116221..5980f20bb5 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -115,7 +115,7 @@ class Endpoint: ... - async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True, metrics_labels: Optional[List[Tuple[str, str]]] = None, health_check_payload: Optional[Dict[str, Any]] = None) -> None: + async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True, metrics_labels: Optional[List[Tuple[str, str]]] = None, health_check_payload: Optional[Dict[str, Any]] = None, http_endpoint_path: Optional[str] = None) -> None: """ Serve an endpoint discoverable by all connected clients at `{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}` diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 8dd22cee9c..38057e8fca 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -17,6 +17,7 @@ use crate::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }, }; +use dynamo_runtime::component::INSTANCE_ROOT_PATH; use dynamo_runtime::transports::etcd; use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode}; @@ -55,11 +56,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul let http_service = match engine_config { EngineConfig::Dynamic(_) => { let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; - let etcd_client = distributed_runtime.etcd_client(); // This allows the /health endpoint to query etcd for active instances - http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone()); + http_service_builder = http_service_builder.with_drt(Some(distributed_runtime.clone())); let http_service = http_service_builder.build()?; - match etcd_client { + match distributed_runtime.etcd_client() { Some(ref etcd_client) => { let router_config = engine_config.local_model().router_config(); // Listen for models registering themselves in etcd, add them to HTTP service @@ -71,7 +71,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul } else { Some(namespace.to_string()) }; - run_watcher( + run_model_watcher( distributed_runtime, http_service.state().manager_clone(), etcd_client.clone(), @@ -84,6 +84,10 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul http_service.state().metrics_clone(), ) .await?; + + // Start dynamic HTTP endpoint watcher + run_endpoint_watcher(etcd_client.clone(), Arc::new(http_service.clone())) + .await?; } None => { // Static endpoints don't need discovery @@ -221,7 +225,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul /// Spawns a task that watches for new models in etcd at network_prefix, /// and registers them with the ModelManager so that the HTTP service can use them. #[allow(clippy::too_many_arguments)] -async fn run_watcher( +async fn run_model_watcher( runtime: DistributedRuntime, model_manager: Arc, etcd_client: etcd::Client, @@ -265,6 +269,24 @@ async fn run_watcher( Ok(()) } +/// Spawns a task that watches instance records for dynamic HTTP endpoints and updates the +/// DynamicEndpointWatcher held in the HTTP service state. +async fn run_endpoint_watcher( + etcd_client: etcd::Client, + http_service: Arc, +) -> anyhow::Result<()> { + if let Some(dep_watcher) = http_service.state().dynamic_registry() { + let instances_watcher = etcd_client + .kv_get_and_watch_prefix(INSTANCE_ROOT_PATH) + .await?; + let (_prefix2, _watcher2, instances_rx) = instances_watcher.dissolve(); + tokio::spawn(async move { + dep_watcher.watch(instances_rx).await; + }); + } + Ok(()) +} + /// Updates HTTP service endpoints based on available model types fn update_http_endpoints(service: Arc, model_type: ModelUpdate) { tracing::debug!( diff --git a/lib/llm/src/http/service.rs b/lib/llm/src/http/service.rs index 7f163a200f..6f4e9182e4 100644 --- a/lib/llm/src/http/service.rs +++ b/lib/llm/src/http/service.rs @@ -21,6 +21,8 @@ mod openai; pub mod disconnect; +pub mod dynamic_endpoint; +pub mod dynamic_registry; pub mod error; pub mod health; pub mod metrics; diff --git a/lib/llm/src/http/service/clear_kv_blocks.rs b/lib/llm/src/http/service/clear_kv_blocks.rs deleted file mode 100644 index ee1cc3bc3e..0000000000 --- a/lib/llm/src/http/service/clear_kv_blocks.rs +++ /dev/null @@ -1,248 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -use super::{service_v2, RouteDoc}; -use axum::{http::Method, response::IntoResponse, routing::post, Json, Router}; -use serde_json::json; -use std::sync::Arc; - -use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt}; - -pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks"; - -pub fn clear_kv_blocks_router( - state: Arc, - path: Option, -) -> (Vec, Router) { - let path = path.unwrap_or_else(|| "/clear_kv_blocks".to_string()); - - let docs: Vec = vec![RouteDoc::new(Method::POST, &path)]; - - let router = Router::new() - .route(&path, post(clear_kv_blocks_handler)) - .with_state(state); - - (docs, router) -} - -async fn clear_kv_blocks_handler( - axum::extract::State(state): axum::extract::State>, -) -> impl IntoResponse { - let model_entries = state.manager().get_model_entries(); - - // if there are no active workers - if model_entries.is_empty() { - return Json(serde_json::json!({ - "message": "No active worker groups found" - })); - } - - let distributed = match state.runtime() { - Some(runtime) => runtime, - None => { - return Json(serde_json::json!({ - "message": "Failed to create distributed runtime", - })); - } - }; - - let mut cleared_workers = Vec::new(); - let mut failed_workers = Vec::new(); - - // update cleared and failed workers - let mut add_worker_result = |success: bool, - name: String, - status: &str, - ns: &str, - comp: &str, - message: Option| { - let mut result = json!({ - "name": name, - "endpoint": format!("{}/{}/{}", ns, comp, CLEAR_KV_ENDPOINT), - "status": status, - }); - if success { - if let Some(m) = message { - result["response"] = json!(m); - } - cleared_workers.push(result); - } else { - if let Some(m) = message { - result["error"] = json!(m); - } - failed_workers.push(result); - } - }; - - // create client for each model entry - for entry in &model_entries { - let namespace = &entry.endpoint_id.namespace; - let component = &entry.endpoint_id.component; - let entry_name = entry.name.to_string(); - - tracing::debug!("Processing worker group: {}/{}", namespace, component); - - let namespace_obj = match distributed.namespace(namespace) { - Ok(ns) => ns, - Err(e) => { - add_worker_result( - false, - entry_name, - "Failed to get namespace", - namespace, - component, - Some(e.to_string()), - ); - continue; - } - }; - - let component_obj = match namespace_obj.component(component) { - Ok(comp) => comp, - Err(e) => { - add_worker_result( - false, - entry_name, - "Failed to get component", - namespace, - component, - Some(e.to_string()), - ); - continue; - } - }; - - let endpoint: dynamo_runtime::component::Endpoint = - component_obj.endpoint(CLEAR_KV_ENDPOINT); - - let client = match endpoint.client().await { - Ok(c) => c, - Err(e) => { - add_worker_result( - false, - entry_name, - "Failed to get client", - namespace, - component, - Some(e.to_string()), - ); - continue; - } - }; - - let router = match PushRouter::<(), serde_json::Value>::from_client( - client.clone(), - Default::default(), - ) - .await - { - Ok(r) => r, - Err(e) => { - add_worker_result( - false, - entry_name, - "Failed to create router", - namespace, - component, - Some(e.to_string()), - ); - continue; - } - }; - - let instances = match component_obj.list_instances().await { - Ok(instances) => instances, - Err(e) => { - add_worker_result( - false, - entry_name, - "Failed to get instances for worker group", - namespace, - component, - Some(e.to_string()), - ); - continue; - } - }; - - if instances.is_empty() { - add_worker_result( - false, - entry_name, - "No instances found for worker group", - namespace, - component, - None, - ); - continue; - } - - let instances_filtered = instances - .clone() - .into_iter() - .filter(|instance| instance.endpoint == CLEAR_KV_ENDPOINT) - .collect::>(); - - if instances_filtered.is_empty() { - let found_endpoints: Vec = instances - .iter() - .map(|instance| instance.endpoint.clone()) - .collect(); - add_worker_result( - false, - entry_name, - &format!( - "Worker group doesn't support clear_kv_blocks. Supported endpoints: {}", - found_endpoints.join(", ") - ), - namespace, - component, - None, - ); - continue; - } - - for instance in &instances_filtered { - let instance_name = format!("{}-instance-{}", entry.name, instance.id()); - match router.direct(().into(), instance.id()).await { - Ok(mut stream) => match stream.next().await { - Some(response) => { - add_worker_result( - true, - instance_name, - "Successfully cleared kv blocks for instance", - namespace, - component, - Some(response.to_string()), - ); - } - None => { - add_worker_result( - false, - instance_name, - "No response from instance", - namespace, - component, - None, - ); - } - }, - Err(e) => { - add_worker_result( - false, - instance_name, - "Failed to send request for instance", - namespace, - component, - Some(e.to_string()), - ); - } - } - } - } - - Json(serde_json::json!({ - "cleared_workers": cleared_workers, - "failed_workers": failed_workers - })) -} diff --git a/lib/llm/src/http/service/dynamic_endpoint.rs b/lib/llm/src/http/service/dynamic_endpoint.rs new file mode 100644 index 0000000000..9adfa6e010 --- /dev/null +++ b/lib/llm/src/http/service/dynamic_endpoint.rs @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Dynamic endpoint handler that fans out requests to all instances that registered +//! the matching HTTP endpoint path, using the background registry. +//! Returns 404 if no instances have registered the endpoint. + +use super::{RouteDoc, service_v2}; +use crate::types::Annotated; +use axum::{ + Json, Router, + http::{Method, StatusCode}, + response::IntoResponse, + routing::post, +}; +use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt}; +use std::sync::Arc; + +pub fn dynamic_endpoint_router( + state: Arc, + path: Option, +) -> (Vec, Router) { + let wildcard_path = "/{*path}"; + let path = path.unwrap_or_else(|| wildcard_path.to_string()); + + let docs: Vec = vec![RouteDoc::new(Method::POST, &path)]; + + let router = Router::new() + .route(&path, post(dynamic_endpoint_handler)) + .with_state(state); + + (docs, router) +} + +async fn inner_dynamic_endpoint_handler( + state: Arc, + path: String, + body: serde_json::Value, +) -> Result { + let fmt_path = format!("/{}", &path); + let registry = state.dynamic_registry(); + let registry = match registry { + Some(r) => r, + None => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Dynamic registry not found", + )); + } + }; + let target_clients = match registry.get_clients(&fmt_path).await { + Some(clients) if !clients.is_empty() => clients, + _ => return Err((StatusCode::NOT_FOUND, "Endpoint not found")), + }; + + // For now broadcast to all instances using direct routing + let mut all_responses = Vec::new(); + for client in target_clients { + let router = PushRouter::>::from_client( + client.clone(), + Default::default(), + ) + .await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Failed to get router"))?; + + let ids = client.instance_ids_avail().clone(); + for id in ids.iter() { + let mut stream = router.direct(body.clone().into(), *id).await.map_err(|e| { + tracing::error!("Failed to route (direct): {:?}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Failed to route") + })?; + while let Some(resp) = stream.next().await { + all_responses.push(resp); + } + } + } + + Ok(Json(serde_json::json!({ + "responses": all_responses + }))) +} + +async fn dynamic_endpoint_handler( + axum::extract::State(state): axum::extract::State>, + axum::extract::Path(path): axum::extract::Path, + body: Option>, +) -> impl IntoResponse { + let body = body.map(|Json(v)| v).unwrap_or(serde_json::json!({})); + inner_dynamic_endpoint_handler(state, path, body) + .await + .map_err(|(status_code, err_string)| { + ( + status_code, + Json(serde_json::json!({ + "message": err_string + })), + ) + }) +} diff --git a/lib/llm/src/http/service/dynamic_registry.rs b/lib/llm/src/http/service/dynamic_registry.rs new file mode 100644 index 0000000000..0930788407 --- /dev/null +++ b/lib/llm/src/http/service/dynamic_registry.rs @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Dynamic HTTP endpoint watcher for native HTTP paths. +//! +//! This watcher maintains a small, in-memory mapping from HTTP path -> set of +//! `EndpointId` and a cache of `EndpointId` -> `Client` (one per endpoint). +//! It consumes etcd watch events for instance records and updates the mapping +//! on PUT/DELETE. The HTTP hot path performs a read-only lookup to get Clients +//! and does not touch etcd. + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use tokio::sync::{RwLock, mpsc::Receiver}; + +use dynamo_runtime::DistributedRuntime; +use dynamo_runtime::component::Client; +use dynamo_runtime::component::Instance; +use dynamo_runtime::protocols::EndpointId; +use dynamo_runtime::transports::etcd::WatchEvent; + +fn normalize_path(path: &str) -> String { + if path.is_empty() { + return "/".to_string(); + } + if path.starts_with('/') { + path.to_string() + } else { + format!("/{}", path) + } +} + +#[derive(Default)] +struct RegistryInner { + // Only 1 entry per EndpointId + paths: HashMap>, + endpoint_clients: HashMap, + // Maps etcd key to its (path, endpoint) for easier deletes + instance_index: HashMap, +} + +#[derive(Clone)] +pub struct DynamicEndpointWatcher { + drt: Option, + inner: Arc>, +} + +impl DynamicEndpointWatcher { + pub fn new(drt: Option) -> Self { + Self { + drt, + inner: Arc::new(RwLock::new(RegistryInner::default())), + } + } + + pub async fn watch(&self, mut rx: Receiver) { + while let Some(evt) = rx.recv().await { + match evt { + WatchEvent::Put(kv) => { + let key = match kv.key_str() { + Ok(k) => k.to_string(), + Err(e) => { + tracing::warn!("Invalid UTF-8 in instance key: {e:?}"); + continue; + } + }; + match serde_json::from_slice::(kv.value()) { + Ok(instance) => { + if let Err(e) = self.add_instance(&key, instance).await { + tracing::warn!("Failed to process instance PUT: {e:?}"); + } + } + Err(err) => { + tracing::warn!("Failed to parse instance on PUT: {err:?}"); + } + } + } + WatchEvent::Delete(kv) => { + let key = match kv.key_str() { + Ok(k) => k.to_string(), + Err(e) => { + tracing::warn!("Invalid UTF-8 in instance key on DELETE: {e:?}"); + continue; + } + }; + self.remove_instance(&key).await; + } + } + } + } + + async fn ensure_client(&self, eid: &EndpointId) -> anyhow::Result { + if let Some(c) = self.inner.read().await.endpoint_clients.get(eid) { + return Ok(c.clone()); + } + let drt = self + .drt + .clone() + .ok_or_else(|| anyhow::anyhow!("No DistributedRuntime available"))?; + let ns = drt + .namespace(eid.namespace.clone()) + .map_err(|e| anyhow::anyhow!("namespace(): {e}"))?; + let comp = ns + .component(eid.component.clone()) + .map_err(|e| anyhow::anyhow!("component(): {e}"))?; + let ep = comp.endpoint(eid.name.clone()); + let client = ep.client().await?; + // Ensure at least one instance is observed before publishing the client + let _ = client.wait_for_instances().await?; + self.inner + .write() + .await + .endpoint_clients + .insert(eid.clone(), client.clone()); + tracing::info!( + path = %eid.as_url(), + namespace = %eid.namespace, + component = %eid.component, + endpoint = %eid.name, + "Dynamic HTTP endpoint client ready" + ); + Ok(client) + } + + async fn add_instance(&self, key: &str, instance: Instance) -> anyhow::Result<()> { + let Some(path) = instance.http_endpoint_path.as_ref() else { + // not a dynamic HTTP endpoint; ignore + return Ok(()); + }; + let path = normalize_path(path); + + let endpoint_id = EndpointId { + namespace: instance.namespace, + component: instance.component, + name: instance.endpoint, + }; + + let mut guard = self.inner.write().await; + + guard + .instance_index + .insert(key.to_string(), (path.clone(), endpoint_id.clone())); + + let set = guard.paths.entry(path.clone()).or_insert_with(HashSet::new); + let inserted_new = set.insert(endpoint_id.clone()); + let need_client = inserted_new && !guard.endpoint_clients.contains_key(&endpoint_id); + drop(guard); + + if need_client { + if let Err(e) = self.ensure_client(&endpoint_id).await { + tracing::warn!("Failed to create client for dynamic endpoint triple: {e:?}"); + } + tracing::info!( + http_path = %path, + namespace = %endpoint_id.namespace, + component = %endpoint_id.component, + endpoint = %endpoint_id.name, + "Registered dynamic HTTP endpoint path" + ); + } + + Ok(()) + } + + async fn remove_instance(&self, key: &str) { + let (_path, endpoint_id) = { + let mut guard = self.inner.write().await; + match guard.instance_index.remove(key) { + Some(v) => { + if let Some(set) = guard.paths.get_mut(&v.0) { + set.remove(&v.1); + if set.is_empty() { + guard.paths.remove(&v.0); + } + } + v + } + None => return, + } + }; + + let still_used = { + let guard = self.inner.read().await; + guard.paths.values().any(|set| set.contains(&endpoint_id)) + }; + if !still_used { + let mut guard = self.inner.write().await; + if guard.endpoint_clients.remove(&endpoint_id).is_some() { + tracing::info!( + namespace = %endpoint_id.namespace, + component = %endpoint_id.component, + endpoint = %endpoint_id.name, + "Removed dynamic HTTP endpoint client" + ); + } + } + } + + /// Get a cloned list of clients for a path. Returns None if the path is unknown. + pub async fn get_clients(&self, path: &str) -> Option> { + let path = normalize_path(path); + let guard = self.inner.read().await; + let triples: Vec = guard + .paths + .get(&path) + .map(|set| set.iter().cloned().collect())?; + let clients = triples + .into_iter() + .filter_map(|t| guard.endpoint_clients.get(&t).cloned()) + .collect::>(); + Some(clients) + } +} diff --git a/lib/llm/src/http/service/health.rs b/lib/llm/src/http/service/health.rs index 6be55254ad..97f9cd5609 100644 --- a/lib/llm/src/http/service/health.rs +++ b/lib/llm/src/http/service/health.rs @@ -52,8 +52,11 @@ async fn live_handler( async fn health_handler( axum::extract::State(state): axum::extract::State>, ) -> impl IntoResponse { - let instances = if let Some(etcd_client) = state.etcd_client() { - match list_all_instances(etcd_client).await { + let drt = state + .distributed_runtime() + .expect("Failed to get distributed runtime"); + let instances = if let Some(etcd_client) = drt.etcd_client() { + match list_all_instances(&etcd_client).await { Ok(instances) => instances, Err(err) => { tracing::warn!("Failed to fetch instances from etcd: {}", err); diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 21d3e8422c..ebcac32174 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -11,6 +11,7 @@ use std::time::Duration; use super::Metrics; use super::RouteDoc; +use super::dynamic_registry::DynamicEndpointWatcher; use super::metrics; use crate::discovery::ModelManager; use crate::endpoint_type::EndpointType; @@ -18,19 +19,19 @@ use crate::request_template::RequestTemplate; use anyhow::Result; use axum_server::tls_rustls::RustlsConfig; use derive_builder::Builder; +use dynamo_runtime::DistributedRuntime; use dynamo_runtime::logging::make_request_span; -use dynamo_runtime::transports::etcd; use std::net::SocketAddr; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tower_http::trace::TraceLayer; /// HTTP service shared state -#[derive(Default)] pub struct State { metrics: Arc, manager: Arc, - etcd_client: Option, + distributed_runtime: Option, + dynamic_registry: Option, flags: StateFlags, } @@ -75,7 +76,8 @@ impl State { Self { manager, metrics: Arc::new(Metrics::default()), - etcd_client: None, + distributed_runtime: None, + dynamic_registry: None, flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -85,11 +87,12 @@ impl State { } } - pub fn new_with_etcd(manager: Arc, etcd_client: Option) -> Self { + pub fn new_with_drt(manager: Arc, drt: Option) -> Self { Self { manager, metrics: Arc::new(Metrics::default()), - etcd_client, + distributed_runtime: drt.clone(), + dynamic_registry: Some(DynamicEndpointWatcher::new(drt)), flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -111,8 +114,12 @@ impl State { self.manager.clone() } - pub fn etcd_client(&self) -> Option<&etcd::Client> { - self.etcd_client.as_ref() + pub fn distributed_runtime(&self) -> Option<&DistributedRuntime> { + self.distributed_runtime.as_ref() + } + + pub fn dynamic_registry(&self) -> Option { + self.dynamic_registry.clone() } // TODO @@ -171,7 +178,7 @@ pub struct HttpServiceConfig { request_template: Option, #[builder(default = "None")] - etcd_client: Option, + distributed_runtime: Option, } impl HttpService { @@ -294,8 +301,8 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let etcd_client = config.etcd_client; - let state = Arc::new(State::new_with_etcd(model_manager, etcd_client)); + let drt = config.distributed_runtime; + let state = Arc::new(State::new_with_drt(model_manager, drt)); state .flags @@ -316,6 +323,8 @@ impl HttpServiceConfigBuilder { // Note: Metrics polling task will be started in run() method to have access to cancellation token + // Start dynamic endpoint watcher: rely on upstream to provide rx; handled in http.rs run() + let mut router = axum::Router::new(); let mut all_docs = Vec::new(); @@ -325,6 +334,7 @@ impl HttpServiceConfigBuilder { super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), + super::dynamic_endpoint::dynamic_endpoint_router(state.clone(), None), ]; let endpoint_routes = @@ -355,8 +365,8 @@ impl HttpServiceConfigBuilder { self } - pub fn with_etcd_client(mut self, etcd_client: Option) -> Self { - self.etcd_client = Some(etcd_client); + pub fn with_drt(mut self, drt: Option) -> Self { + self.distributed_runtime = Some(drt); self } diff --git a/lib/runtime/src/component.rs b/lib/runtime/src/component.rs index 9b38c1c12c..8deabf50b6 100644 --- a/lib/runtime/src/component.rs +++ b/lib/runtime/src/component.rs @@ -98,6 +98,8 @@ pub struct Instance { pub namespace: String, pub instance_id: i64, pub transport: TransportType, + #[serde(skip_serializing_if = "Option::is_none")] + pub http_endpoint_path: Option, } impl Instance { @@ -460,7 +462,7 @@ impl Endpoint { .expect("Endpoint name and component name should be valid") } - /// The fully path of an instance in etcd + /// The full path of an instance in etcd pub fn etcd_path_with_lease_id(&self, lease_id: i64) -> String { format!("{INSTANCE_ROOT_PATH}/{}", self.unique_path(lease_id)) } diff --git a/lib/runtime/src/component/endpoint.rs b/lib/runtime/src/component/endpoint.rs index c5026b8a07..13c62fedb2 100644 --- a/lib/runtime/src/component/endpoint.rs +++ b/lib/runtime/src/component/endpoint.rs @@ -44,6 +44,10 @@ pub struct EndpointConfig { #[educe(Debug(ignore))] #[builder(default, setter(into, strip_option))] health_check_payload: Option, + + /// Expose this endpoint over HTTP at this path + #[builder(default, setter(into, strip_option))] + http_endpoint_path: Option, } impl EndpointConfigBuilder { @@ -67,6 +71,7 @@ impl EndpointConfigBuilder { metrics_labels, graceful_shutdown, health_check_payload, + http_endpoint_path, ) = self.build_internal()?.dissolve(); let lease = lease.or(endpoint.drt().primary_lease()); let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0); @@ -128,6 +133,7 @@ impl EndpointConfigBuilder { let subject = endpoint.subject_to(lease_id); let etcd_path = endpoint.etcd_path_with_lease_id(lease_id); let etcd_client = endpoint.component.drt.etcd_client.clone(); + let http_endpoint_path = http_endpoint_path.clone(); // Register health check target in SystemHealth if provided if let Some(health_check_payload) = &health_check_payload { @@ -137,6 +143,7 @@ impl EndpointConfigBuilder { namespace: namespace_name.clone(), instance_id: lease_id, transport: TransportType::NatsTcp(subject.clone()), + http_endpoint_path: http_endpoint_path.clone(), }; tracing::debug!(endpoint_name = %endpoint_name, "Registering endpoint health check target"); let guard = system_health.lock().unwrap(); @@ -234,6 +241,7 @@ impl EndpointConfigBuilder { namespace: namespace_name, instance_id: lease_id, transport: TransportType::NatsTcp(subject), + http_endpoint_path, }; let info = serde_json::to_vec_pretty(&info)?; diff --git a/lib/runtime/src/protocols.rs b/lib/runtime/src/protocols.rs index 33f8efd79e..6713308709 100644 --- a/lib/runtime/src/protocols.rs +++ b/lib/runtime/src/protocols.rs @@ -36,7 +36,7 @@ pub struct Component { /// /// Example format: `"namespace/component/endpoint"` /// -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub struct EndpointId { pub namespace: String, pub component: String, diff --git a/tests/serve/test_sglang.py b/tests/serve/test_sglang.py index 9d6de3dbd4..fc763c1072 100644 --- a/tests/serve/test_sglang.py +++ b/tests/serve/test_sglang.py @@ -19,6 +19,7 @@ completion_payload_default, embedding_payload, embedding_payload_default, + model_info_payload_default, ) logger = logging.getLogger(__name__) @@ -42,7 +43,11 @@ class SGLangConfig(EngineConfig): model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", env={}, models_port=8000, - request_payloads=[chat_payload_default(), completion_payload_default()], + request_payloads=[ + chat_payload_default(), + completion_payload_default(), + model_info_payload_default(), + ], ), "disaggregated": SGLangConfig( name="disaggregated", diff --git a/tests/utils/payload_builder.py b/tests/utils/payload_builder.py index 26404d3896..501b7e55c1 100644 --- a/tests/utils/payload_builder.py +++ b/tests/utils/payload_builder.py @@ -9,6 +9,7 @@ CompletionPayload, EmbeddingPayload, MetricsPayload, + ModelInfoPayload, ) # Common default text prompt used across tests @@ -76,6 +77,19 @@ def metric_payload_default( ) +def model_info_payload_default( + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + expected_log: Optional[List[str]] = None, +) -> ModelInfoPayload: + return ModelInfoPayload( + body={}, + repeat_count=repeat_count, + expected_log=expected_log or [], + expected_response=expected_response or ["Model:"], + ) + + def chat_payload( content: Union[str, List[Dict[str, Any]]], repeat_count: int = 1, diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index e7b547e576..02e2b1c36b 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -190,6 +190,35 @@ def response_handler(self, response: Any) -> str: return EmbeddingPayload.extract_embeddings(response) +@dataclass +class ModelInfoPayload(BasePayload): + """Payload for get_model_info endpoint.""" + + endpoint: str = "/get_model_info" + + @staticmethod + def extract_model_info(response): + """ + Process get_model_info API responses. + """ + response.raise_for_status() + result = response.json() + assert "responses" in result, "Missing 'responses' in response" + assert len(result["responses"]) > 0, "Empty responses in response" + + data = result["responses"][0].get("data", {}) + assert "data" in data, "Missing 'data' in response data" + assert len(data["data"]) > 0, "Empty data in response" + + model_info = data["data"][0] + assert "model_path" in model_info, "Missing 'model_path' in model info" + + return f"Model: {model_info['model_path']}" + + def response_handler(self, response: Any) -> str: + return ModelInfoPayload.extract_model_info(response) + + @dataclass class MetricsPayload(BasePayload): endpoint: str = "/metrics"