diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 3fcee8212c2..3f423561722 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_TORCH_PROFILER_DIR` | Directory for PyTorch profiler output | `/tmp` | | `SGLANG_PROFILE_WITH_STACK` | Set `with_stack` option (bool) for PyTorch profiler (capture stack trace) | `true` | +| `SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS` | Config BatchSpanProcessor.schedule_delay_millis if tracing is enabled | `500` | +| `SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE` | Config BatchSpanProcessor.max_export_batch_size if tracing is enabled | `64` | ## Storage & Caching diff --git a/docs/references/production_request_trace.md b/docs/references/production_request_trace.md index 928e5fd3fc8..2450ccfda13 100644 --- a/docs/references/production_request_trace.md +++ b/docs/references/production_request_trace.md @@ -1,4 +1,4 @@ -SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server. +SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server. You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965. @@ -22,7 +22,13 @@ This section explains how to configure the request tracing and export the trace 3. start your SGLang server with tracing enabled ```bash - python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 + # set env variables + export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500 + export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64 + # start the prefill and decode server + python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 + # start the mini lb + python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 ``` Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. @@ -39,9 +45,9 @@ We have already inserted instrumentation points in the tokenizer and scheduler m Every process involved in tracing during the initialization phase should execute: ```python - process_tracing_init(oltp_traces_endpoint, server_name) + process_tracing_init(otlp_traces_endpoint, server_name) ``` - The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes. + The otlp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes. Every thread involved in tracing during the initialization phase should execute: ```python @@ -95,24 +101,52 @@ We have already inserted instrumentation points in the tokenizer and scheduler m trace_set_proc_propagate_context(rid, req.trace_context) ``` +5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated. + - sender: Execute the following code before sending the request to node thread via http + ```python + trace_context = trace_get_remote_propagate_context(bootstrap_room_list) + headers = {"trace_context": trace_context} + session.post(url, headers=headers) + ``` + - receiver: Execute the following code after receiving the request via http + ```python + trace_set_remote_propagate_context(request.headers['trace_context']) + ``` + ## How to Extend the Tracing Framework to Support Complex Tracing Scenarios The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles. -The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure. - -The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows: +The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows: ``` SglangTraceReqContext (req_id="req-123") ├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0) -│ └── SglangTraceSliceContext (name="prefill") # cur slice | └── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1) - └── SglangTraceSliceContext (name="prefill") # cur slice ``` -Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends. +Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list. In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow. When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span. + + +We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design. + +``` +bootstrap room span +├── router req root span +| └── router thread span +| └── slice span +├── prefill req root span +| ├── tokenizer thread span +| | └── slice span +| └── scheduler thread span +| └── slice span +└── decode req root span + ├── tokenizer thread span + | └── slice span + └── scheduler thread span + └── slice span +``` diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 5e05cdd7408..51b91009ec1 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -58,6 +58,11 @@ ReqToTokenPool, SWAKVPool, ) +from sglang.srt.tracing.trace import ( + trace_event_batch, + trace_slice_batch, + trace_slice_end, +) from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -313,6 +318,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None: ) req.add_latency(RequestStage.DECODE_PREPARE) + trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True) self.queue.append( DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) ) @@ -528,6 +534,9 @@ def pop_preallocated(self) -> List[DecodeRequest]: time.perf_counter() ) decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) + trace_slice_end( + RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True + ) self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove @@ -776,8 +785,19 @@ def pop_transferred(self) -> List[Req]: [decode_req.req], decode_req.req.return_logprob ) self.tree_cache.cache_finished_req(decode_req.req) + trace_slice_end( + RequestStage.DECODE_QUICK_FINISH, + decode_req.req.rid, + thread_finish_flag=True, + ) else: transferred_reqs.append(decode_req.req) + trace_slice_end( + RequestStage.DECODE_TRANSFERRED, + decode_req.req.rid, + auto_next_anon=True, + ) + elif poll in [ KVPoll.Bootstrapping, KVPoll.WaitingForInput, @@ -823,6 +843,7 @@ def event_loop_normal_disagg_decode(self: Scheduler): self.stream_output( batch.reqs, any(req.return_logprob for req in batch.reqs) ) + trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) if prepare_mlp_sync_flag: self._prepare_idle_batch_and_run(None) else: @@ -872,6 +893,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler): self.stream_output( batch.reqs, any(req.return_logprob for req in batch.reqs) ) + trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) if prepare_mlp_sync_flag: batch_, batch_result = self._prepare_idle_batch_and_run( None, delay_process=True @@ -954,6 +976,9 @@ def get_next_disagg_decode_batch_to_run( self.running_batch = self.update_running_batch(self.running_batch) ret = self.running_batch if not self.running_batch.is_empty() else None + if ret: + attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()} + trace_event_batch("schedule", ret.reqs, attrs=attrs) return ret def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 447fffb546f..8fad0c0dce4 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -53,6 +53,7 @@ NSATokenToKVPool, SWAKVPool, ) +from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync if TYPE_CHECKING: @@ -198,6 +199,7 @@ def add(self, req: Req, num_kv_heads: int) -> None: self._process_req(req) req.add_latency(RequestStage.PREFILL_PREPARE) self.queue.append(req) + trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True) def extend(self, reqs: List[Req], num_kv_heads: int) -> None: for req in reqs: @@ -289,6 +291,10 @@ def pop_bootstrapped( req.time_stats.wait_queue_entry_time = time.perf_counter() req.add_latency(RequestStage.PREFILL_BOOTSTRAP) + trace_slice_end( + RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True + ) + self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove ] @@ -316,6 +322,9 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None: ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() + if batch: + attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()} + trace_event_batch("schedule", batch.reqs, attrs=attrs) if require_mlp_sync(self.server_args): batch = self.prepare_mlp_sync_batch(batch) @@ -348,6 +357,9 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: ) self.process_prefill_chunk() batch = self.get_new_batch_prefill() + if batch: + attrs = {"bid": hex(id(batch)), "batch_size": batch.batch_size()} + trace_event_batch("schedule", batch.reqs, attrs=attrs) if require_mlp_sync(self.server_args): batch = self.prepare_mlp_sync_batch(batch) @@ -423,6 +435,7 @@ def process_batch_result_disagg_prefill( req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock req.add_latency(RequestStage.PREFILL_FORWARD) + trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True) self.disagg_prefill_inflight_queue.append(req) if self.spec_algorithm.is_eagle() and batch.spec_info is not None: req.output_topk_p = batch.spec_info.topk_p[i] @@ -487,6 +500,9 @@ def process_batch_result_disagg_prefill( if self.enable_overlap: self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) + trace_slice( + RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True + ) self.maybe_send_health_check_signal() @@ -558,6 +574,9 @@ def process_disagg_prefill_inflight_queue( req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) req.metadata_buffer_index = -1 + trace_slice( + RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True + ) self.disagg_prefill_inflight_queue = undone_reqs diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f79a551434d..eebbcca4ac0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -143,10 +143,13 @@ def __init__(self, **kwargs): # Enable tracing if server_args.enable_trace: - process_tracing_init(server_args.oltp_traces_endpoint, "sglang") - if server_args.disaggregation_mode == "null": - thread_label = "Tokenizer" - trace_set_thread_info(thread_label) + process_tracing_init(server_args.otlp_traces_endpoint, "sglang") + thread_label = "Tokenizer" + if server_args.disaggregation_mode == "prefill": + thread_label = "Prefill Tokenizer" + elif server_args.disaggregation_mode == "decode": + thread_label = "Decode Tokenizer" + trace_set_thread_info(thread_label) try: self.loop = asyncio.get_running_loop() diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index ee2586fde67..6ec95fdd84e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -220,9 +220,12 @@ async def lifespan(fast_api_app: FastAPI): # Init tracing if server_args.enable_trace: - process_tracing_init(server_args.oltp_traces_endpoint, "sglang") - if server_args.disaggregation_mode == "null": - trace_set_thread_info(thread_label) + process_tracing_init(server_args.otlp_traces_endpoint, "sglang") + if server_args.disaggregation_mode == "prefill": + thread_label = "Prefill" + thread_label + elif server_args.disaggregation_mode == "decode": + thread_label = "Decode" + thread_label + trace_set_thread_info(thread_label) # Initialize OpenAI serving handlers fast_api_app.state.openai_serving_completion = OpenAIServingCompletion( diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index c47dc048ad1..de00e73a4d4 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -129,6 +129,8 @@ class Envs: SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1) SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") + SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500) + SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64) # Scheduler: memory leak test SGLANG_TEST_RETRACT = EnvBool(False) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 6209852ade5..92437aa80a0 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -34,13 +34,21 @@ TokenizedGenerateReqInput, WatchLoadUpdateReq, ) -from sglang.srt.managers.schedule_batch import Req +from sglang.srt.managers.schedule_batch import Req, RequestStage from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import ( DP_ATTENTION_HANDSHAKE_PORT_DELTA, PortArgs, ServerArgs, ) +from sglang.srt.tracing.trace import ( + process_tracing_init, + trace_get_proc_propagate_context, + trace_set_proc_propagate_context, + trace_set_thread_info, + trace_slice_end, + trace_slice_start, +) from sglang.srt.utils import ( bind_port, configure_logger, @@ -170,11 +178,22 @@ def send_control_message(self, obj): def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) + def dispatching_with_trace(self, req: Req): + if self.server_args.enable_trace: + trace_set_proc_propagate_context(req.rid, req.trace_context) + trace_slice_start(RequestStage.DC_DISPATCH, req.rid) + req.trace_context = trace_get_proc_propagate_context(req.rid) + + self.dispatching(req) + + if self.server_args.enable_trace: + trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True) + def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( [ - (TokenizedGenerateReqInput, self.dispatching), - (TokenizedEmbeddingReqInput, self.dispatching), + (TokenizedGenerateReqInput, self.dispatching_with_trace), + (TokenizedEmbeddingReqInput, self.dispatching_with_trace), (BlockReqInput, self.send_to_all_workers), (WatchLoadUpdateReq, self.handle_load_update_req), ] @@ -487,6 +506,14 @@ def run_data_parallel_controller_process( pipe_writer, ): kill_itself_when_parent_died() + if server_args.enable_trace: + process_tracing_init(server_args.otlp_traces_endpoint, "sglang") + thread_label = "DP Controller" + if server_args.disaggregation_mode == "prefill": + thread_label = "Prefill DP Controller" + elif server_args.disaggregation_mode == "decode": + thread_label = "Decode DP Controller" + trace_set_thread_info(thread_label) setproctitle.setproctitle("sglang::data_parallel_controller") faulthandler.enable() configure_logger(server_args) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index be2de0cc709..a55cfa49259 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -392,13 +392,23 @@ def merge(self, other: MultimodalInputs): class RequestStage(str, enum.Enum): - # prefill + # Tokenizer + TOKENIZE = "tokenize" + TOKENIZER_DISPATCH = "dispatch" + + # DP controller + DC_DISPATCH = "dc_dispatch" + + # common/non-disaggregation PREFILL_WAITING = "prefill_waiting" + REQUEST_PROCESS = "request_process" + DECODE_LOOP = "decode_loop" + PREFILL_FORWARD = "prefill_forward" + PREFILL_CHUNKED_FORWARD = "chunked_prefill" # disaggregation prefill PREFILL_PREPARE = "prefill_prepare" PREFILL_BOOTSTRAP = "prefill_bootstrap" - PREFILL_FORWARD = "prefill_forward" PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" # disaggregation decode @@ -406,6 +416,8 @@ class RequestStage(str, enum.Enum): DECODE_BOOTSTRAP = "decode_bootstrap" DECODE_WAITING = "decode_waiting" DECODE_TRANSFERRED = "decode_transferred" + DECODE_FAKE_OUTPUT = "fake_output" + DECODE_QUICK_FINISH = "quick_finish" class Req: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4346b0d9ac4..a8eba543e84 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -156,6 +156,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( process_tracing_init, + trace_event_batch, trace_set_proc_propagate_context, trace_set_thread_info, trace_slice_batch, @@ -1376,7 +1377,7 @@ def _add_request_to_queue(self, req: Req, is_retracted: bool = False): self._prefetch_kvcache(req) self.waiting_queue.append(req) req.time_stats.wait_queue_entry_time = time.perf_counter() - trace_slice_end("process req", req.rid, auto_next_anon=True) + trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True) elif self.disaggregation_mode == DisaggregationMode.PREFILL: self._prefetch_kvcache(req) self.disagg_prefill_bootstrap_queue.add( @@ -1639,6 +1640,10 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: if need_dp_attn_preparation: ret = self.prepare_mlp_sync_batch(ret) + if ret: + attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()} + trace_event_batch("schedule", ret.reqs, attrs=attrs) + return ret def get_num_allocatable_reqs(self, running_bs): @@ -2012,13 +2017,10 @@ def process_batch_result( ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) - if self.enable_trace: - trace_slice_batch("decode loop", batch.reqs) + trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs) elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) - if self.enable_trace: - trace_slice_batch("prefill", batch.reqs) elif batch.forward_mode.is_idle(): if self.enable_overlap: @@ -2743,10 +2745,13 @@ def run_scheduler_process( # Set up tracing if server_args.enable_trace: - process_tracing_init(server_args.oltp_traces_endpoint, "sglang") - if server_args.disaggregation_mode == "null": - thread_label = "Scheduler" - trace_set_thread_info(thread_label, tp_rank, dp_rank) + process_tracing_init(server_args.otlp_traces_endpoint, "sglang") + thread_label = "Scheduler" + if server_args.disaggregation_mode == "prefill": + thread_label = "Prefill Scheduler" + elif server_args.disaggregation_mode == "decode": + thread_label = "Decode Scheduler" + trace_set_thread_info(thread_label, tp_rank, dp_rank) # Create a scheduler and run the event loop try: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e06fac95aea..3cd13b9058d 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -14,7 +14,13 @@ BatchEmbeddingOutput, BatchTokenIDOutput, ) -from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + BaseFinishReason, + Req, + RequestStage, + ScheduleBatch, +) +from sglang.srt.tracing.trace import trace_slice from sglang.srt.utils.common import ceil_div if TYPE_CHECKING: @@ -160,6 +166,14 @@ def process_batch_result_prefill( ) self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() + + trace_slice( + RequestStage.PREFILL_FORWARD, + req.rid, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + ) + else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 @@ -188,6 +202,12 @@ def process_batch_result_prefill( ) logprob_pt += num_input_logprobs + trace_slice( + RequestStage.PREFILL_CHUNKED_FORWARD, + req.rid, + auto_next_anon=True, + ) + else: # embedding or reward model is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set() @@ -224,6 +244,13 @@ def process_batch_result_prefill( # being chunked reqs' prefill is not finished req.is_chunked -= 1 + trace_slice( + RequestStage.PREFILL_FORWARD, + req.rid, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + ) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) def _resolve_spec_overlap_token_ids( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 885da6a9842..6fd69f84b64 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -68,6 +68,7 @@ ) from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors +from sglang.srt.managers.schedule_batch import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin @@ -79,6 +80,7 @@ trace_get_proc_propagate_context, trace_req_finish, trace_req_start, + trace_set_remote_propagate_context, trace_slice_end, trace_slice_start, ) @@ -383,6 +385,10 @@ async def generate_request( self.auto_create_handle_loop() obj.normalize_batch_and_arguments() + if request: + if "trace_context" in request.headers: + trace_set_remote_propagate_context(request.headers["trace_context"]) + if self.server_args.tokenizer_worker_num > 1: self._attach_multi_http_worker_info(obj) @@ -605,7 +611,7 @@ async def _tokenize_one_request( mm_inputs = None self._validate_one_request(obj, input_ids) - trace_slice_end("tokenize", obj.rid) + trace_slice_end(RequestStage.TOKENIZE, obj.rid) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids ) @@ -798,7 +804,7 @@ async def _batch_tokenize_and_process( req, req.text, input_ids_list[i], None, None, token_type_ids ) ) - trace_slice_end("tokenize", req.rid) + trace_slice_end(RequestStage.TOKENIZE, req.rid) logger.debug(f"Completed batch processing for {batch_size} requests") return tokenized_objs @@ -850,12 +856,14 @@ def _send_one_request( tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): - trace_slice_start("dispatch", obj.rid) + trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid) tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) self.rid_to_state[obj.rid] = state - trace_slice_end("dispatch", obj.rid, thread_finish_flag=True) + trace_slice_end( + RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True + ) return state def _send_batch_request( @@ -2088,7 +2096,12 @@ def _trace_request_start( bootstrap_room = ( obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None ) - trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9)) + trace_req_start( + obj.rid, + bootstrap_room, + ts=int(created_time * 1e9), + role=self.server_args.disaggregation_mode, + ) trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) else: for i in range(len(obj.rid)): @@ -2097,7 +2110,12 @@ def _trace_request_start( if hasattr(obj, "bootstrap_room") and obj.bootstrap_room else None ) - trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9)) + trace_req_start( + obj.rid[i], + bootstrap_room, + ts=int(created_time * 1e9), + role=self.server_args.disaggregation_mode, + ) trace_slice_start( "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 589ff6c9078..010e7e9cdb6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -288,7 +288,7 @@ class ServerArgs: enable_request_time_stats_logging: bool = False kv_events_config: Optional[str] = None enable_trace: bool = False - oltp_traces_endpoint: str = "localhost:4317" + otlp_traces_endpoint: str = "localhost:4317" # API related api_key: Optional[str] = None @@ -2315,7 +2315,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable opentelemetry trace", ) parser.add_argument( - "--oltp-traces-endpoint", + "--otlp-traces-endpoint", type=str, default="localhost:4317", help="Config opentelemetry collector endpoint if --enable-trace is set. format: :", diff --git a/python/sglang/srt/tracing/trace.py b/python/sglang/srt/tracing/trace.py index f637a8d776d..e3f9c871644 100644 --- a/python/sglang/srt/tracing/trace.py +++ b/python/sglang/srt/tracing/trace.py @@ -15,6 +15,8 @@ from __future__ import annotations +import base64 +import json import logging import os import random @@ -24,6 +26,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional +from sglang.srt.utils import get_int_env_var + if TYPE_CHECKING: from sglang.srt.managers.scheduler import Req @@ -85,6 +89,8 @@ class SglangTraceReqContext: # Indicates whether this instance is a replica from the main process. # When True, root_span is None and only root_span_context is preserved. is_copy: bool = False + bootstrap_room_span: Optional[trace.span.Span] = None + bootstrap_room_span_context: Optional[context.Context] = None root_span: Optional[trace.span.Span] = None root_span_context: Optional[context.Context] = None @@ -96,8 +102,7 @@ class SglangTracePropagateContext: def to_dict(self): carrier: dict[str, str] = {} - context.attach(self.root_span_context) - propagate.inject(carrier) + propagate.inject(carrier, self.root_span_context) if self.prev_span_context: return { @@ -149,6 +154,7 @@ def generate_span_id(self) -> int: # global variables +remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {} threads_info: Dict[int, SglangTraceThreadInfo] = {} reqs_context: Dict[str, SglangTraceReqContext] = {} @@ -193,8 +199,17 @@ def process_tracing_init(otlp_endpoint, server_name): resource=resource, id_generator=SglangTraceCustomIdGenerator() ) + schedule_delay_millis = get_int_env_var( + "SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500 + ) + max_export_batch_size = get_int_env_var( + "SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64 + ) + processor = BatchSpanProcessor( - OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) + OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True), + schedule_delay_millis=schedule_delay_millis, + max_export_batch_size=max_export_batch_size, ) tracer_provider.add_span_processor(processor) trace.set_tracer_provider(tracer_provider) @@ -266,7 +281,9 @@ def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): return thread_context -def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: +def trace_get_proc_propagate_context( + rid, remote_propagate=False +) -> Optional[Dict[str, Any]]: if not tracing_enabled: return None @@ -283,9 +300,11 @@ def trace_get_proc_propagate_context(rid) -> Optional[Dict[str, Any]]: elif thread_context.last_span_context: prev_span_context = thread_context.last_span_context - trace_context = SglangTracePropagateContext( - reqs_context[rid].root_span_context, prev_span_context - ) + root_span_context = reqs_context[rid].root_span_context + if remote_propagate: + root_span_context = reqs_context[rid].bootstrap_room_span_context + + trace_context = SglangTracePropagateContext(root_span_context, prev_span_context) return trace_context.to_dict() @@ -327,10 +346,54 @@ def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any] ].last_span_context = trace_context.prev_span_context +def trace_get_remote_propagate_context(bootstrap_room_list: List[str]): + if not tracing_enabled: + return "" + + reqs_trace_contexts = {} + for bootstrap_room in bootstrap_room_list: + # In the router, rid is also the bootstrap room. + bootstrap_room = str(bootstrap_room) + + if bootstrap_room not in reqs_context: + continue + + _context = trace_get_proc_propagate_context( + bootstrap_room, remote_propagate=True + ) + reqs_trace_contexts[bootstrap_room] = _context + + json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False) + return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") + + +def trace_set_remote_propagate_context(base64_str): + if not tracing_enabled: + return + + if base64_str is None or base64_str == "" or base64_str == "None": + return + + base64_bytes = base64.b64decode(base64_str) + json_str = base64_bytes.decode("utf-8") + remote_reqs_trace_contexts = json.loads(json_str) + + for bootstrap_room in remote_reqs_trace_contexts: + if bootstrap_room in remote_trace_contexts: + continue + + remote_trace_contexts[bootstrap_room] = ( + SglangTracePropagateContext.instance_from_dict( + remote_reqs_trace_contexts[bootstrap_room] + ) + ) + + def trace_req_start( rid: str, bootstrap_room: Optional[int] = None, ts: Optional[int] = None, + role: Optional[str] = "null", ): if not tracing_enabled: return @@ -344,6 +407,7 @@ def trace_req_start( return # create req context and root span + bootstrap_room = 0 if bootstrap_room is None else bootstrap_room reqs_context[rid] = SglangTraceReqContext( rid=rid, start_time_ns=ts, @@ -352,23 +416,42 @@ def trace_req_start( is_copy=False, ) + # create bootstrap room span + tracer = threads_info[pid].tracer + if str(bootstrap_room) not in remote_trace_contexts: + attrs = {"bootstrap_room": str(hex(bootstrap_room))} + bootstrap_room_span = tracer.start_span( + name=f"Bootstrap Room {hex(bootstrap_room)}", + start_time=ts, + attributes=attrs, + ) + reqs_context[rid].bootstrap_room_span = bootstrap_room_span + bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span) + else: + bootstrap_room_span_context = remote_trace_contexts[ + str(bootstrap_room) + ].root_span_context + # Drop the worker_id added by MultiTokenizer orig_rid = rid.split("_")[-1] - tracer = threads_info[pid].tracer + role = "" if role == "null" else role + attrs = {"rid": orig_rid} root_span = tracer.start_span( - name=f"Req {orig_rid[:8]}", + name=f"{role} Req {orig_rid[:8]}", start_time=ts, + context=bootstrap_room_span_context, + attributes=attrs, ) root_span.set_attributes( { "rid": rid, - "bootstrap_room": bootstrap_room if bootstrap_room else "None", } ) reqs_context[rid].root_span = root_span reqs_context[rid].root_span_context = trace.set_span_in_context(root_span) + reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context # create thread context and thread span reqs_context[rid].threads_context[pid] = __create_thread_context( @@ -376,6 +459,10 @@ def trace_req_start( reqs_context[rid].root_span_context, ts, ) + if str(bootstrap_room) in remote_trace_contexts: + reqs_context[rid].threads_context[pid].last_span_context = ( + remote_trace_contexts[str(bootstrap_room)].prev_span_context + ) def trace_req_finish( @@ -399,6 +486,10 @@ def trace_req_finish( req_context.root_span.set_attributes(attrs) req_context.root_span.end(end_time=ts) + if str(req_context.bootstrap_room) in remote_trace_contexts: + del remote_trace_contexts[str(req_context.bootstrap_room)] + else: + req_context.bootstrap_room_span.end(end_time=ts) del reqs_context[rid] @@ -518,7 +609,9 @@ def trace_slice_end( # Add event to the current slice on the same thread with the same rid. -def trace_event(name: str, rid: str, ts: Optional[int] = None): +def trace_event( + name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None +): if not tracing_enabled: return @@ -539,7 +632,7 @@ def trace_event(name: str, rid: str, ts: Optional[int] = None): ts = ts or __get_cur_time_ns() slice_info = thread_context.cur_slice_stack[-1] - slice_info.span.add_event(name=name, timestamp=ts) + slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs) # Add attrs to the current slice on the same thread with the same rid. @@ -569,6 +662,9 @@ def trace_slice_batch( name: str, reqs: List[Req], ): + if not tracing_enabled: + return + for req in reqs: trace_slice( name, @@ -576,3 +672,16 @@ def trace_slice_batch( auto_next_anon=not req.finished(), thread_finish_flag=req.finished(), ) + + +def trace_event_batch( + name: str, + reqs: List[Req], + ts: Optional[int] = None, + attrs: Dict[str, Any] = None, +): + if not tracing_enabled: + return + + for req in reqs: + trace_event(name, req.rid, ts=ts, attrs=attrs) diff --git a/scripts/convert_otel_2_perfetto.py b/scripts/convert_otel_2_perfetto.py new file mode 100644 index 00000000000..42f1127bcac --- /dev/null +++ b/scripts/convert_otel_2_perfetto.py @@ -0,0 +1,375 @@ +import argparse +import bisect +import json +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, Iterable, List, Tuple + +parser = argparse.ArgumentParser( + description="Convert SGLang OTEL trace files to Perfetto format.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, +) +parser.add_argument( + "-i", + "--input", + dest="input_file", + required=True, + type=str, + help="Path to the input OTEL trace file (JSON or JSONL format).", +) +parser.add_argument( + "-o", + "--output", + dest="output_file", + type=str, + default="sglang_trace_perfetto.json", + help="Path to the output Perfetto JSON file.", +) +parser.add_argument( + "-f", "--torch-file", dest="torch_file", help="specify torch profile file" +) + +args = parser.parse_args() + +perfetto_data = None +if args.torch_file: + with open(args.torch_file, "r", encoding="utf-8") as file: + perfetto_data = json.load(file) + baseline = perfetto_data["baseTimeNanoseconds"] +else: + baseline = 0 + + +def id_generator(): + i = 0 + while True: + yield i + i += 1 + + +relation_id_gen = id_generator() + + +class SpanLayoutContainer: + def __init__(self): + self.intervals = [] + + def check_overlap(self, start, end): + idx = bisect.bisect_left(self.intervals, (start, float("-inf"))) + + if idx > 0: + prev_start, prev_end = self.intervals[idx - 1] + if prev_end > start: + return True + + if idx < len(self.intervals): + next_start, next_end = self.intervals[idx] + if next_start < end: + return True + return False + + def insert_span(self, start, end): + bisect.insort_left(self.intervals, (start, end)) + + +def new_metadata_level1(name: str, pid): + return { + "name": "process_name", + "ph": "M", + "pid": pid, + "args": {"name": name}, + } + + +def new_metadata_level2(name: str, pid, slot_seq): + return { + "name": "thread_name", + "ph": "M", + "pid": pid, + "tid": slot_seq, + "args": {"name": name}, + } + + +def __find_line(graph, trans_graph_status, slot_meta_data, pid, start, end): + if pid in trans_graph_status: + line = trans_graph_status[pid] + if start == end: + return line + # check conflict + if not graph[pid][line].check_overlap(start, end): + return line + + if pid not in graph: + line = 1 + graph[pid] = {line: SpanLayoutContainer()} + trans_graph_status[pid] = line + slot_meta_data.append(new_metadata_level2("slot", pid, line)) + return line + + for line in graph[pid]: + if not graph[pid][line].check_overlap(start, end): + trans_graph_status[pid] = line + return line + + new_line = len(graph[pid]) + 1 + graph[pid][new_line] = SpanLayoutContainer() + trans_graph_status[pid] = new_line + slot_meta_data.append(new_metadata_level2("slot", pid, new_line)) + return new_line + + +OtelSpan = Dict[str, Any] + + +def load_otel_data(path: str | Path): + p = Path(path) + with p.open("rt", encoding="utf-8") as f: + first = f.read(1) + f.seek(0) + if first == "[": + data = json.load(f) # JSON array + else: + data = [json.loads(line) for line in f if line.strip()] # JSONL + return data + + +def extract_all_otel_spans(otel_data): + otel_spans = [] + for line_data in otel_data: + for resource_spans in line_data["resourceSpans"]: + for scope_spans in resource_spans["scopeSpans"]: + for span in scope_spans["spans"]: + if "attributes" in span: + attributes_dict = { + attr.get("key"): next( + iter(attr.get("value", {}).values()), None + ) + for attr in span["attributes"] + } + span["attributes"] = attributes_dict + else: + span["attributes"] = {} + otel_spans.append(span) + return otel_spans + + +def build_otel_span_tree(otel_spans): + span_id_map = {span["spanId"]: span for span in otel_spans} + for span in otel_spans: + span["child"] = [] + + bootstrap_room_spans = [] + + for span in otel_spans: + span_id = span["spanId"] + parent_span_id = span.get("parentSpanId", "") + if parent_span_id == "": + # check if root span is a request span + attrs = span.get("attributes", {}) + bootstrap_room_spans.append(span) + elif parent_span_id in span_id_map: + parent_span = span_id_map[parent_span_id] + parent_span["child"].append(span) + + link_spans = [] + if "links" in span: + for link in span["links"]: + link_span = span_id_map.get(link["spanId"]) + if link_span: + link_spans.append(link_span) + span["links"] = link_spans + + return bootstrap_room_spans + + +def generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data): + for bootstrap_room_span in otel_bootstrap_room_spans: + bootstrap_room = bootstrap_room_span["attributes"]["bootstrap_room"] + bootstrap_room_span["spans"] = [] + + for node_req_span in bootstrap_room_span["child"]: + rid = node_req_span["attributes"]["rid"] + + for thread_span in node_req_span["child"]: + pid = int(thread_span["attributes"]["pid"]) + thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}' + if "tp_rank" in thread_span["attributes"]: + thread_name += f"-TP{thread_span['attributes']['tp_rank']}" + + if pid not in thread_meta_data: + thread_meta_data[pid] = new_metadata_level1(thread_name, pid) + + for span in thread_span["child"]: + span["attributes"]["bootstrap_room"] = bootstrap_room + span["attributes"]["rid"] = rid + span["host_id"] = thread_span["attributes"]["host_id"] + span["pid"] = pid + + span["startTimeUnixNano"] = int(span["startTimeUnixNano"]) + span["endTimeUnixNano"] = int(span["endTimeUnixNano"]) + ts = span["startTimeUnixNano"] + dur = span["endTimeUnixNano"] - ts + + perfetto_span = { + "ph": "X", + "name": span.get("name", "unknown"), + "cat": "sglang", + "ts": (ts - baseline) / 1000.0, + "dur": (dur - 1000) / 1000.0, + "pid": pid, + "tid": 0, + "args": span["attributes"], + } + + span["perfetto_span"] = perfetto_span + bootstrap_room_span["spans"].append(span) + + +def generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data): + for bootstrap_room_span in otel_bootstrap_room_spans: + bootstrap_room_span["spans"] = sorted( + bootstrap_room_span["spans"], key=lambda x: int(x["startTimeUnixNano"]) + ) + + otel_bootstrap_room_spans = sorted( + otel_bootstrap_room_spans, key=lambda x: int(x["spans"][0]["startTimeUnixNano"]) + ) + graph = {} + for bootstrap_room_span in otel_bootstrap_room_spans: + req_thread_status = {} + for span in bootstrap_room_span["spans"]: + line = __find_line( + graph, + req_thread_status, + slot_meta_data, + span["perfetto_span"]["pid"], + span["startTimeUnixNano"], + span["endTimeUnixNano"], + ) + graph[span["perfetto_span"]["pid"]][line].insert_span( + span["startTimeUnixNano"], span["endTimeUnixNano"] + ) + span["perfetto_span"]["tid"] = line + + +def generate_perfetto_events(otel_bootstrap_room_spans): + for bootstrap_room_span in otel_bootstrap_room_spans: + for span in bootstrap_room_span["spans"]: + span["perfetto_events"] = [] + if "events" in span: + for event in span["events"]: + attributes_dict = { + attr.get("key"): next( + iter(attr.get("value", {}).values()), None + ) + for attr in event["attributes"] + } + perfetto_event = { + "ph": "i", + "cat": "sglang", + "ts": (int(event["timeUnixNano"]) - baseline) / 1000.0, + "pid": span["perfetto_span"]["pid"], + "tid": span["perfetto_span"]["tid"], + "name": event.get("name", "unknown"), + "args": attributes_dict, + } + + span["perfetto_events"].append(perfetto_event) + + +def generate_perfetto_links(otel_bootstrap_room_spans): + for bootstrap_room_span in otel_bootstrap_room_spans: + for span in bootstrap_room_span["spans"]: + span["perfetto_links"] = [] + if "links" in span: + for link_span in span["links"]: + if "correlation" in link_span["perfetto_span"]["args"]: + id = link_span["perfetto_span"]["args"]["correlation"] + else: + id = next(relation_id_gen) + link_span["perfetto_span"]["args"]["correlation"] = id + + perfetto_start_node = { + "ph": "s", + "id": id, + "pid": link_span["perfetto_span"]["pid"], + "tid": link_span["perfetto_span"]["tid"], + "ts": link_span["perfetto_span"]["ts"], + "cat": "ac2g", + "name": "ac2g", + } + + perfetto_end_node = { + "ph": "f", + "id": id, + "pid": span["perfetto_span"]["pid"], + "tid": span["perfetto_span"]["tid"], + "ts": span["perfetto_span"]["ts"], + "cat": "ac2g", + "name": "ac2g", + "bp": "e", + } + + span["perfetto_links"].append(perfetto_start_node) + span["perfetto_links"].append(perfetto_end_node) + + +def gather_all_perfetto_elems( + otel_bootstrap_room_spans, thread_meta_data, slot_meta_data +): + elems = [] + elems.extend(thread_meta_data.values()) + elems.extend(slot_meta_data) + for bootstrap_room_span in otel_bootstrap_room_spans: + for span in bootstrap_room_span["spans"]: + elems.append(span["perfetto_span"]) + elems.extend(span["perfetto_events"]) + elems.extend(span["perfetto_links"]) + + return elems + + +def write_json(perfetto_elems): + global perfetto_data + + if args.torch_file: + perfetto_data["traceEvents"].extend(perfetto_elems) + filered_data = [ + item + for item in perfetto_data["traceEvents"] + if item.get("cat") != "gpu_user_annotation" + ] + perfetto_data["traceEvents"] = filered_data + else: + perfetto_data = perfetto_elems + + with open(args.output_file, "w", encoding="utf-8") as file: + json.dump(perfetto_data, file, ensure_ascii=False, indent=4) + + +def main(): + start_time = time.time() + otel_data = load_otel_data(args.input_file) + otel_spans = extract_all_otel_spans(otel_data) + otel_bootstrap_room_spans = build_otel_span_tree(otel_spans) + thread_meta_data = {} + generate_perfetto_span(otel_bootstrap_room_spans, thread_meta_data) + slot_meta_data = [] + generate_perfetto_span_layout(otel_bootstrap_room_spans, slot_meta_data) + generate_perfetto_events(otel_bootstrap_room_spans) + generate_perfetto_links(otel_bootstrap_room_spans) + perfetto_elems = gather_all_perfetto_elems( + otel_bootstrap_room_spans, thread_meta_data, slot_meta_data + ) + write_json(perfetto_elems) + end_time = time.time() + execution_time = end_time - start_time + print(f"\nConversion finished successfully!") + print(f"Output written to: {args.output_file}") + print(f"Execution time: {execution_time * 1000:.4f} ms") + + +if __name__ == "__main__": + main() diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 506842f843f..50d96da6e1f 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -38,9 +38,22 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router_args = args if router_args.mini_lb: + if router_args.enable_trace: + from sglang.srt.tracing.trace import ( + process_tracing_init, + trace_set_thread_info, + ) + + process_tracing_init(router_args.otlp_traces_endpoint, "sglang") + trace_set_thread_info("Mini lb") + mini_lb = MiniLoadBalancer(router_args) mini_lb.start() else: + # TODO: support tracing for router(Rust). + del router_args.enable_trace + del router_args.otlp_traces_endpoint + if Router is None: raise RuntimeError("Rust Router is not installed") router_args._validate_router_args() diff --git a/sgl-router/py_src/sglang_router/mini_lb.py b/sgl-router/py_src/sglang_router/mini_lb.py index 920d5c38fc1..4fcd1cc1753 100644 --- a/sgl-router/py_src/sglang_router/mini_lb.py +++ b/sgl-router/py_src/sglang_router/mini_lb.py @@ -18,6 +18,14 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang_router.router_args import RouterArgs +from sglang.srt.tracing.trace import ( + trace_get_remote_propagate_context, + trace_req_finish, + trace_req_start, + trace_slice_end, + trace_slice_start, +) + logger = logging.getLogger(__name__) AIOHTTP_STREAM_READ_CHUNK_SIZE = ( @@ -46,6 +54,7 @@ def __init__( self.prefill_urls = [url[0] for url in router_args.prefill_urls] self.prefill_bootstrap_ports = [url[1] for url in router_args.prefill_urls] self.decode_urls = router_args.decode_urls + self.enable_trace = router_args.enable_trace def _validate_router_args(self, router_args: RouterArgs): logger.warning( @@ -91,11 +100,33 @@ async def generate( total=self.timeout ) # Add timeout for request reliability ) as session: + headers = {} + bootstrap_room_list = [] + if self.enable_trace: + bootstrap_room_list = ( + modified_request["bootstrap_room"] + if isinstance(modified_request["bootstrap_room"], list) + else [modified_request["bootstrap_room"]] + ) + trace_context = trace_get_remote_propagate_context(bootstrap_room_list) + headers = {"trace_context": trace_context} + tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), + session.post( + f"{prefill_server}/{endpoint}", + json=modified_request, + headers=headers, + ), + session.post( + f"{decode_server}/{endpoint}", + json=modified_request, + headers=headers, + ), ] + for bootstrap_room in bootstrap_room_list: + trace_slice_end("mini_lb_launch", bootstrap_room, auto_next_anon=True) + # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) @@ -114,6 +145,14 @@ async def generate( else: ret_json = await decode_response.json() + for bootstrap_room in bootstrap_room_list: + trace_slice_end( + "wait_PD_finish", + bootstrap_room, + thread_finish_flag=True, + ) + trace_req_finish(bootstrap_room) + return ORJSONResponse( content=ret_json, status_code=decode_response.status, @@ -131,10 +170,36 @@ async def stream_results(): ) # Add timeout for request reliability ) as session: # Create the tasks for both prefill and decode requests + headers = {} + bootstrap_room_list = [] + if self.enable_trace: + bootstrap_room_list = ( + modified_request["bootstrap_room"] + if isinstance(modified_request["bootstrap_room"], list) + else [modified_request["bootstrap_room"]] + ) + trace_context = trace_get_remote_propagate_context( + bootstrap_room_list + ) + headers = {"trace_context": trace_context} + tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), + session.post( + f"{prefill_server}/{endpoint}", + json=modified_request, + headers=headers, + ), + session.post( + f"{decode_server}/{endpoint}", + json=modified_request, + headers=headers, + ), ] + + for bootstrap_room in bootstrap_room_list: + trace_slice_end( + "mini_lb_launch", bootstrap_room, auto_next_anon=True + ) # Wait for both responses to complete. Since this is streaming, they return immediately. prefill_response, decode_response = await asyncio.gather(*tasks) @@ -174,6 +239,14 @@ async def stream_results(): ): yield chunk + for bootstrap_room in bootstrap_room_list: + trace_slice_end( + "wait_PD_finish", + bootstrap_room, + thread_finish_flag=True, + ) + trace_req_finish(bootstrap_room) + return StreamingResponse( stream_results(), media_type="text/event-stream", @@ -367,7 +440,10 @@ async def handle_completion_request(request_data: dict): def _generate_bootstrap_room(): - return random.randint(0, 2**63 - 1) + bootstrap_room = random.randint(0, 2**63 - 1) + trace_req_start(bootstrap_room, bootstrap_room, role="router") + trace_slice_start("mini_lb_launch", bootstrap_room) + return bootstrap_room # We may utilize `GenerateReqInput`'s logic later diff --git a/sgl-router/py_src/sglang_router/router_args.py b/sgl-router/py_src/sglang_router/router_args.py index 53f804e04b1..813a0b2a1a8 100644 --- a/sgl-router/py_src/sglang_router/router_args.py +++ b/sgl-router/py_src/sglang_router/router_args.py @@ -112,6 +112,9 @@ class RouterArgs: client_cert_path: Optional[str] = None client_key_path: Optional[str] = None ca_cert_paths: List[str] = dataclasses.field(default_factory=list) + # Trace + enable_trace: bool = False + otlp_traces_endpoint: str = "localhost:4317" @staticmethod def add_cli_args( @@ -608,6 +611,17 @@ def add_cli_args( default=[], help="Path(s) to CA certificate(s) for verifying worker TLS certificates. Can specify multiple CAs.", ) + parser.add_argument( + f"--{prefix}enable-trace", + action="store_true", + help="Enable opentelemetry trace", + ) + parser.add_argument( + f"--{prefix}otlp-traces-endpoint", + type=str, + default="localhost:4317", + help="Config opentelemetry collector endpoint if --enable-trace is set. format: :", + ) @classmethod def from_cli_args( diff --git a/test/srt/test_tracing.py b/test/srt/test_tracing.py index a3e6de6b52b..173b15b506c 100644 --- a/test/srt/test_tracing.py +++ b/test/srt/test_tracing.py @@ -74,7 +74,7 @@ def test_trace_enable(self): DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-trace", "--oltp-traces-endpoint", "0.0.0.0:4317"], + other_args=["--enable-trace", "--otlp-traces-endpoint", "0.0.0.0:4317"], ) try: @@ -121,7 +121,7 @@ def test_trace_engine_enable(self): model_path=model_path, random_seed=42, enable_trace=True, - oltp_traces_endpoint="localhost:4317", + otlp_traces_endpoint="localhost:4317", ) try: @@ -148,7 +148,7 @@ def test_trace_engine_encode(self): model_path=model_path, random_seed=42, enable_trace=True, - oltp_traces_endpoint="localhost:4317", + otlp_traces_endpoint="localhost:4317", is_embedding=True, )